In [None]:
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import pandas as pd
import networkx as nx
import matplotlib.pyplot as plt
import seaborn as sns

from torch_geometric.data import Data, InMemoryDataset, DataLoader
from torch_geometric.nn import (
    GCNConv, GATv2Conv, SAGEConv, GINEConv,
    NNConv, PNAConv, global_mean_pool
)
from sklearn.metrics import (
    accuracy_score, confusion_matrix, classification_report
)

LABEL_MAP = {"Benign": 0, "InSitu": 1, "Invasive": 2, "Normal": 3}


In [None]:
class GNN(nn.Module):
    def __init__(self, kind, in_ch, edge_ch, hidden, out):
        super().__init__()
        self.kind = kind
        self.dropout = nn.Dropout(0.3)

        if kind == "GCN":
            self.conv1 = GCNConv(in_ch, hidden)
            self.conv2 = GCNConv(hidden, hidden)
        elif kind == "GAT":
            self.conv1 = GATv2Conv(in_ch, hidden // 4, heads=4)
            self.conv2 = GATv2Conv(hidden, hidden, heads=1, concat=False)
        elif kind == "SAGE":
            self.conv1 = SAGEConv(in_ch, hidden)
            self.conv2 = SAGEConv(hidden, hidden)
        elif kind == "GIN":
            nn1 = nn.Sequential(nn.Linear(in_ch, hidden), nn.ReLU(), nn.Linear(hidden, hidden))
            nn2 = nn.Sequential(nn.Linear(hidden, hidden), nn.ReLU(), nn.Linear(hidden, hidden))
            self.conv1 = GINEConv(nn1, edge_dim=edge_ch)
            self.conv2 = GINEConv(nn2, edge_dim=edge_ch)
        elif kind == "MPNN":
            lin = nn.Linear(edge_ch, in_ch * hidden)
            self.conv1 = NNConv(in_ch, hidden, lin)
            self.conv2 = NNConv(hidden, hidden, lin)
        elif kind == "PNA":
            self.conv1 = PNAConv(in_ch, hidden,
                aggregators=["mean", "max", "min", "std"],
                scalers=["identity", "amplification", "attenuation"],
                deg=None  # TODO: Set degree histogram if needed
            )
            self.conv2 = PNAConv(hidden, hidden,
                aggregators=["mean", "max", "min", "std"],
                scalers=["identity", "amplification", "attenuation"],
                deg=None
            )
        else:
            raise ValueError(f"Unknown kind {kind}")

        self.skip1 = nn.Linear(in_ch, hidden) if in_ch != hidden else nn.Identity()
        self.skip2 = nn.Identity()
        self.bn1 = nn.BatchNorm1d(hidden)
        self.bn2 = nn.BatchNorm1d(hidden)
        self.lin = nn.Linear(hidden, out)

    def forward(self, x, edge_index, edge_attr, batch):
        if self.kind in ["GIN", "MPNN"]:
            h = self.conv1(x, edge_index, edge_attr)
        else:
            h = self.conv1(x, edge_index)

        h = self.bn1(h)
        h = F.relu(h + self.skip1(x))
        h = self.dropout(h)

        if self.kind in ["GIN", "MPNN"]:
            h2 = self.conv2(h, edge_index, edge_attr)
        else:
            h2 = self.conv2(h, edge_index)

        h2 = self.bn2(h2)
        h2 = F.relu(h2 + self.skip2(h))
        h2 = self.dropout(h2)

        hg = global_mean_pool(h2, batch)
        return self.lin(hg)


In [None]:
def train_epoch(model, loader, optimizer, device):
    model.train()
    total_loss = 0
    for batch in loader:
        batch = batch.to(device)
        optimizer.zero_grad()
        out = model(batch.x, batch.edge_index, batch.edge_attr, batch.batch)
        if isinstance(out, tuple): out = out[0]
        loss = F.cross_entropy(out, batch.y)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    return total_loss / len(loader)

def eval_model(model, loader, device):
    model.eval()
    all_preds, all_labels = [], []
    for batch in loader:
        batch = batch.to(device)
        with torch.no_grad():
            out = model(batch.x, batch.edge_index, batch.edge_attr, batch.batch)
            if isinstance(out, tuple): out = out[0]
            pred = out.argmax(dim=1)
        all_preds += pred.cpu().tolist()
        all_labels += batch.y.cpu().tolist()
    
    acc = accuracy_score(all_labels, all_preds)
    print(classification_report(all_labels, all_preds, target_names=LABEL_MAP.keys(), digits=3))
    return acc, confusion_matrix(all_labels, all_preds)

def plot_cm(cm, classes):
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
                xticklabels=classes, yticklabels=classes)
    plt.xlabel("Predicted")
    plt.ylabel("True")
    plt.title("Confusion Matrix")
    plt.show()


In [2]:
import torch
import networkx as nx
import matplotlib.pyplot as plt
from PIL import Image
import os
from glob import glob
from natsort import natsorted

