# Imports

In [30]:
import os
import torch
import torch.nn.functional as F
from torch_geometric.data import Data, Dataset, DataLoader
from torch_geometric.nn import GCNConv
from dgl.data.utils import load_graphs
from torch_geometric.utils import from_networkx
import dgl

# Loading Data

I have the following folder structure where all size images are in their respective folders:

```
- some root folder
    - train
    - val
    - test
```

In [31]:
class CustomDataset(Dataset):
    def __init__(self, root, transform=None, pre_transform=None):
        super(CustomDataset, self).__init__(root, transform, pre_transform)


    def get_num_features(self):
        sample_data = self.get(0)
        return sample_data.num_features


    @property
    def raw_file_names(self):
        return [f for f in os.listdir(self.root) if f.endswith('.bin')]

    def len(self):
        return len(self.raw_file_names)

    def get(self, idx):
        file_name = self.raw_file_names[idx]
        file_path = os.path.join(self.root, file_name)
        class_label = file_name.split('_')[2]

        # Load the DGL graph using DGL's load_graphs function
        dgl_graphs, _ = load_graphs(file_path)
        dgl_graph = dgl_graphs[0]

        # Used to check if our graph had edge features - it does not
        # if 'efeat' in dgl_graph.edata:
        #     edge_features = dgl_graph.edata['efeat']
        #     print("Graph has edge features with shape:", edge_features.shape)
        # else:
        #     print("Graph does not have edge features")

        # Get node features from DGL graph
        node_features = dgl_graph.ndata['feat']

        # Convert the DGL graph to a PyTorch Geometric Data object
        src, dst = dgl_graph.edges()
        edge_index = torch.stack((src, dst), dim=0).to(torch.long)
        data = Data(x=node_features, edge_index=edge_index)

        # Map the class label to an integer
        class_mapping = {'N': 0, 'PB': 1, 'UDH': 2, 'FEA': 3, 'ADH': 4, 'DCIS': 5, 'IC': 6}
        y = class_mapping[class_label]
        data.y = torch.tensor([y], dtype=torch.long)

        return data
    
train_dataset = CustomDataset(root='C:/Users/George/Downloads/cell_graph_dataset/train') # set path for all train graphs
val_dataset = CustomDataset(root='C:/Users/George/Downloads/cell_graph_dataset/val')
test_dataset = CustomDataset(root='C:/Users/George/Downloads/cell_graph_dataset/test')


train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=32)
val_loader = DataLoader(val_dataset, batch_size=32)

print(f"Total dataset size: {len(train_dataset) + len(test_dataset) + len(val_dataset)}")
print(f"Number of features: {train_dataset.get_num_features()}")
print(f"Number of classes: {train_dataset.num_classes}")
print(f"Train dataset size: {len(train_dataset)}")
print(f"Val dataset size: {len(val_dataset)}")
print(f"Test dataset size: {len(test_dataset)}")


Total dataset size: 772
Number of features: 514
Number of classes: 7
Train dataset size: 563
Val dataset size: 49
Test dataset size: 160


# Defining GCN Model

In [32]:
import torch
import torch.nn.functional as F
from torch_geometric.nn import GCNConv, global_mean_pool, GraphConv

class GCNModel(torch.nn.Module):
    def __init__(self, num_features, num_classes, hidden_size=256):
        super(GCNModel, self).__init__()
        self.conv1 = GraphConv(num_features, hidden_size)
        self.conv2 = GraphConv(hidden_size, num_classes)


    def forward(self, data):
        x, edge_index, batch = data.x, data.edge_index, data.batch
        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = F.dropout(x, p=0.5, training=self.training)
        x = self.conv2(x, edge_index)
        x = global_mean_pool(x, batch)  # Perform global mean pooling
        return x


In [33]:
# Initialize the GCNModel
hidden_size = 1024
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = GCNModel(train_dataset.num_features, train_dataset.num_classes, hidden_size).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

# Training loop
num_epochs = 50
best_val_accuracy = 0

for epoch in range(num_epochs):
    # Training
    model.train()
    total_loss = 0
    for data in train_loader:
        data = data.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = F.cross_entropy(output, data.y)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    train_loss = total_loss / len(train_loader)
    
    # Validation
    model.eval()
    correct = 0
    total = 0
    val_loss = 0
    with torch.no_grad():
        for data in val_loader:
            data = data.to(device)
            output = model(data)
            loss = F.cross_entropy(output, data.y)
            val_loss += loss.item()
            pred = output.argmax(dim=1)
            correct += pred.eq(data.y).sum().item()
            total += data.num_graphs
    val_loss = val_loss / len(val_loader)
    val_accuracy = correct / total
    
    # Save the best model
    if val_accuracy > best_val_accuracy:
        best_val_accuracy = val_accuracy
        torch.save(model.state_dict(), "best_gcn_model.pth")
    
    print(f"Epoch: {epoch+1}, Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}, Val Accuracy: {val_accuracy:.4f}")

# Testing loop
model.eval()
correct = 0
total = 0
with torch.no_grad():
    for data in test_loader:
        data = data.to(device)
        output = model(data)
        pred = output.argmax(dim=1)
        correct += pred.eq(data.y).sum().item()
        total += data.num_graphs

accuracy = correct / total
print(f"Test accuracy: {accuracy:.4f}")

Epoch: 1, Train Loss: 119.3547, Val Loss: 2.0594, Val Accuracy: 0.2245
Epoch: 2, Train Loss: 1.8386, Val Loss: 1.8592, Val Accuracy: 0.2245
Epoch: 3, Train Loss: 1.6930, Val Loss: 1.8186, Val Accuracy: 0.2245
Epoch: 4, Train Loss: 1.6342, Val Loss: 1.8057, Val Accuracy: 0.2245
Epoch: 5, Train Loss: 1.5515, Val Loss: 1.7731, Val Accuracy: 0.2245
Epoch: 6, Train Loss: 1.5582, Val Loss: 1.6654, Val Accuracy: 0.2449
Epoch: 7, Train Loss: 1.4879, Val Loss: 1.6491, Val Accuracy: 0.2449
Epoch: 8, Train Loss: 1.4476, Val Loss: 1.5495, Val Accuracy: 0.4082
Epoch: 9, Train Loss: 1.4540, Val Loss: 1.5111, Val Accuracy: 0.2857
Epoch: 10, Train Loss: 1.3439, Val Loss: 1.3520, Val Accuracy: 0.4694
Epoch: 11, Train Loss: 1.3041, Val Loss: 1.4936, Val Accuracy: 0.3878
Epoch: 12, Train Loss: 1.1896, Val Loss: 1.4501, Val Accuracy: 0.3673
Epoch: 13, Train Loss: 1.1993, Val Loss: 1.4517, Val Accuracy: 0.4082
Epoch: 14, Train Loss: 1.2011, Val Loss: 1.4600, Val Accuracy: 0.4082
Epoch: 15, Train Loss: 1.17

# Defining GIN Model