In [1]:
import os
import csv

root_dir = 'graphs_new_pannuke_edgeAtr'  # top-level folder
label_map = {'Benign': 0, 'InSitu': 1, 'Invasive': 2, 'Normal': 3}  # your subtype→label mapping
metadata_path = 'metadata.csv'

with open(metadata_path, 'w', newline='') as csvfile:
    writer = csv.writer(csvfile)
    writer.writerow(['graph_path', 'label'])

    for subtype in os.listdir(root_dir):
        subtype_path = os.path.join(root_dir, subtype)
        if not os.path.isdir(subtype_path):
            continue

        label = label_map.get(subtype, -1)
        if label == -1:
            print(f"Unknown label for {subtype}")
            continue

        for fname in os.listdir(subtype_path):
            if fname.endswith('.pt'):
                rel_path = os.path.join(subtype, fname)
                writer.writerow([os.path.join(root_dir, rel_path), label])

print(f"Metadata written to {metadata_path}")


Metadata written to metadata.csv


In [2]:
import torch
import pandas as pd
import os
from torch_geometric.data import Data, InMemoryDataset

# --- Load .pt Graph and Create Subgraphs ---

def load_pt_and_create_subgraphs(pt_path, label, window_size=100, step_size=50):
    graph = torch.load(pt_path)
    graph.y = torch.tensor([label], dtype=torch.long)

    subgraphs = []
    num_nodes = graph.num_nodes
    selected_nodes = set()  # To keep track of already used nuclei

    try:
        nucleus_types = graph.x[:, 2]
    except Exception as e:
        print("Error accessing type column:", e)
        print("Example entry:", graph.x[1])
        return []

    type1_nodes = (nucleus_types == 1).nonzero(as_tuple=True)[0].tolist()

    for center in type1_nodes:
        if center in selected_nodes:
            continue  # Skip if nucleus already used

        start = max(0, center - window_size // 2)
        end = min(num_nodes, start + window_size)
        start = max(0, end - window_size)

        node_indices = list(range(start, end))

        # Skip if any of the current nodes have already been used
        if any(idx in selected_nodes for idx in node_indices):
            continue

        # Mark all nodes in this subgraph as used
        selected_nodes.update(node_indices)

        id_map = {old: i for i, old in enumerate(node_indices)}

        mask = torch.tensor([
            (src.item() in id_map and dst.item() in id_map)
            for src, dst in graph.edge_index.T
        ], dtype=torch.bool)

        edge_index = graph.edge_index[:, mask]
        edge_attr = graph.edge_attr[mask] if graph.edge_attr is not None else None

        edge_index = torch.tensor([
            [id_map[src.item()], id_map[dst.item()]]
            for src, dst in edge_index.T
        ], dtype=torch.long).T

        if edge_index.size(1) == 0:
            edge_index = torch.stack([
                torch.arange(len(node_indices)),
                torch.arange(len(node_indices))
            ], dim=0)
            edge_attr = torch.ones((len(node_indices), 1), dtype=torch.float)

        subgraph = Data(
            x=graph.x[start:end],
            edge_index=edge_index,
            edge_attr=edge_attr,
            y=graph.y,
            original_node_indices=torch.tensor(node_indices, dtype=torch.long)
        )

        subgraphs.append(subgraph)

    return subgraphs


In [3]:
if __name__ == "__main__":
    pt_path = 'graphs_new_pannuke_edgeAtr/Benign/2.pt'
    label = 0  # e.g., Benign = 0
    subgraphs = load_pt_and_create_subgraphs(pt_path, label)
    print(subgraphs)


  graph = torch.load(pt_path)


[Data(x=[100, 8], edge_index=[2, 265], edge_attr=[265, 3], y=[1], original_node_indices=[100]), Data(x=[100, 8], edge_index=[2, 239], edge_attr=[239, 3], y=[1], original_node_indices=[100]), Data(x=[100, 8], edge_index=[2, 257], edge_attr=[257, 3], y=[1], original_node_indices=[100]), Data(x=[100, 8], edge_index=[2, 258], edge_attr=[258, 3], y=[1], original_node_indices=[100]), Data(x=[100, 8], edge_index=[2, 225], edge_attr=[225, 3], y=[1], original_node_indices=[100]), Data(x=[100, 8], edge_index=[2, 235], edge_attr=[235, 3], y=[1], original_node_indices=[100]), Data(x=[100, 8], edge_index=[2, 242], edge_attr=[242, 3], y=[1], original_node_indices=[100]), Data(x=[100, 8], edge_index=[2, 205], edge_attr=[205, 3], y=[1], original_node_indices=[100]), Data(x=[100, 8], edge_index=[2, 225], edge_attr=[225, 3], y=[1], original_node_indices=[100])]


In [4]:
import torch
import torch.nn.functional as F
from torch_geometric.nn import GATConv, global_mean_pool
import pandas as pd
import random
import numpy as np
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix
from torch.utils.data import Dataset


In [5]:
class GraphSubgraphDataset(Dataset):
    def __init__(self, metadata_csv):
        self.meta = pd.read_csv(metadata_csv)
        self.graph_paths = self.meta['graph_path'].tolist()
        self.labels = self.meta['label'].tolist()

    def __len__(self):
        return len(self.graph_paths)

    def __getitem__(self, idx):
        return self.graph_paths[idx], self.labels[idx]


In [6]:
import torch
import torch.nn.functional as F
from torch_geometric.nn import GATConv, global_mean_pool

class GAT(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels, heads=8, dropout=0.4):
        super(GAT, self).__init__()

        # First GATConv: multi-head
        self.conv1 = GATConv(in_channels, hidden_channels, heads=heads, dropout=dropout)
        # Second GATConv: single-head to collapse heads
        self.conv2 = GATConv(hidden_channels * heads, hidden_channels, heads=1, concat=False)
        self.dropout = dropout

        # Final linear classifier
        self.lin = torch.nn.Linear(hidden_channels, out_channels)

    def forward(self, data):
        x, edge_index, batch = data.x, data.edge_index, data.batch

        # Layer 1 with attention
        x, attn1 = self.conv1(x, edge_index, return_attention_weights=True)
        x = F.elu(x)
        x = F.dropout(x, p=self.dropout, training=self.training)

        # Layer 2 with attention
        x, attn2 = self.conv2(x, edge_index, return_attention_weights=True)

        # Global pooling over nodes
        x = global_mean_pool(x, batch)

        # Classification
        logits = self.lin(x)
        return logits, attn1, attn2


In [7]:
@torch.no_grad()
def score_subgraph_attention(model, subgraph, device):
    # Add dummy batch index for global pooling
    subgraph.batch = torch.zeros(subgraph.num_nodes, dtype=torch.long)
    subgraph = subgraph.to(device)

    # Run through model in eval mode
    model.eval()
    logits, attn1, _ = model(subgraph)
    attn_weights = attn1[1]  # Shape: [num_edges, num_heads]

    # Use mean attention across all edges and heads
    score = attn_weights.mean().item()
    return score


In [None]:
def train(model, dataset, optimizer, criterion, device):
    model.train()
    total_loss, correct = 0.0, 0

    for path, label in dataset:
        subgraphs = load_pt_and_create_subgraphs(path, label)
        if len(subgraphs) == 0:
            continue

        subgraph_scores = []
        for subgraph in subgraphs:
            try:
                score = score_subgraph_attention(model, subgraph, device)
                subgraph_scores.append((score, subgraph))
            except:
                continue
        # best_k = len(subgraphs)//3
        # sorted_graphs = sorted(subgraph_scores, key=lambda x: x[0], reverse=True)
        top_subgraphs = sorted(subgraph_scores, key=lambda x: x[0], reverse=True)[:6]
        if len(top_subgraphs) == 0:
            continue

        logits_list = []
        for _, subgraph in top_subgraphs:
            subgraph.batch = torch.zeros(subgraph.num_nodes, dtype=torch.long)
            subgraph = subgraph.to(device)
            logits, _, _ = model(subgraph)
            logits_list.append(logits)

        # Aggregate logits
        avg_logits = torch.stack(logits_list).mean(dim=0)
        min_loss=10000
        # Loss and backprop
        label_tensor = torch.tensor([label], dtype=torch.long).to(device)
        loss = criterion(avg_logits, label_tensor)
        if loss < min_loss and paitence:
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()
            min_loss = loss
        else:
            paitence -=1
            continue

        total_loss += loss.item()
        correct += (avg_logits.argmax(dim=1) == label_tensor).sum().item()

    return total_loss / len(dataset), correct / len(dataset)


In [9]:
@torch.no_grad()
def evaluate(model, dataset, device):
    model.eval()
    y_true = []
    y_pred = []

    for path, label in dataset:
        subgraphs = load_pt_and_create_subgraphs(path, label)
        if len(subgraphs) == 0:
            continue

        scored = []
        for sg in subgraphs:
            try:
                score = score_subgraph_attention(model, sg, device)
                scored.append((score, sg))
            except:
                continue
        best_k = len(subgraphs)//3
        top_subgraphs = sorted(scored, key=lambda x: x[0], reverse=True)[:6]
        if len(top_subgraphs) == 0:
            continue

        logits_list = []
        for _, sg in top_subgraphs:
            sg.batch = torch.zeros(sg.num_nodes, dtype=torch.long)
            sg = sg.to(device)
            logits, _, _ = model(sg)
            logits_list.append(logits)

        avg_logits = torch.stack(logits_list).mean(dim=0)
        pred = avg_logits.argmax(dim=1).item()

        y_true.append(label)
        y_pred.append(pred)

    return y_true, y_pred


In [10]:
# Load and split metadata
metadata = pd.read_csv("metadata.csv").sample(frac=1).reset_index(drop=True)
split = int(0.8 * len(metadata))
metadata.iloc[:split].to_csv("train_meta.csv", index=False)
metadata.iloc[split:].to_csv("test_meta.csv", index=False)

# Create datasets
train_dataset = GraphSubgraphDataset("train_meta.csv")
test_dataset = GraphSubgraphDataset("test_meta.csv")

# Model, optimizer, loss
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = GAT(in_channels=8, hidden_channels=64, out_channels=4).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.005)
criterion = torch.nn.CrossEntropyLoss()