# Define your color mappings
NODE_COLORS = {
    0: 'blue',
    1: 'yellow',
    2: 'red',
    3: 'black',
    4: 'green',
    5: 'aqua'
}
DEFAULT_NODE_COLOR = "gray"
EDGE_COLOR = "black"
EDGE_WIDTH_SCALE = 1.5

def visualize_original_and_subgraph(image_path, graphml_path, subgraph_pt_path):
    # Load image
    img = Image.open(image_path).convert("RGB")
    base = os.path.splitext(os.path.basename(image_path))[0]

    # Load original graph
    G = nx.read_graphml(graphml_path)
    pos_full = {n: (float(G.nodes[n]['x']), float(G.nodes[n]['y'])) for n in G.nodes}
    node_colors_full = [NODE_COLORS.get(int(G.nodes[n].get('type', 0)), DEFAULT_NODE_COLOR) for n in G.nodes]
    edge_list_full = list(G.edges())
    weights_full = [float(G.edges[e]['weight']) for e in edge_list_full]
    widths_full = [w * EDGE_WIDTH_SCALE for w in weights_full]

    # Load subgraph (PyG Data)
    data = torch.load(subgraph_pt_path)
    x = data.x.cpu().numpy()
    edge_index = data.edge_index.cpu().numpy()

    # Convert edge index to pairs
    edge_list_sub = [(int(u), int(v)) for u, v in edge_index.T]

    # Position and color for subgraph
    pos_sub = {i: (x[i][0], x[i][1]) for i in range(x.shape[0])}
    node_colors_sub = [NODE_COLORS.get(int(x[i][2]), DEFAULT_NODE_COLOR) for i in range(x.shape[0])]
    widths_sub = [1.5 for _ in edge_list_sub]  # Default width

    # Build subgraph as NetworkX graph
    G_sub = nx.Graph()
    G_sub.add_nodes_from(pos_sub.keys())
    G_sub.add_edges_from(edge_list_sub)

    # Plot
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(20, 10))

    # Original graph on image
    ax1.imshow(img)
    nx.draw_networkx_edges(G, pos_full, ax=ax1, edgelist=edge_list_full, edge_color=EDGE_COLOR, width=widths_full)
    nx.draw_networkx_nodes(G, pos_full, ax=ax1, node_color=node_colors_full, node_size=30)
    ax1.set_title(f"Original Graph: {base}")
    ax1.axis('off')

    # Subgraph on same image
    ax2.imshow(img)
    nx.draw_networkx_edges(G_sub, pos_sub, ax=ax2, edge_color=EDGE_COLOR, width=widths_sub)
    nx.draw_networkx_nodes(G_sub, pos_sub, ax=ax2, node_color=node_colors_sub, node_size=50)
    ax2.set_title(f"Subgraph Overlay: {os.path.basename(subgraph_pt_path)}")
    ax2.axis('off')

    plt.tight_layout()
    plt.show()


In [5]:
import torch
import networkx as nx
from torch_geometric.utils import from_networkx

# Load .graphml using networkx
G = torch.load('graphs_new_pannuke/Invasive/1.pt')

# # Load subgraph (.pt) as usual
# data_subgraphs = torch.load('subgraphs_pannuke_s20/b017_sg240.pt')

# Print summaries
print("Original Full Graph:")
print(G)
print("\n\nSubGraph:")
# print(data_subgraphs)


Original Full Graph:
Data(x=[1504], edge_index=[2, 4487], edge_attr=[4487, 1], y=[1504], original_node_indices=[1504], type=[1504], area=[1504], perimeter=[1504], eccentricity=[1504], solidity=[1504], circularity=[1504])


SubGraph:


  G = torch.load('graphs_new_pannuke/Invasive/1.pt')


In [1]:
for i in range(1,100,10):
    visualize_original_and_subgraph(
        image_path="dataset/data/Photos/Benign/b017.tif",
        graphml_path="graphs_new_pannuke/Benign/b017.graphml",
        subgraph_pt_path=f"subgraphs_pannuke_s20/b017_sg{i}.pt"
    )


NameError: name 'visualize_original_and_subgraph' is not defined

In [None]:
from torch_geometric.data import InMemoryDataset
import torch

class SubgraphDatasetFromSavedFiles(InMemoryDataset):
    def __init__(self, subgraph_metadata_csv, transform=None, pre_transform=None):
        super(SubgraphDatasetFromSavedFiles, self).__init__('.', transform, pre_transform)
        self.meta_df = pd.read_csv(subgraph_metadata_csv)
        self.data_list = []

        for _, row in self.meta_df.iterrows():
            data = torch.load(row['subgraph_path'])
            self.data_list.append(data)

        self.data, self.slices = self.collate(self.data_list)

    def get_labels(self):
        return [data.y.item() for data in self.data_list]
# Load the subgraph datasets
train_ds = SubgraphDatasetFromSavedFiles("train_meta.csv")
test_ds  = SubgraphDatasetFromSavedFiles("test_meta.csv")

# Data loaders
train_loader = torch.utils.data.DataLoader(train_ds, batch_size=32, shuffle=True)
test_loader  = torch.utils.data.DataLoader(test_ds, batch_size=32, shuffle=False)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

