In [3]:
import os
import torch
import torch.nn.functional as F
from torch.nn import Linear
from torch_geometric.data import Data, Dataset, DataLoader
from torch_geometric.nn import GCNConv
from dgl.data.utils import load_graphs
from torch_geometric.utils import from_networkx
import dgl

from torch_geometric.nn import GCNConv, global_mean_pool, GraphConv
import numpy as np

In [4]:
class CustomDataset(Dataset):
    def __init__(self, root, transform=None, pre_transform=None):
        super(CustomDataset, self).__init__(root, transform, pre_transform)


    def get_num_features(self):
        sample_data = self.get(0)
        return sample_data.num_features


    @property
    def raw_file_names(self):
        file_name_lst = []
        for f in os.listdir(self.root):
            if f.endswith(".bin"):
                if f.split("_")[2] in ["N", "PB", "UDH", "DCIS", "IC"]:
                    file_name_lst.append(f)
        return file_name_lst

    def len(self):
        return len(self.raw_file_names)

    def get(self, idx):
        file_name = self.raw_file_names[idx]
        file_path = os.path.join(self.root, file_name)
        class_label = file_name.split('_')[2]

        # Load the DGL graph using DGL's load_graphs function
        dgl_graphs, _ = load_graphs(file_path)
        dgl_graph = dgl_graphs[0]

        # Used to check if our graph had edge features - it does not
        # if 'efeat' in dgl_graph.edata:
        #     edge_features = dgl_graph.edata['efeat']
        #     print("Graph has edge features with shape:", edge_features.shape)
        # else:
        #     print("Graph does not have edge features")

        # Get node features from DGL graph
        node_features = dgl_graph.ndata['feat']

        # Convert the DGL graph to a PyTorch Geometric Data object
        src, dst = dgl_graph.edges()
        edge_index = torch.stack((src, dst), dim=0).to(torch.long)
        data = Data(x=node_features, edge_index=edge_index)

        # Derive edge weight
        centroids = dgl_graph.ndata["centroid"]
        edges = dgl_graph.edges()
        src_centroid_x = centroids[edges[0]][:, 0].numpy()
        src_centroid_y = centroids[edges[0]][:, 1].numpy()
        dst_centroid_x = centroids[edges[1]][:, 0].numpy()
        dst_centroid_y = centroids[edges[1]][:, 1].numpy()
        edge_weights = 1/np.sqrt((src_centroid_x - dst_centroid_x)**2 + \
                            (src_centroid_y - dst_centroid_y)**2)
        data.edge_weights = torch.from_numpy(edge_weights)

        # Map the class label to an integer
        class_mapping = {'N': 0, 'PB': 0, 'UDH': 0, 'DCIS': 1, 'IC': 1}
        y = class_mapping[class_label]
        data.y = torch.tensor([y], dtype=torch.long)

        return data
    
train_dataset = CustomDataset(root="C:/Users/szhu337/Desktop/project/data_cell_graph/train/")
val_dataset = CustomDataset(root="C:/Users/szhu337/Desktop/project/data_cell_graph/val/")
test_dataset = CustomDataset(root="C:/Users/szhu337/Desktop/project/data_cell_graph/test/")


train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=32)
val_loader = DataLoader(val_dataset, batch_size=32)

print(f"Total dataset size: {len(train_dataset) + len(test_dataset) + len(val_dataset)}")
print(f"Number of features: {train_dataset.get_num_features()}")
print(f"Number of classes: {train_dataset.num_classes}")
print(f"Train dataset size: {len(train_dataset)}")
print(f"Val dataset size: {len(val_dataset)}")
print(f"Test dataset size: {len(test_dataset)}")



Total dataset size: 1717
Number of features: 514
Number of classes: 2
Train dataset size: 1325
Val dataset size: 116
Test dataset size: 276


# Test GCN

In [5]:
class GCNModel(torch.nn.Module):
    def __init__(self, num_features, num_classes, hidden_size=256):
        super(GCNModel, self).__init__()
        self.conv1 = GraphConv(num_features, hidden_size)
        self.conv2 = GraphConv(hidden_size, 512)
        self.lin1 = Linear(512, 256)
        self.lin2 = Linear(256, num_classes)

    def forward(self, data):
        x, edge_index, batch, edge_weights = data.x, data.edge_index, data.batch, data.edge_weights
        # node embedding
        x = self.conv1(x, edge_index, edge_weights)
        x = F.relu(x)
        x = F.dropout(x, p=0.5, training=self.training)
        x = self.conv2(x, edge_index, edge_weights)
        # readout layer
        x = global_mean_pool(x, batch)
        # Apply a final classifier
        x = self.lin1(x)
        x = F.dropout(x, p=0.5, training=self.training)
        x = self.lin2(x)
        return x

In [6]:
# Initialize the GCNModel
hidden_size = 1024
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

device(type='cuda')

In [7]:
model = GCNModel(train_dataset.num_features, train_dataset.num_classes, hidden_size).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

