In [15]:
# Install required packages.
!pip install torch

import os
import torch
from tqdm import tqdm 

os.environ['TORCH'] = torch.__version__
print(torch.__version__)

!pip install -q torch-scatter -f https://data.pyg.org/whl/torch-${TORCH}.html
!pip install -q torch-sparse -f https://data.pyg.org/whl/torch-${TORCH}.html
!pip install -q git+https://github.com/pyg-team/pytorch_geometric.git

2.3.1+cu121


In [16]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)


cuda


In [17]:
# import torch
# import pickle
# from torch_geometric.datasets import TUDataset

# from google.colab import drive
# drive.mount('/content/drive')
# FOLDERNAME = 'cse493g1/project/data'
# assert FOLDERNAME is not None, "[!] Enter the foldername."

# %cd drive/My\ Drive

# dataset_path = os.path.join(FOLDERNAME, 'solutions_dataset_gnn_graphs.pkl')
# with open(dataset_path, 'rb') as f:
#     dataset = pickle.load(f)

from construct_gnn_dataset import SolutionDataset

dataset = SolutionDataset(root='../../data/raw')

In [18]:
print()
print(f'Dataset: {dataset}:')
print('====================')
print(f'Number of graphs: {len(dataset)}')
print(f'Number of features: {dataset.num_features}')
# print(f'Number of classes: {dataset.num_classes}')

data = dataset[0]  # Get the first graph object.

print()
print(data)
print('=============================================================')

# Gather some statistics about the first graph.
print(f'Number of nodes: {data.num_nodes}')
print(f'Number of edges: {data.num_edges}')
print(f'Average node degree: {data.num_edges / data.num_nodes:.2f}')
print(f'Has isolated nodes: {data.has_isolated_nodes()}')
print(f'Has self-loops: {data.has_self_loops()}')
print(f'Is undirected: {data.is_undirected()}')


Dataset: SolutionDataset(5000):
Number of graphs: 5000
Number of features: 139

Data(edge_index=[2, 165], name=[166], cooccurrences=[166, 139], num_nodes=166, x=[166, 139], y=[10])
Number of nodes: 166
Number of edges: 165
Average node degree: 0.99
Has isolated nodes: False
Has self-loops: False
Is undirected: False


In [19]:
NODE_FEATURES = dataset.num_features
NUM_CLASSES = data.y.size(-1)

print(NODE_FEATURES)
print(NUM_CLASSES)


139
10


In [20]:
torch.manual_seed(12345)

assert (len(dataset) % 10 == 0)
split = (len(dataset) * 9) // 10

train_dataset = [item for i, item in enumerate(dataset) if (i + 1) % 10 != 0]
test_dataset = [item for i, item in enumerate(dataset) if (i + 1) % 10 == 0]

print(f'Number of training graphs: {len(train_dataset)}')
print(f'Number of test graphs: {len(test_dataset)}')


KeyboardInterrupt: 

In [None]:
from torch_geometric.loader import DataLoader

train_loader = DataLoader(train_dataset, batch_size=1, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False)

for step, data in enumerate(train_loader):
    data = data.to(device)
    
    if (step + 1) % 2500 == 0:
        print(f'Step {step + 1}:')
        print('=======')
        print(f'Number of graphs in the current batch: {data.num_graphs}')
        print(data)
        print()

Step 2500:
Number of graphs in the current batch: 1
DataBatch(edge_index=[2, 155], name=[1], cooccurrences=[156, 139], num_nodes=156, x=[156, 139], y=[10], batch=[156], ptr=[2])



## Training a Graph Neural Network (GNN)

Training a GNN for graph classification usually follows a simple recipe:

1. Embed each node by performing multiple rounds of message passing
2. Aggregate node embeddings into a unified graph embedding (**readout layer**)
3. Train a final classifier on the graph embedding

There exists multiple **readout layers** in literature, but the most common one is to simply take the average of node embeddings:

$$
\mathbf{x}_{\mathcal{G}} = \frac{1}{|\mathcal{V}|} \sum_{v \in \mathcal{V}} \mathcal{x}^{(L)}_v
$$

