In [2]:
import os
import torch
import torch.nn.functional as F
from torch.nn import Linear
from torch_geometric.data import Data, Dataset, DataLoader
from torch_geometric.nn import GCNConv
from dgl.data.utils import load_graphs
from torch_geometric.utils import from_networkx
import dgl
import numpy as np

from torch_geometric.nn import GCNConv, global_mean_pool, GraphConv

In [4]:
class CustomDataset(Dataset):
    def __init__(self, root, transform=None, pre_transform=None):
        super(CustomDataset, self).__init__(root, transform, pre_transform)


    def get_num_features(self):
        sample_data = self.get(0)
        return sample_data.num_features


    @property
    def raw_file_names(self):
        return [f for f in os.listdir(self.root) if f.endswith('.bin')]

    def len(self):
        return len(self.raw_file_names)

    def get(self, idx):
        file_name = self.raw_file_names[idx]
        file_path = os.path.join(self.root, file_name)
        class_label = file_name.split('_')[2]

        # Load the DGL graph using DGL's load_graphs function
        dgl_graphs, _ = load_graphs(file_path)
        dgl_graph = dgl_graphs[0]

        # Get node features from DGL graph
        node_features = dgl_graph.ndata['feat']

        # Convert the DGL graph to a PyTorch Geometric Data object
        src, dst = dgl_graph.edges()
        edge_index = torch.stack((src, dst), dim=0).to(torch.long)
        data = Data(x=node_features, edge_index=edge_index)

        # Derive edge weight
        centroids = dgl_graph.ndata["centroid"]
        edges = dgl_graph.edges()
        src_centroid_x = centroids[edges[0]][:, 0].numpy()
        src_centroid_y = centroids[edges[0]][:, 1].numpy()
        dst_centroid_x = centroids[edges[1]][:, 0].numpy()
        dst_centroid_y = centroids[edges[1]][:, 1].numpy()
        edge_weights = 1/np.sqrt((src_centroid_x - dst_centroid_x)**2 + \
                            (src_centroid_y - dst_centroid_y)**2)
        data.edge_weights = torch.from_numpy(edge_weights)

        # Map the class label to an integer
        class_mapping = {'N': 0, 'PB': 1, 'UDH': 2, 'FEA': 3, 'ADH': 4, 'DCIS': 5, 'IC': 6}
        y = class_mapping[class_label]
        data.y = torch.tensor([y], dtype=torch.long)

        return data
    
train_dataset = CustomDataset(root='/Users/zhusichen/Desktop/2023Spring/8803/Projects/cell_graph_data/train')
val_dataset = CustomDataset(root='/Users/zhusichen/Desktop/2023Spring/8803/Projects/cell_graph_data/val')
test_dataset = CustomDataset(root='/Users/zhusichen/Desktop/2023Spring/8803/Projects/cell_graph_data/test')


train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=32)
val_loader = DataLoader(val_dataset, batch_size=32)

print(f"Total dataset size: {len(train_dataset) + len(test_dataset) + len(val_dataset)}")
print(f"Number of features: {train_dataset.get_num_features()}")
print(f"Number of classes: {train_dataset.num_classes}")
print(f"Train dataset size: {len(train_dataset)}")
print(f"Val dataset size: {len(val_dataset)}")
print(f"Test dataset size: {len(test_dataset)}")

Total dataset size: 772
Number of features: 514
Number of classes: 7
Train dataset size: 563
Val dataset size: 49
Test dataset size: 160


# Test GCN

In [5]:
class GCNModel(torch.nn.Module):
    def __init__(self, num_features, num_classes, hidden_size=256):
        super(GCNModel, self).__init__()
        self.conv1 = GraphConv(num_features, hidden_size)
        self.conv2 = GraphConv(hidden_size, 128)

    def forward(self, data):
        x, edge_index, batch, edge_weights = data.x, data.edge_index, data.batch, data.edge_weights
        x = self.conv1(x, edge_index, edge_weights)
        x = F.relu(x)
        x = F.dropout(x, p=0.5, training=self.training)
        x = self.conv2(x, edge_index, edge_weights)
        x = global_mean_pool(x, batch)
        return x