# Training loop
for epoch in range(1, 21):
    loss, acc = train(model, train_dataset, optimizer, criterion, device)
    print(f"Epoch {epoch:02d} | Train Loss: {loss:.4f} | Train Acc: {acc:.4f}")


  graph = torch.load(pt_path)


Epoch 01 | Train Loss: 14.1504 | Train Acc: 0.2812
Epoch 02 | Train Loss: 1.3952 | Train Acc: 0.3375
Epoch 03 | Train Loss: 1.3660 | Train Acc: 0.3500
Epoch 04 | Train Loss: 1.3987 | Train Acc: 0.3063
Epoch 05 | Train Loss: 1.3653 | Train Acc: 0.3406
Epoch 06 | Train Loss: 1.3836 | Train Acc: 0.3375
Epoch 07 | Train Loss: 1.4483 | Train Acc: 0.2938
Epoch 08 | Train Loss: 1.4289 | Train Acc: 0.2875
Epoch 09 | Train Loss: 1.4536 | Train Acc: 0.2844
Epoch 10 | Train Loss: 1.4897 | Train Acc: 0.2781
Epoch 11 | Train Loss: 1.4389 | Train Acc: 0.2969
Epoch 12 | Train Loss: 1.5180 | Train Acc: 0.2500
Epoch 13 | Train Loss: 1.5299 | Train Acc: 0.2781
Epoch 14 | Train Loss: 1.5162 | Train Acc: 0.2344
Epoch 15 | Train Loss: 1.5429 | Train Acc: 0.2844
Epoch 16 | Train Loss: 1.5291 | Train Acc: 0.2844
Epoch 17 | Train Loss: 1.6067 | Train Acc: 0.2375
Epoch 18 | Train Loss: 1.6412 | Train Acc: 0.2406
Epoch 19 | Train Loss: 1.6601 | Train Acc: 0.2625
Epoch 20 | Train Loss: 1.7264 | Train Acc: 0.2906

In [11]:
y_true, y_pred = evaluate(model, test_dataset, device)

print("\nClassification Report:")
print(classification_report(y_true, y_pred, target_names=['Class 0', 'Class 1', 'Class 2', 'Class 3']))

print("\nConfusion Matrix:")
print(confusion_matrix(y_true, y_pred))

print(f"\nAccuracy Score: {accuracy_score(y_true, y_pred):.4f}")


  graph = torch.load(pt_path)



Classification Report:
              precision    recall  f1-score   support

     Class 0       0.00      0.00      0.00        22
     Class 1       0.26      0.69      0.37        16
     Class 2       0.00      0.00      0.00        21
     Class 3       0.50      0.90      0.64        20

    accuracy                           0.37        79
   macro avg       0.19      0.40      0.25        79
weighted avg       0.18      0.37      0.24        79


Confusion Matrix:
[[ 0 11  0 11]
 [ 0 11  0  5]
 [ 0 19  0  2]
 [ 0  2  0 18]]

Accuracy Score: 0.3671


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