PyTorch Geometric provides this functionality via [`torch_geometric.nn.global_mean_pool`](https://pytorch-geometric.readthedocs.io/en/latest/modules/nn.html#torch_geometric.nn.glob.global_mean_pool), which takes in the node embeddings of all nodes in the mini-batch and the assignment vector `batch` to compute a graph embedding of size `[batch_size, hidden_channels]` for each graph in the batch.

The final architecture for applying GNNs to the task of graph classification then looks as follows and allows for complete end-to-end training:

In [None]:
from torch.nn import Linear
import torch.nn.functional as F
from torch_geometric.nn import GCNConv
from torch_geometric.nn import global_mean_pool


class GCN(torch.nn.Module):
    def __init__(self, hidden_channels):
        super(GCN, self).__init__()
        torch.manual_seed(12345)
        self.conv1 = GCNConv(NODE_FEATURES, hidden_channels)
        self.conv2 = GCNConv(hidden_channels, hidden_channels)
        self.conv3 = GCNConv(hidden_channels, hidden_channels)
        self.conv4 = GCNConv(hidden_channels, hidden_channels)
        self.conv5 = GCNConv(hidden_channels, hidden_channels)
        self.conv6 = GCNConv(hidden_channels, hidden_channels)
        self.conv7 = GCNConv(hidden_channels, hidden_channels)
        self.conv8 = GCNConv(hidden_channels, hidden_channels)
        self.lin = Linear(hidden_channels, NUM_CLASSES)

    def forward(self, x, edge_index, batch):
        x = self.conv1(x, edge_index)
        x = x.relu()
        x = self.conv2(x, edge_index)
        x = x.relu()
        x = self.conv3(x, edge_index)
        x = x.relu()
        x = self.conv4(x, edge_index)
        x = x.relu()
        x = self.conv5(x, edge_index)
        x = x.relu()
        x = self.conv6(x, edge_index)
        x = x.relu()
        x = self.conv7(x, edge_index)
        x = x.relu()
        x = self.conv8(x, edge_index)

        x = global_mean_pool(x, batch)  
        
        x = F.dropout(x, p=0.5, training=self.training)
        x = self.lin(x)

        return x.squeeze(0) if x.size(0) == 1 else x

model = GCN(hidden_channels=500).to(device) 
print(model)

GCN(
  (conv1): GCNConv(139, 500)
  (conv2): GCNConv(500, 500)
  (conv3): GCNConv(500, 500)
  (conv4): GCNConv(500, 500)
  (conv5): GCNConv(500, 500)
  (conv6): GCNConv(500, 500)
  (conv7): GCNConv(500, 500)
  (conv8): GCNConv(500, 500)
  (lin): Linear(in_features=500, out_features=10, bias=True)
)


Here, we again make use of the [`GCNConv`](https://pytorch-geometric.readthedocs.io/en/latest/modules/nn.html#torch_geometric.nn.conv.GCNConv) with $\mathrm{ReLU}(x) = \max(x, 0)$ activation for obtaining localized node embeddings, before we apply our final classifier on top of a graph readout layer.

Let's train our network for a few epochs to see how well it performs on the training as well as test set:

In [None]:
EPOCHS = 25
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
criterion = torch.nn.CrossEntropyLoss()

epoch_losses = {}

model.train()
for epoch in range(EPOCHS):
    epoch_loss = 0
    for batch in tqdm(train_loader, desc="Epoch Progress"):
        batch = batch.to(device)
        optimizer.zero_grad()
        out = model(batch.x, batch.edge_index, batch.batch)
        loss = criterion(out, batch.y.to(torch.float32))
        loss.backward()
        optimizer.step()
        epoch_loss += loss.item()

        _, predicted = torch.max(out, dim=0)
        
    epoch_losses[epoch + 1] = epoch_loss
    print(f"Epoch {epoch+1}/{EPOCHS}, Loss: {epoch_loss:.4f}")

print("Training completed.")

Epoch Progress: 100%|██████████| 4500/4500 [00:35<00:00, 127.05it/s]


Epoch 1/50, Loss: 13014.5209


Epoch Progress: 100%|██████████| 4500/4500 [00:35<00:00, 128.04it/s]


Epoch 2/50, Loss: 11771.9692


Epoch Progress: 100%|██████████| 4500/4500 [00:35<00:00, 127.45it/s]


Epoch 3/50, Loss: 11893.9694


Epoch Progress: 100%|██████████| 4500/4500 [00:35<00:00, 127.95it/s]


Epoch 4/50, Loss: 11902.9588


Epoch Progress: 100%|██████████| 4500/4500 [00:35<00:00, 128.31it/s]


Epoch 5/50, Loss: 11922.5749


Epoch Progress: 100%|██████████| 4500/4500 [00:34<00:00, 129.13it/s]


Epoch 6/50, Loss: 11903.1666


Epoch Progress: 100%|██████████| 4500/4500 [00:35<00:00, 126.84it/s]


Epoch 7/50, Loss: 11911.6625


Epoch Progress: 100%|██████████| 4500/4500 [00:34<00:00, 128.94it/s]


Epoch 8/50, Loss: 11950.9489


Epoch Progress: 100%|██████████| 4500/4500 [00:35<00:00, 128.07it/s]


Epoch 9/50, Loss: 11923.1090


Epoch Progress: 100%|██████████| 4500/4500 [00:35<00:00, 127.46it/s]


Epoch 10/50, Loss: 11921.1426


Epoch Progress: 100%|██████████| 4500/4500 [00:35<00:00, 127.99it/s]


Epoch 11/50, Loss: 11913.9150


Epoch Progress: 100%|██████████| 4500/4500 [00:35<00:00, 127.92it/s]


Epoch 12/50, Loss: 11880.1381


Epoch Progress: 100%|██████████| 4500/4500 [00:34<00:00, 128.60it/s]


Epoch 13/50, Loss: 11924.7974


Epoch Progress: 100%|██████████| 4500/4500 [00:35<00:00, 126.39it/s]


Epoch 14/50, Loss: 11936.1784


Epoch Progress: 100%|██████████| 4500/4500 [00:35<00:00, 127.54it/s]


Epoch 15/50, Loss: 11925.4601


Epoch Progress: 100%|██████████| 4500/4500 [00:35<00:00, 126.75it/s]


Epoch 16/50, Loss: 11915.2960


Epoch Progress: 100%|██████████| 4500/4500 [00:35<00:00, 128.07it/s]


Epoch 17/50, Loss: 11910.3352


Epoch Progress: 100%|██████████| 4500/4500 [00:35<00:00, 127.44it/s]


Epoch 18/50, Loss: 11894.4301


Epoch Progress: 100%|██████████| 4500/4500 [00:34<00:00, 128.92it/s]


Epoch 19/50, Loss: 11883.1444


Epoch Progress: 100%|██████████| 4500/4500 [00:35<00:00, 127.02it/s]


Epoch 20/50, Loss: 11903.0044


Epoch Progress: 100%|██████████| 4500/4500 [00:35<00:00, 127.76it/s]


Epoch 21/50, Loss: 11907.3576


Epoch Progress: 100%|██████████| 4500/4500 [00:35<00:00, 126.61it/s]


Epoch 22/50, Loss: 11932.9513


Epoch Progress: 100%|██████████| 4500/4500 [00:35<00:00, 128.13it/s]


Epoch 23/50, Loss: 11977.4264


Epoch Progress: 100%|██████████| 4500/4500 [00:35<00:00, 127.93it/s]


Epoch 24/50, Loss: 11932.9064


Epoch Progress: 100%|██████████| 4500/4500 [00:35<00:00, 127.40it/s]


Epoch 25/50, Loss: 11882.9789


Epoch Progress: 100%|██████████| 4500/4500 [00:35<00:00, 128.25it/s]


Epoch 26/50, Loss: 11924.2368


Epoch Progress: 100%|██████████| 4500/4500 [00:35<00:00, 127.55it/s]


Epoch 27/50, Loss: 11934.3919


Epoch Progress: 100%|██████████| 4500/4500 [00:35<00:00, 127.23it/s]


Epoch 28/50, Loss: 11912.8631


Epoch Progress: 100%|██████████| 4500/4500 [00:35<00:00, 127.73it/s]


Epoch 29/50, Loss: 11892.1836


Epoch Progress: 100%|██████████| 4500/4500 [00:32<00:00, 137.84it/s]


Epoch 30/50, Loss: 11911.1020


Epoch Progress: 100%|██████████| 4500/4500 [00:29<00:00, 154.54it/s]


Epoch 31/50, Loss: 11910.0833


Epoch Progress: 100%|██████████| 4500/4500 [00:29<00:00, 154.66it/s]


Epoch 32/50, Loss: 11858.1227


Epoch Progress: 100%|██████████| 4500/4500 [00:28<00:00, 155.30it/s]


Epoch 33/50, Loss: 11913.5605


Epoch Progress: 100%|██████████| 4500/4500 [00:29<00:00, 153.97it/s]


Epoch 34/50, Loss: 11950.2663


Epoch Progress: 100%|██████████| 4500/4500 [00:29<00:00, 152.68it/s]


Epoch 35/50, Loss: 11907.2889


Epoch Progress: 100%|██████████| 4500/4500 [00:29<00:00, 154.96it/s]


Epoch 36/50, Loss: 11885.6210


Epoch Progress: 100%|██████████| 4500/4500 [00:29<00:00, 153.60it/s]


Epoch 37/50, Loss: 11884.4645


Epoch Progress: 100%|██████████| 4500/4500 [00:29<00:00, 154.37it/s]


Epoch 38/50, Loss: 11898.1574


Epoch Progress: 100%|██████████| 4500/4500 [00:29<00:00, 154.70it/s]


Epoch 39/50, Loss: 11935.5992


Epoch Progress: 100%|██████████| 4500/4500 [00:29<00:00, 154.36it/s]


Epoch 40/50, Loss: 11915.9987


Epoch Progress: 100%|██████████| 4500/4500 [00:30<00:00, 149.79it/s]


Epoch 41/50, Loss: 11843.5355


Epoch Progress: 100%|██████████| 4500/4500 [00:30<00:00, 146.48it/s]


Epoch 42/50, Loss: 11895.1547


Epoch Progress: 100%|██████████| 4500/4500 [00:31<00:00, 143.42it/s]


Epoch 43/50, Loss: 11927.8393


Epoch Progress: 100%|██████████| 4500/4500 [00:31<00:00, 144.45it/s]


Epoch 44/50, Loss: 11905.8646


Epoch Progress: 100%|██████████| 4500/4500 [00:30<00:00, 147.16it/s]


Epoch 45/50, Loss: 11911.1625


Epoch Progress: 100%|██████████| 4500/4500 [00:30<00:00, 148.07it/s]


Epoch 46/50, Loss: 11909.6662


Epoch Progress: 100%|██████████| 4500/4500 [00:31<00:00, 144.56it/s]


Epoch 47/50, Loss: 11916.8967


Epoch Progress: 100%|██████████| 4500/4500 [00:31<00:00, 144.09it/s]


Epoch 48/50, Loss: 11938.4222


Epoch Progress: 100%|██████████| 4500/4500 [00:31<00:00, 145.02it/s]


Epoch 49/50, Loss: 11861.2591


Epoch Progress: 100%|██████████| 4500/4500 [00:30<00:00, 145.52it/s]

Epoch 50/50, Loss: 11933.0255
Training completed.





In [None]:
MODEL_PATH = "DeepGCNModel.pth"

In [None]:
torch.save(model, MODEL_PATH)

In [None]:
loaded_model = torch.load(MODEL_PATH)
loaded_model.eval()

GCN(
  (conv1): GCNConv(139, 500)
  (conv2): GCNConv(500, 500)
  (conv3): GCNConv(500, 500)
  (conv4): GCNConv(500, 500)
  (conv5): GCNConv(500, 500)
  (conv6): GCNConv(500, 500)
  (conv7): GCNConv(500, 500)
  (conv8): GCNConv(500, 500)
  (lin): Linear(in_features=500, out_features=10, bias=True)
)

In [None]:
def test(model, test_loader):
    model.eval()  # Set model to evaluation mode
    correct = 0
    total = 0
    with torch.no_grad():
        for batch in tqdm(test_loader, desc="Testing Progress"):
            batch = batch.to(device)
            out = model(batch.x, batch.edge_index, batch.batch)  # Forward pass            
            _, predicted = torch.max(out, 0)  # Get the index of the max log-probability
            total += batch.y.size(0)  # Total number of graphs
            correct += (predicted == batch.y).sum().item()  # Correct predictions

    accuracy = correct / total
    return accuracy

test_accuracy = test(loaded_model, test_dataset)
print(f"Test Accuracy: {test_accuracy * 100:.2f}%")

Testing Progress: 100%|██████████| 500/500 [00:01<00:00, 263.18it/s]

Test Accuracy: 0.00%



