In [1]:
import os
import os.path as osp

import torch
import torch.nn.functional as F
import torch_geometric.transforms as T
from torch.nn import Linear
from torch_geometric.loader import LinkNeighborLoader
from torch_geometric.data import Data, Dataset, DataLoader
from torch_geometric.nn import GCNConv
from torch_geometric.nn import GraphSAGE
from torch_geometric.utils import from_networkx

import dgl
from dgl.data.utils import load_graphs

from torch_geometric.nn import GCNConv, global_mean_pool, GraphConv

from sklearn.linear_model import LogisticRegression

In [2]:
class CustomDataset(Dataset):
    def __init__(self, root, transform=None, pre_transform=None):
        super(CustomDataset, self).__init__(root, transform, pre_transform)

    def get_num_features(self):
        sample_data = self.get(0)
        return sample_data.num_features

    @property
    def raw_file_names(self):
        file_name_lst = []
        for f in os.listdir(self.root):
            if f.endswith(".bin"):
                if f.split("_")[2] in ["N", "PB", "UDH", "DCIS", "IC"]:
                    file_name_lst.append(f)
        return file_name_lst

    def len(self):
        return len(self.raw_file_names)

    def get(self, idx):
        file_name = self.raw_file_names[idx]
        file_path = os.path.join(self.root, file_name)
        class_label = file_name.split('_')[2]

        # Load the DGL graph using DGL's load_graphs function
        dgl_graphs, _ = load_graphs(file_path)
        dgl_graph = dgl_graphs[0]

        # Used to check if our graph had edge features - it does not
        # if 'efeat' in dgl_graph.edata:
        #     edge_features = dgl_graph.edata['efeat']
        #     print("Graph has edge features with shape:", edge_features.shape)
        # else:
        #     print("Graph does not have edge features")

        # Get node features from DGL graph
        node_features = dgl_graph.ndata['feat']

        # Convert the DGL graph to a PyTorch Geometric Data object
        src, dst = dgl_graph.edges()
        edge_index = torch.stack((src, dst), dim=0).to(torch.long)

        # Map the class label to an integer
        class_mapping = {'N': 0, 'PB': 0, 'UDH': 0, 'DCIS': 1, 'IC': 1}
        y = class_mapping[class_label]

        data = Data(x=node_features, edge_index=edge_index)

        data.y = torch.tensor([y], dtype=torch.long)
        
        return data

In [3]:
train_dataset = CustomDataset(root="C:/Users/szhu337/Desktop/project/data_cell_graph/train/")
val_dataset = CustomDataset(root="C:/Users/szhu337/Desktop/project/data_cell_graph/val/")
test_dataset = CustomDataset(root="C:/Users/szhu337/Desktop/project/data_cell_graph/test/")


train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=32)
val_loader = DataLoader(val_dataset, batch_size=32)

print(f"Total dataset size: {len(train_dataset) + len(test_dataset) + len(val_dataset)}")
print(f"Number of features: {train_dataset.get_num_features()}")
print(f"Number of classes: {train_dataset.num_classes}")
print(f"Train dataset size: {len(train_dataset)}")
print(f"Val dataset size: {len(val_dataset)}")
print(f"Test dataset size: {len(test_dataset)}")



Total dataset size: 1717
Number of features: 514
Number of classes: 2
Train dataset size: 1325
Val dataset size: 116
Test dataset size: 276


# Test model

In [4]:
import torch
from torch.nn import Linear, Dropout
from torch_geometric.nn import SAGEConv, GATv2Conv, GCNConv
import torch.nn.functional as F


class GraphSAGE(torch.nn.Module):
    """GraphSAGE"""
    def __init__(self, dim_in, dim_h, num_classes):
        super().__init__()
        self.sage1 = SAGEConv(dim_in, dim_h)
        self.sage2 = SAGEConv(dim_h, 512)
        self.lin1 = Linear(512, 256)
        self.lin2 = Linear(256, num_classes)

    def forward(self, data):
        x, edge_index, batch = data.x, data.edge_index, data.batch
        # node embedding
        x = self.sage1(x, edge_index).relu()
        x = F.dropout(x, p=0.5, training=self.training)
        x = self.sage2(x, edge_index)

        # readout layer
        x = global_mean_pool(x, batch)

        # Apply a final classifier
        x = self.lin1(x)
        x = F.dropout(x, p=0.5, training=self.training)
        x = self.lin2(x)
        return x

In [5]:
# Initialize the GCNModel
hidden_size = 1024
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [6]:
model = GraphSAGE(train_dataset.num_features, hidden_size, train_dataset.num_classes).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

In [7]:
# Training loop
num_epochs = 50
best_val_accuracy = 0

