In [None]:
import os
import csv

root_dir = 'graphs_new_pannuke'  # top-level folder
label_map = {'Benign': 0, 'InSitu': 1, 'Invasive': 2, 'Normal': 3}  # your subtype→label mapping
metadata_path = 'metadata.csv'

with open(metadata_path, 'w', newline='') as csvfile:
    writer = csv.writer(csvfile)
    writer.writerow(['graph_path', 'label'])

    for subtype in os.listdir(root_dir):
        subtype_path = os.path.join(root_dir, subtype)
        if not os.path.isdir(subtype_path):
            continue

        label = label_map.get(subtype, -1)
        if label == -1:
            print(f"Unknown label for {subtype}")
            continue

        for fname in os.listdir(subtype_path):
            if fname.endswith('.graphml'):
                rel_path = os.path.join(subtype, fname)
                writer.writerow([os.path.join(root_dir, rel_path), label])

print(f"Metadata written to {metadata_path}")


In [None]:
import os
import torch
from torch_geometric.data import Data, InMemoryDataset
from glob import glob
from natsort import natsorted
import pandas as pd

# --- Create Subgraphs ---
def create_subgraphs(graph, window_size, step_size):
    subgraphs = []
    num_nodes = graph.num_nodes

    # Verify x tensor is 2D
    if graph.x.dim() != 2:
        raise ValueError(f"Expected graph.x to be 2D, got shape {graph.x.shape}")

    # Extract 'type' feature (3rd column, index 2)
    nucleus_types = graph.x[:, 2]  # 'type' is at index 2
    type1_nodes = (nucleus_types == 1).nonzero(as_tuple=True)[0]

    for center in type1_nodes.tolist()[::step_size]:
        # Define window centered at the type-1 node
        start = max(0, center - window_size // 2)
        end = start + window_size

        # Correct window if it exceeds graph bounds
        if end > num_nodes:
            end = num_nodes
            start = max(0, end - window_size)

        # Build node index list and remapping dictionary
        node_indices = list(range(start, end))
        id_map = {old: i for i, old in enumerate(node_indices)}

        # Filter and remap edges
        mask = torch.tensor([
            (src.item() in id_map and dst.item() in id_map)
            for src, dst in graph.edge_index.T
        ], dtype=torch.bool)
        edge_index = graph.edge_index[:, mask]
        edge_attr = graph.edge_attr[mask] if graph.edge_attr is not None else None

        # Remap edge indices
        edge_index = torch.tensor([
            [id_map[src.item()], id_map[dst.item()]]
            for src, dst in edge_index.T
        ], dtype=torch.long).T

        # Fallback if no edges remain (self-loops to avoid empty graphs)
        if edge_index.size(1) == 0:
            edge_index = torch.stack([
                torch.arange(len(node_indices)),
                torch.arange(len(node_indices))
            ], dim=0)
            edge_attr = torch.ones((len(node_indices), 1), dtype=torch.float)  # Default weight=1 for self-loops

        # Create the subgraph
        subgraph = Data(
            x=graph.x[start:end],
            edge_index=edge_index,
            edge_attr=edge_attr,
            y=graph.y,
            original_node_indices=torch.tensor(node_indices, dtype=torch.long)
        )

        # Add individual node features as separate attributes
        feature_names = ['x', 'y', 'type', 'area', 'perimeter', 'eccentricity', 'solidity', 'circularity']
        for i, feature_name in enumerate(feature_names):
            subgraph[feature_name] = subgraph.x[:, i]

        subgraphs.append(subgraph)

    return subgraphs

# --- Generate and Save Subgraphs as .pt ---
def generate_and_save_subgraphs(graph_folder, subgraph_folder, window_size=62, step_size=1):
    os.makedirs(subgraph_folder, exist_ok=True)
    graphs = natsorted(glob(os.path.join(graph_folder, "*.pt")))
    
    metadata = []
    
    for graph_path in graphs:
        # Load full graph
        graph = torch.load(graph_path, weights_only=False)
        
        # Generate subgraphs
        subgraphs = create_subgraphs(graph, window_size=window_size, step_size=step_size)
        
        # Save subgraphs as .pt
        base = os.path.splitext(os.path.basename(graph_path))[0]
        label = graph.y.item()
        for i, sg in enumerate(subgraphs):
            sg_filename = f"{base}_sg{i}.pt"
            sg_path = os.path.join(subgraph_folder, sg_filename)
            torch.save(sg, sg_path)
            print(f"Saved subgraph to {sg_path}")
            metadata.append({'subgraph_path': sg_path, 'label': label})
        
        print(f"{graph_path}: {len(subgraphs)} subgraphs generated")
    
    # Save metadata CSV
    metadata_csv = os.path.join(subgraph_folder, 'sub_gphmeta.csv')
    pd.DataFrame(metadata).to_csv(metadata_csv, index=False)
    print(f"Metadata saved to {metadata_csv}")


In [None]:
import torch
import networkx as nx
import pandas as pd
import os
from torch_geometric.data import Data, InMemoryDataset
from sklearn.preprocessing import StandardScaler
import numpy as np

# --- Convert .graphml → PyG Data object ---
def convert_nx_to_pyg(graphml_path, label):
    # Read graphml file
    G = nx.read_graphml(graphml_path)
    
    # Convert node labels to integers
    G = nx.convert_node_labels_to_integers(G)
    
    # Extract node features
    node_features = []
    feature_names = ['x', 'y', 'type', 'area', 'perimeter', 'eccentricity', 'solidity', 'circularity']
    
    # Debug: Print raw node attributes to verify values
    print(f"Processing graph: {graphml_path}")
    for node, data in G.nodes(data=True):
        features = []
        for f in feature_names:
            value = data.get(f, 0)
            try:
                features.append(float(value))
            except (ValueError, TypeError):
                print(f"Warning: Invalid value for {f} in node {node}: {value}. Using 0.")
                features.append(0.0)
        node_features.append(features)
    
    # Convert to numpy array for analysis
    node_features_np = np.array(node_features)
    
    # Debug: Check if features have identical values
    for i, name in enumerate(feature_names):
        unique_values = np.unique(node_features_np[:, i])
        if len(unique_values) == 1:
            print(f"Warning: Feature '{name}' has identical values: {unique_values[0]}")
        else:
            print(f"Feature '{name}' range: {unique_values.min()} to {unique_values.max()}")
    
    # Optional: Normalize node features
    scaler = StandardScaler()
    node_features_normalized = scaler.fit_transform(node_features_np)
    
    # Convert to tensor
    x = torch.tensor(node_features_normalized, dtype=torch.float)
    
    # Extract edge indices and attributes
    edge_index = []
    edge_attr = []
    for u, v, data in G.edges(data=True):
        edge_index.append([u, v])
        weight = float(data.get('weight', 1.0))
        edge_attr.append([weight])
    
    # Debug: Check edge weights
    edge_weights = np.array(edge_attr)
    unique_weights = np.unique(edge_weights)
    if len(unique_weights) == 1:
        print(f"Warning: All edge weights are identical: {unique_weights[0]}")
    else:
        print(f"Edge weight range: {unique_weights.min()} to {unique_weights.max()}")
    
    edge_index = torch.tensor(edge_index, dtype=torch.long).T.contiguous()
    edge_attr = torch.tensor(edge_attr, dtype=torch.float)
    
    # Create PyG Data object
    data = Data(
        x=x,
        edge_index=edge_index,
        edge_attr=edge_attr,
        y=torch.tensor([label], dtype=torch.long),
        original_node_indices=torch.arange(G.number_of_nodes(), dtype=torch.long)
    )
    
    # Add individual node features as separate attributes
    for i, feature_name in enumerate(feature_names):
        data[feature_name] = x[:, i]
    
    return data

# --- Create Subgraphs ---
def create_subgraphs(graph, window_size, step_size):
    subgraphs = []
    num_nodes = graph.num_nodes

    # Extract 'type' feature (assuming it's the 3rd column, index 2)
    nucleus_types = graph.x[:, 2]  # 'type' is at index 2
    type1_nodes = (nucleus_types == 1).nonzero(as_tuple=True)[0]

    for center in type1_nodes.tolist()[::step_size]:
        # Define window centered at the type-1 node
        start = max(0, center - window_size // 2)
        end = start + window_size

        # Correct window if it exceeds graph bounds
        if end > num_nodes:
            end = num_nodes
            start = max(0, end - window_size)

        # Build node index list and remapping dictionary
        node_indices = list(range(start, end))
        id_map = {old: i for i, old in enumerate(node_indices)}

        # Filter and remap edges
        mask = torch.tensor([
            (src.item() in id_map and dst.item() in id_map)
            for src, dst in graph.edge_index.T
        ], dtype=torch.bool)
        edge_index = graph.edge_index[:, mask]
        edge_attr = graph.edge_attr[mask] if graph.edge_attr is not None else None

        # Remap edge indices
        edge_index = torch.tensor([
            [id_map[src.item()], id_map[dst.item()]]
            for src, dst in edge_index.T
        ], dtype=torch.long).T

        # Fallback if no edges remain (self-loops to avoid empty graphs)
        if edge_index.size(1) == 0:
            edge_index = torch.stack([
                torch.arange(len(node_indices)),
                torch.arange(len(node_indices))
            ], dim=0)
            edge_attr = torch.ones((len(node_indices), 1), dtype=torch.float)  # Default weight=1 for self-loops

        # Create the subgraph
        subgraph = Data(
            x=graph.x[start:end],
            edge_index=edge_index,
            edge_attr=edge_attr,
            y=graph.y,
            original_node_indices=torch.tensor(node_indices, dtype=torch.long)
        )

        # Add individual node features as separate attributes
        feature_names = ['x', 'y', 'type', 'area', 'perimeter', 'eccentricity', 'solidity', 'circularity']
        for i, feature_name in enumerate(feature_names):
            subgraph[feature_name] = subgraph.x[:, i]

        subgraphs.append(subgraph)

    return subgraphs

# --- PyG InMemoryDataset from metadata.csv ---
class SubgraphDatasetFromMetadata(InMemoryDataset):
    def __init__(self, metadata_path, transform=None, pre_transform=None, out_dir='./subgraphs', out_csv='subghmeta.csv'):
        super(SubgraphDatasetFromMetadata, self).__init__('.', transform, pre_transform)
        os.makedirs(out_dir, exist_ok=True)
        metadata = pd.read_csv(metadata_path)

        self.data_list = []
        self.records = []

        for _, row in metadata.iterrows():
            graph_path = row['graph_path']
            label = int(row['label'])
            graph_name = os.path.splitext(os.path.basename(graph_path))[0]

            graph = convert_nx_to_pyg(graph_path, label)
            subgraphs = create_subgraphs(graph, window_size=62, step_size=1)
            for i, sg in enumerate(subgraphs):
                sg_filename = f"{graph_name}_sg{i}.pt"
                sg_path = os.path.join(out_dir, sg_filename)
                torch.save(sg, sg_path)
                self.records.append({'subgraph_path': sg_path, 'label': label})
                self.data_list.append(sg)
            print(f"{row['graph_path']}: {len(subgraphs)} subgraphs generated")

        self.data, self.slices = self.collate(self.data_list)

        # Save metadata CSV
        pd.DataFrame(self.records).to_csv(out_csv, index=False)
        print(f"Subgraphs saved to '{out_dir}', metadata saved to '{out_csv}'")

    def get_labels(self):
        return [data.y.item() for data in self.data_list]

# --- Example Usage ---
if __name__ == "__main__":
    # Example with a single graph
    graphml_path = 'graphs_new_pannuke/Benign/b017.graphml'
    label = 0  # Example label (e.g., 0 for Benign)
    
    # Convert and debug full graph
    full_graph = convert_nx_to_pyg(graphml_path, label)
    print("Original Full Graph:")
    print(full_graph)
    
    # Create and print an example subgraph
    subgraphs = create_subgraphs(full_graph, window_size=62, step_size=1)
    if subgraphs:
        print("\nExample SubGraph:")
        print(subgraphs[0])

In [None]:
G = nx.read_graphml('graphs_new_pannuke/Benign/b017.graphml')
for node, data in G.nodes(data=True):
    print(f"Node {node}: {data}")

In [None]:

dataset = SubgraphDatasetFromMetadata(
    metadata_path='metadata.csv',
    out_dir='./subgraphs_pannuke_s20',
    out_csv='sub_gphmeta_pannuke.csv'
)



In [None]:
from torch_geometric.loader import DataLoader

train_loader = DataLoader(dataset, batch_size=32, shuffle=True)
test_loader = DataLoader(dataset, batch_size=32, shuffle=True)

In [None]:
from sklearn.model_selection import train_test_split
sb_meta = pd.read_csv('sub_gphmeta.csv')
# Split into train/test
train_meta, test_meta = train_test_split(sb_meta, test_size=0.2, stratify=sb_meta["label"], random_state=42)

# Save new metadata files
train_meta.to_csv("train_meta.csv", index=False)
test_meta.to_csv("test_meta.csv", index=False)

print(f"Train size: {len(train_meta)}, Test size: {len(test_meta)}")

In [None]:
# Number of graphs
num_graphs = len(dataset)

# Number of classes (unique labels)
num_classes = dataset.num_classes

# Get the labels (for each graph in the dataset)
labels = [data.y.item() for data in dataset]

# Calculate the average number of nodes and edges
total_nodes = sum(data.num_nodes for data in dataset)
total_edges = sum(data.num_edges for data in dataset)
avg_nodes = total_nodes / num_graphs
avg_edges = total_edges / num_graphs

# Display the results
print(f"Number of graphs: {num_graphs}")
print(f"Number of classes: {num_classes}")
print(f"Labels: {set(labels)}")
print(f"Average number of nodes: {avg_nodes}")
print(f"Average number of edges: {avg_edges}")

In [None]:
import torch.nn.functional as F
from torch_geometric.nn import GATConv, global_mean_pool
class GAT(torch.nn.Module):
    def __init__(self, hidden_channels, heads=8, dropout=0.6):
        super(GAT, self).__init__()
        torch.manual_seed(42)
        
        self.conv1 = GATConv(dataset.num_node_features, hidden_channels, heads=heads, dropout=dropout)
        # Combine the heads by averaging
        self.conv2 = GATConv(hidden_channels * heads, hidden_channels, heads=1, concat=False, dropout=dropout)
        self.lin = torch.nn.Linear(hidden_channels, dataset.num_classes)
        self.dropout = dropout

    def forward(self, data):
        x, edge_index = data.x, data.edge_index

        # First GAT layer
        x = self.conv1(x, edge_index)
        x = F.elu(x)
        x = F.dropout(x, p=self.dropout, training=self.training)

        # Second GAT layer
        x, attn_weights = self.conv2(x, edge_index, return_attention_weights=True)
        x = F.elu(x)
        x = F.dropout(x, p=self.dropout, training=self.training)

        # Global mean pooling
        x = global_mean_pool(x, data.batch)

        # Classifier
        x = self.lin(x)

        return F.log_softmax(x, dim=1), attn_weights


In [None]:



# Initialize model, optimizer, and loss function
hidden_channels = 64
heads = 16
dropout = 0.6
model = GAT(hidden_channels=hidden_channels, heads=heads, dropout=dropout)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
criterion = torch.nn.CrossEntropyLoss()


device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)

# Move data to device
def move_to_device(batch, device):
    batch = batch.to(device)
    return batch

# Weakly supervised training
model.train()
for epoch in range(50):  # Number of epochs
    total_loss = 0
    correct = 0
    total = 0
    batch = next(iter(train_loader))

    print("Type of batch:", type(batch))
    print("Batch object keys:", batch.keys)
    print("Node feature shape (x):", batch.x.shape)
    print("Edge index shape:", batch.edge_index.shape)
    print("Labels (y):", batch.y)
    print("Batch vector shape (batch):", batch.batch.shape)  # tells which node belongs to which graph

    for batch in train_loader:
        batch = move_to_device(batch, device)
        optimizer.zero_grad()
        out, attn_weights = model(batch)
        loss = criterion(out, batch.y)  # Compute loss using graph labels
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
        
        
        _, predicted = torch.max(out, dim=1)
        correct += (predicted == batch.y).sum().item()
        total += batch.y.size(0)
    
    avg_loss = total_loss / len(train_loader)
    accuracy = correct / total
    print(f'Epoch {epoch + 1}, Loss: {avg_loss:.4f}, Accuracy: {accuracy:.4f}')

In [None]:
print(type([0]))
print("Type of train_loader:", type(train_loader))