train_ds = SubgraphDatasetFromSavedFiles("train_meta.csv")
test_ds = SubgraphDatasetFromSavedFiles("test_meta.csv")

train_dl = DataLoader(train_ds, batch_size=32, shuffle=True)
test_dl = DataLoader(test_ds, batch_size=32)

hist = {}
results = {}

# for kind in ["GCN", "GAT", "SAGE", "GIN", "MPNN"]:
#     model = GNN(kind,
#                 in_ch=train_ds[0].x.size(1),
#                 edge_ch=train_ds[0].edge_attr.size(1) if train_ds[0].edge_attr is not None else 0,
#                 hidden=64, out=4).to(device)
#     optimizer = torch.optim.Adam(model.parameters(), lr=0.005, weight_decay=1e-4)

#     losses, accs = [], []
#     for epoch in range(1, 31):
#         loss = train_epoch(model, train_dl, optimizer, device)
#         acc, _ = eval_model(model, test_dl, device)
#         losses.append(loss)
#         accs.append(acc)

#     hist[kind] = (losses, accs)
#     results[kind] = accs[-1]
#     torch.save(model.state_dict(), f"models/{kind.lower()}.pth")


In [None]:

for kind in ["GCN", "GAT", "SAGE", "GIN", "MPNN"]:
    model = GNN(kind,
                in_ch=train_ds[0].x.size(1),
                edge_ch=train_ds[0].edge_attr.size(1) if train_ds[0].edge_attr is not None else 0,
                hidden=64, out=4).to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=0.005, weight_decay=1e-4)

    losses, accs = [], []
    for epoch in range(1, 31):
        loss = train_epoch(model, train_dl, optimizer, device)
        acc, _ = eval_model(model, test_dl, device)
        losses.append(loss)
        accs.append(acc)

    hist[kind] = (losses, accs)
    results[kind] = accs[-1]
    torch.save(model.state_dict(), f"models/{kind.lower()}.pth")


In [None]:
plt.figure(figsize=(8,5))
for k, (ls, _) in hist.items():
    plt.plot(ls, label=f"{k} loss")
plt.legend(); plt.title("Training Loss"); plt.show()

plt.figure(figsize=(8,5))
for k, (_, acc) in hist.items():
    plt.plot(acc, label=f"{k} acc")
plt.legend(); plt.title("Validation Accuracy"); plt.show()


In [None]:

# Cell: Training & Saving (fixed hist init)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
train_ds = NucleiGraphDataset("train_meta.csv")
test_ds  = NucleiGraphDataset("test_meta.csv")
tr_dl = DataLoader(train_ds, batch_size=16, shuffle=True)
te_dl = DataLoader(test_ds, batch_size=16)

hist = {}
final_acc = {}

for kind in ["GCN","GAT","SAGE","GIN","MPNN","PNA"]:
    model = GNN(kind,
                in_ch=train_ds[0].x.size(1),
                edge_ch=train_ds[0].edge_attr.size(1),
                hidden=64, out=4).to(device)
    opt = torch.optim.Adam(model.parameters(), lr=0.005, weight_decay=1e-4)

    losses, accs = [], []
    for epoch in range(1, 31):
        l = train_epoch(model, tr_dl, opt, device)
        a, _ = eval_model(model, te_dl, device)
        losses.append(l)
        accs.append(a)

    hist[kind] = (losses, accs)           # now works
    final_acc[kind] = accs[-1]
    torch.save(model.state_dict(), f"models/{kind.lower()}.pth")

results = {}
test_ds = NucleiGraphDataset("test_meta.csv")
test_dl = DataLoader(test_ds, batch_size=16)

for kind in ["GCN","GAT","SAGE","GIN","MPNN","PNA"]:
    # load
    path = f"models/{kind.lower()}.pth"
    model = GNN(kind,
                in_ch=test_ds[0].x.size(1),
                edge_ch=test_ds[0].edge_attr.size(1),
                hidden=64, out=4).to(device)
    model.load_state_dict(torch.load(path))
    model.eval()
    
    # evaluate
    acc, cm = eval_model(model, test_dl, device)
    print(f"{kind} Test Accuracy: {acc:.3f}")
    
    # confusion matrix
    plot_cm(cm, list(LABEL_MAP.keys()))
    
    results[kind] = acc

# summary bar chart
plt.figure(figsize=(6,4))
plt.bar(results.keys(), results.values())
plt.ylabel("Test Accuracy")
plt.title("Model Comparison")
plt.xticks(rotation=30)
plt.show()



In [None]:
for kind in results:
    path = f"models/{kind.lower()}.pth"
    model = GNN(kind,
                in_ch=test_ds[0].x.size(1),
                edge_ch=test_ds[0].edge_attr.size(1) if test_ds[0].edge_attr is not None else 0,
                hidden=64, out=4).to(device)
    model.load_state_dict(torch.load(path))
    model.eval()

    acc, cm = eval_model(model, test_dl, device)
    print(f"{kind} Test Accuracy: {acc:.3f}")
    plot_cm(cm, list(LABEL_MAP.keys()))
