In [1]:
import os
import torch
import torch.nn.functional as F
from torch.nn import Linear

import torch_geometric.transforms as T
from torch_geometric.datasets import Planetoid
from torch_geometric.logging import init_wandb, log
from torch_geometric.nn import GATConv
from torch_geometric.data import Data, Dataset, DataLoader
from dgl.data.utils import load_graphs
from torch_geometric.utils import from_networkx
from torch_geometric.nn import global_mean_pool

import dgl

In [2]:
torch.set_default_dtype(torch.float32)

In [3]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [4]:
class CustomDataset(Dataset):
    def __init__(self, root, transform=None, pre_transform=None):
        super(CustomDataset, self).__init__(root, transform, pre_transform)


    def get_num_features(self):
        sample_data = self.get(0)
        return sample_data.num_features


    @property
    def raw_file_names(self):
        return [f for f in os.listdir(self.root) if f.endswith('.bin')]

    def len(self):
        return len(self.raw_file_names)

    def get(self, idx):
        file_name = self.raw_file_names[idx]
        file_path = os.path.join(self.root, file_name)
        class_label = file_name.split('_')[2]

        # Load the DGL graph using DGL's load_graphs function
        dgl_graphs, _ = load_graphs(file_path)
        dgl_graph = dgl_graphs[0]

        # Used to check if our graph had edge features - it does not
        # if 'efeat' in dgl_graph.edata:
        #     edge_features = dgl_graph.edata['efeat']
        #     print("Graph has edge features with shape:", edge_features.shape)
        # else:
        #     print("Graph does not have edge features")

        # Get node features from DGL graph
        node_features = dgl_graph.ndata['feat']

        # Convert the DGL graph to a PyTorch Geometric Data object
        src, dst = dgl_graph.edges()
        edge_index = torch.stack((src, dst), dim=0).to(torch.long)
        data = Data(x=node_features, edge_index=edge_index)


        # Map the class label to an integer
        class_mapping = {'N': 0, 'PB': 1, 'UDH': 2, 'FEA': 3, 'ADH': 4, 'DCIS': 5, 'IC': 6}
        y = class_mapping[class_label]
        data.y = torch.tensor([y], dtype=torch.long)

        return data
    
train_dataset = CustomDataset(root="C:/Users/szhu337/Desktop/project/data_cell_graph/train/")
val_dataset = CustomDataset(root="C:/Users/szhu337/Desktop/project/data_cell_graph/val/")
test_dataset = CustomDataset(root="C:/Users/szhu337/Desktop/project/data_cell_graph/test/")


train_loader = DataLoader(train_dataset, batch_size=6, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=6)
val_loader = DataLoader(val_dataset, batch_size=6)

print(f"Total dataset size: {len(train_dataset) + len(test_dataset) + len(val_dataset)}")
print(f"Number of features: {train_dataset.get_num_features()}")
print(f"Number of classes: {train_dataset.num_classes}")
print(f"Train dataset size: {len(train_dataset)}")
print(f"Val dataset size: {len(val_dataset)}")
print(f"Test dataset size: {len(test_dataset)}")



Total dataset size: 2856
Number of features: 514
Number of classes: 7
Train dataset size: 2244
Val dataset size: 190
Test dataset size: 422


In [5]:
class GAT(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels, num_classes, heads):
        super().__init__()
        self.conv1 = GATConv(in_channels, hidden_channels, heads, dropout=0.5)
        # On the Pubmed dataset, use `heads` output heads in `conv2`.
        self.conv2 = GATConv(hidden_channels * heads, out_channels, heads,
                             concat=False, dropout=0.5)
        self.lin1 = Linear(out_channels, 256)
        self.lin2 = Linear(256, num_classes)

    def forward(self, data):
        x, edge_index, batch = data.x, data.edge_index, data.batch
        # node embedding
        x = F.elu(self.conv1(x, edge_index))
        x = F.dropout(x, p=0.5, training=self.training)
        x = self.conv2(x, edge_index)

        # readout layer
        x = global_mean_pool(x, batch)

        # Apply a final classifier
        x = self.lin1(x)
        x = F.dropout(x, p=0.5, training=self.training)
        x = self.lin2(x)
        return x

In [6]:
hidden_channels = 128
out_channels = 512
heads = 7

In [7]:
model = GAT(train_dataset.num_features, hidden_channels, out_channels, train_dataset.num_classes,
            heads).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

In [8]:
num_epochs = 50
best_val_accuracy = 0