In [8]:
# Training loop
num_epochs = 50
best_val_accuracy = 0

for epoch in range(num_epochs):
    # Training
    model.train()
    total_loss = 0
    for data in train_loader:
        data = data.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = F.cross_entropy(output, data.y)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    train_loss = total_loss / len(train_loader)
    
    # Validation
    model.eval()
    correct = 0
    total = 0
    val_loss = 0
    with torch.no_grad():
        for data in val_loader:
            data = data.to(device)
            output = model(data)
            loss = F.cross_entropy(output, data.y)
            val_loss += loss.item()
            pred = torch.argmax(F.sigmoid(output), dim=1)
            correct += pred.eq(data.y).sum().item()
            total += data.num_graphs
    val_loss = val_loss / len(val_loader)
    val_accuracy = correct / total
    
    # Save the best model
    if val_accuracy > best_val_accuracy:
        best_val_accuracy = val_accuracy
        torch.save(model.state_dict(), "./model/gcn_w_edge_weight_model_2_classes.pth")
    
    print(f"Epoch: {epoch+1}, Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}, Val Accuracy: {val_accuracy:.4f}")


Epoch: 1, Train Loss: 0.6668, Val Loss: 0.4963, Val Accuracy: 0.7328
Epoch: 2, Train Loss: 0.3765, Val Loss: 0.5105, Val Accuracy: 0.8017
Epoch: 3, Train Loss: 0.3445, Val Loss: 0.3909, Val Accuracy: 0.8103
Epoch: 4, Train Loss: 0.2563, Val Loss: 0.5320, Val Accuracy: 0.7759
Epoch: 5, Train Loss: 0.2371, Val Loss: 0.4359, Val Accuracy: 0.8362
Epoch: 6, Train Loss: 0.2108, Val Loss: 0.3201, Val Accuracy: 0.8879
Epoch: 7, Train Loss: 0.2534, Val Loss: 0.2912, Val Accuracy: 0.8707
Epoch: 8, Train Loss: 0.2041, Val Loss: 0.2807, Val Accuracy: 0.8879
Epoch: 9, Train Loss: 0.2250, Val Loss: 0.2940, Val Accuracy: 0.8707
Epoch: 10, Train Loss: 0.1832, Val Loss: 0.3257, Val Accuracy: 0.8793
Epoch: 11, Train Loss: 0.2095, Val Loss: 0.2324, Val Accuracy: 0.9052
Epoch: 12, Train Loss: 0.1654, Val Loss: 0.2871, Val Accuracy: 0.9052
Epoch: 13, Train Loss: 0.1926, Val Loss: 0.3953, Val Accuracy: 0.8793
Epoch: 14, Train Loss: 0.1634, Val Loss: 0.2774, Val Accuracy: 0.9310
Epoch: 15, Train Loss: 0.1809

In [9]:
best_model = GCNModel(train_dataset.num_features, train_dataset.num_classes, hidden_size)
best_model.load_state_dict(torch.load("./model/gcn_w_edge_weight_model_2_classes.pth"))

# Testing loop
best_model.eval()
correct = 0
total = 0
with torch.no_grad():
    for data in test_loader:
        # data = data.to(device)
        output = best_model(data)
        pred = F.sigmoid(output).argmax(dim=1)
        correct += pred.eq(data.y).sum().item()
        total += data.num_graphs

accuracy = correct / total
print(f"Test accuracy: {accuracy:.4f}")

Test accuracy: 0.8261


In [10]:
best_model.eval()
correct = 0
total = 0
with torch.no_grad():
    for data in train_loader:
        # data = data.to(device)
        output = best_model(data)
        pred = F.sigmoid(output).argmax(dim=1)
        correct += pred.eq(data.y).sum().item()
        total += data.num_graphs

accuracy = correct / total
print(f"Train accuracy: {accuracy:.4f}")

Train accuracy: 0.9615


In [11]:
best_model.eval()
correct = 0
total = 0
with torch.no_grad():
    for data in val_loader:
        # data = data.to(device)
        output = best_model(data)
        pred = F.sigmoid(output).argmax(dim=1)
        correct += pred.eq(data.y).sum().item()
        total += data.num_graphs

accuracy = correct / total
print(f"Val accuracy: {accuracy:.4f}")

Val accuracy: 0.9483


In [12]:
from torchmetrics.classification import MulticlassAUROC

import numpy as np

def graph_auroc(model, loader):
    model.eval()
    auroc_score = 0.0
    auroc_metric = MulticlassAUROC(num_classes=2, average='macro')
    i = 0
    with torch.no_grad():
        for data in loader: 
            out = model(data)
            if i == 0:
                all_out = out
                all_y = data.y
            else:
                all_out = np.append(all_out, out, axis=0)
                all_y = np.append(all_y, data.y, axis=0)

            i += 1
    auroc_score = auroc_metric(torch.tensor(all_out), torch.tensor(all_y).long())
    return auroc_score

In [13]:
graph_auroc(best_model, test_loader)

tensor(0.9159)