for epoch in range(num_epochs):
    # Training
    model.train()
    total_loss = 0
    for data in train_loader:
        data = data.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = F.cross_entropy(output, data.y)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    train_loss = total_loss / len(train_loader)
    
    # Validation
    model.eval()
    correct = 0
    total = 0
    val_loss = 0
    with torch.no_grad():
        for data in val_loader:
            data = data.to(device)
            output = model(data)
            loss = F.cross_entropy(output, data.y)
            val_loss += loss.item()
            pred = torch.argmax(F.sigmoid(output), dim=1)
            correct += pred.eq(data.y).sum().item()
            total += data.num_graphs
    val_loss = val_loss / len(val_loader)
    val_accuracy = correct / total
    
    # Save the best model
    if val_accuracy > best_val_accuracy:
        best_val_accuracy = val_accuracy
        torch.save(model.state_dict(), "./model/gsage_model_2_classes.pth")
    
    print(f"Epoch: {epoch+1}, Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}, Val Accuracy: {val_accuracy:.4f}")

Epoch: 1, Train Loss: 0.7795, Val Loss: 0.4464, Val Accuracy: 0.8362
Epoch: 2, Train Loss: 0.4756, Val Loss: 0.3581, Val Accuracy: 0.7759
Epoch: 3, Train Loss: 0.3717, Val Loss: 0.4806, Val Accuracy: 0.7328
Epoch: 4, Train Loss: 0.2971, Val Loss: 0.3457, Val Accuracy: 0.8448
Epoch: 5, Train Loss: 0.3102, Val Loss: 0.2891, Val Accuracy: 0.8621
Epoch: 6, Train Loss: 0.2556, Val Loss: 0.3861, Val Accuracy: 0.8190
Epoch: 7, Train Loss: 0.2295, Val Loss: 0.3203, Val Accuracy: 0.8534
Epoch: 8, Train Loss: 0.2414, Val Loss: 0.2583, Val Accuracy: 0.8621
Epoch: 9, Train Loss: 0.2133, Val Loss: 0.3105, Val Accuracy: 0.8621
Epoch: 10, Train Loss: 0.2096, Val Loss: 0.3135, Val Accuracy: 0.8621
Epoch: 11, Train Loss: 0.2428, Val Loss: 0.2656, Val Accuracy: 0.8793
Epoch: 12, Train Loss: 0.1989, Val Loss: 0.2626, Val Accuracy: 0.8879
Epoch: 13, Train Loss: 0.1643, Val Loss: 0.2341, Val Accuracy: 0.9052
Epoch: 14, Train Loss: 0.1649, Val Loss: 0.2755, Val Accuracy: 0.8966
Epoch: 15, Train Loss: 0.1799

In [8]:
best_model = GraphSAGE(train_dataset.num_features, hidden_size, train_dataset.num_classes)
best_model.load_state_dict(torch.load("./model/gsage_model_2_classes.pth"))

# Testing loop
best_model.eval()
correct = 0
total = 0
with torch.no_grad():
    for data in test_loader:
        # data = data.to(device)
        output = best_model(data)
        pred = F.sigmoid(output).argmax(dim=1)
        correct += pred.eq(data.y).sum().item()
        total += data.num_graphs

accuracy = correct / total
print(f"Test accuracy: {accuracy:.4f}")

Test accuracy: 0.8370


In [9]:
best_model.eval()
correct = 0
total = 0
with torch.no_grad():
    for data in train_loader:
        # data = data.to(device)
        output = best_model(data)
        pred = F.sigmoid(output).argmax(dim=1)
        correct += pred.eq(data.y).sum().item()
        total += data.num_graphs

accuracy = correct / total
print(f"Train accuracy: {accuracy:.4f}")

Train accuracy: 0.9547


In [10]:
best_model.eval()
correct = 0
total = 0
with torch.no_grad():
    for data in val_loader:
        # data = data.to(device)
        output = best_model(data)
        pred = F.sigmoid(output).argmax(dim=1)
        correct += pred.eq(data.y).sum().item()
        total += data.num_graphs

accuracy = correct / total
print(f"Val accuracy: {accuracy:.4f}")

Val accuracy: 0.9483


In [11]:
from torchmetrics.classification import MulticlassAUROC

import numpy as np

def graph_auroc(model, loader):
    model.eval()
    auroc_score = 0.0
    auroc_metric = MulticlassAUROC(num_classes=2, average='macro')
    i = 0
    with torch.no_grad():
        for data in loader: 
            out = model(data)
            if i == 0:
                all_out = out
                all_y = data.y
            else:
                all_out = np.append(all_out, out, axis=0)
                all_y = np.append(all_y, data.y, axis=0)

            i += 1
    auroc_score = auroc_metric(torch.tensor(all_out), torch.tensor(all_y).long())
    return auroc_score

In [12]:
graph_auroc(best_model, test_loader)

tensor(0.8958)