In [6]:
# Initialize the GCNModel
hidden_size = 256
device = torch.device("cpu")

In [7]:
model = GCNModel(train_dataset.num_features, train_dataset.num_classes, hidden_size).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

In [8]:
# Training loop
num_epochs = 50
best_val_accuracy = 0

for epoch in range(num_epochs):
    # Training
    model.train()
    total_loss = 0
    for data in train_loader:
        data = data.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = F.cross_entropy(output, data.y)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    train_loss = total_loss / len(train_loader)
    
    # Validation
    model.eval()
    correct = 0
    total = 0
    val_loss = 0
    with torch.no_grad():
        for data in val_loader:
            data = data.to(device)
            output = model(data)
            loss = F.cross_entropy(output, data.y)
            val_loss += loss.item()
            pred = torch.argmax(F.sigmoid(output), dim=1)
            correct += pred.eq(data.y).sum().item()
            total += data.num_graphs
    val_loss = val_loss / len(val_loader)
    val_accuracy = correct / total
    
    # Save the best model
    if val_accuracy > best_val_accuracy:
        best_val_accuracy = val_accuracy
        torch.save(model.state_dict(), "./model/best_gcn_model.pth")
    
    print(f"Epoch: {epoch+1}, Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}, Val Accuracy: {val_accuracy:.4f}")


Epoch: 1, Train Loss: 2.3315, Val Loss: 2.0809, Val Accuracy: 0.2245
Epoch: 2, Train Loss: 1.5485, Val Loss: 1.8983, Val Accuracy: 0.3265
Epoch: 3, Train Loss: 1.3571, Val Loss: 1.5885, Val Accuracy: 0.3469
Epoch: 4, Train Loss: 1.2567, Val Loss: 1.4259, Val Accuracy: 0.4286
Epoch: 5, Train Loss: 1.2459, Val Loss: 1.4967, Val Accuracy: 0.3469
Epoch: 6, Train Loss: 1.1141, Val Loss: 1.3874, Val Accuracy: 0.4286
Epoch: 7, Train Loss: 1.0879, Val Loss: 1.2798, Val Accuracy: 0.4490
Epoch: 8, Train Loss: 1.0560, Val Loss: 1.3705, Val Accuracy: 0.4082
Epoch: 9, Train Loss: 0.9948, Val Loss: 1.2682, Val Accuracy: 0.4490
Epoch: 10, Train Loss: 1.0084, Val Loss: 1.2991, Val Accuracy: 0.4694
Epoch: 11, Train Loss: 0.9654, Val Loss: 1.2766, Val Accuracy: 0.4898
Epoch: 12, Train Loss: 0.9561, Val Loss: 1.2319, Val Accuracy: 0.5102
Epoch: 13, Train Loss: 0.9089, Val Loss: 1.2260, Val Accuracy: 0.5714
Epoch: 14, Train Loss: 0.8834, Val Loss: 1.3500, Val Accuracy: 0.4082
Epoch: 15, Train Loss: 0.8570

In [9]:
best_model = GCNModel(train_dataset.num_features, train_dataset.num_classes, hidden_size)
best_model.load_state_dict(torch.load("./model/best_gcn_model.pth"))

# Testing loop
best_model.eval()
correct = 0
total = 0
with torch.no_grad():
    for data in test_loader:
        data = data.to(device)
        output = best_model(data)
        pred = F.sigmoid(output).argmax(dim=1)
        correct += pred.eq(data.y).sum().item()
        total += data.num_graphs

accuracy = correct / total
print(f"Test accuracy: {accuracy:.4f}")

Test accuracy: 0.5062


# Test add/max pool

In [None]:
class GCNModel(torch.nn.Module):
    def __init__(self, num_features, num_classes, hidden_size=256):
        super(GCNModel, self).__init__()
        self.conv1 = GraphConv(num_features, hidden_size)
        self.conv2 = GraphConv(hidden_size, 128)

    def forward(self, data):
        x, edge_index, batch = data.x, data.edge_index, data.batch
        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = F.dropout(x, p=0.5, training=self.training)
        x = self.conv2(x, edge_index)
        x = global_mean_pool(x, batch)
        return x