for epoch in range(num_epochs):
    # Training
    model.train()
    total_loss = 0
    for data in train_loader:
        data = data.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = F.cross_entropy(output, data.y)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    train_loss = total_loss / len(train_loader)
    
    # Validation
    model.eval()
    correct = 0
    total = 0
    val_loss = 0
    with torch.no_grad():
        for data in val_loader:
            data = data.to(device)
            output = model(data)
            loss = F.cross_entropy(output, data.y)
            val_loss += loss.item()
            pred = torch.argmax(F.sigmoid(output), dim=1)
            correct += pred.eq(data.y).sum().item()
            total += data.num_graphs
    val_loss = val_loss / len(val_loader)
    val_accuracy = correct / total
    
    # Save the best model
    if val_accuracy > best_val_accuracy:
        best_val_accuracy = val_accuracy
        torch.save(model.state_dict(), "./model/gat_model_7_classes.pth")
    
    print(f"Epoch: {epoch+1}, Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}, Val Accuracy: {val_accuracy:.4f}")


Epoch: 1, Train Loss: 1.7164, Val Loss: 2.3345, Val Accuracy: 0.2632
Epoch: 2, Train Loss: 1.4767, Val Loss: 2.7284, Val Accuracy: 0.2632
Epoch: 3, Train Loss: 1.3418, Val Loss: 2.3778, Val Accuracy: 0.2895
Epoch: 4, Train Loss: 1.2511, Val Loss: 2.1480, Val Accuracy: 0.2684
Epoch: 5, Train Loss: 1.2143, Val Loss: 2.4257, Val Accuracy: 0.3632
Epoch: 6, Train Loss: 1.1669, Val Loss: 2.2413, Val Accuracy: 0.3789
Epoch: 7, Train Loss: 1.1228, Val Loss: 4.1078, Val Accuracy: 0.2895
Epoch: 8, Train Loss: 1.0869, Val Loss: 3.0760, Val Accuracy: 0.3211
Epoch: 9, Train Loss: 1.0609, Val Loss: 3.1491, Val Accuracy: 0.3053
Epoch: 10, Train Loss: 1.0532, Val Loss: 2.3070, Val Accuracy: 0.3737
Epoch: 11, Train Loss: 1.0358, Val Loss: 4.9992, Val Accuracy: 0.2632
Epoch: 12, Train Loss: 1.0320, Val Loss: 3.6605, Val Accuracy: 0.3053
Epoch: 13, Train Loss: 1.0061, Val Loss: 3.5611, Val Accuracy: 0.2632
Epoch: 14, Train Loss: 1.0066, Val Loss: 4.1891, Val Accuracy: 0.3105
Epoch: 15, Train Loss: 1.0221

In [10]:
best_model = GAT(train_dataset.num_features, hidden_channels, out_channels, 
                 train_dataset.num_classes, heads)
best_model.load_state_dict(torch.load("./model/gat_model_7_classes.pth"))

# Testing loop
best_model.eval()
correct = 0
total = 0
with torch.no_grad():
    for data in test_loader:
        # data = data.to(device)
        output = best_model(data)
        pred = F.sigmoid(output).argmax(dim=1)
        correct += pred.eq(data.y).sum().item()
        total += data.num_graphs

accuracy = correct / total
print(f"Test accuracy: {accuracy:.4f}")

Test accuracy: 0.3815


In [11]:
best_model.eval()
correct = 0
total = 0
with torch.no_grad():
    for data in train_loader:
        # data = data.to(device)
        output = best_model(data)
        pred = F.sigmoid(output).argmax(dim=1)
        correct += pred.eq(data.y).sum().item()
        total += data.num_graphs

accuracy = correct / total
print(f"Train accuracy: {accuracy:.4f}")

Train accuracy: 0.6092


In [12]:
best_model.eval()
correct = 0
total = 0
with torch.no_grad():
    for data in val_loader:
        # data = data.to(device)
        output = best_model(data)
        pred = F.sigmoid(output).argmax(dim=1)
        correct += pred.eq(data.y).sum().item()
        total += data.num_graphs

accuracy = correct / total
print(f"Val accuracy: {accuracy:.4f}")

Val accuracy: 0.4263


In [13]:
from torchmetrics.classification import MulticlassAUROC

import numpy as np

def graph_auroc(model, loader):
    model.eval()
    auroc_score = 0.0
    auroc_metric = MulticlassAUROC(num_classes=7, average='macro')
    i = 0
    with torch.no_grad():
        for data in loader: 
            out = model(data)
            if i == 0:
                all_out = out
                all_y = data.y
            else:
                all_out = np.append(all_out, out, axis=0)
                all_y = np.append(all_y, data.y, axis=0)

            i += 1
    auroc_score = auroc_metric(torch.tensor(all_out), torch.tensor(all_y).long())
    return auroc_score

In [14]:
graph_auroc(best_model, test_loader)

tensor(0.8042)