<a href="https://colab.research.google.com/github/Lua-Nova/Modern-GAP-GNN/blob/main/ModernGAP.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
if torch.cuda.is_available():
  #NVIDIA GPU version

  !pip install torch-scatter torch-sparse torch-cluster torch-spline-conv torch-geometric -f f'https://data.pyg.org/whl/torch-1.12.0+{cutorch.version.cuda.replace('.','')}.html'
else:
  #CPU version
  !pip install torch-scatter torch-sparse torch-cluster torch-spline-conv torch-geometric -f https://data.pyg.org/whl/torch-1.12.0+cpu.html

device = "cuda" if torch.cuda.is_available() else "cpu"

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Looking in links: https://data.pyg.org/whl/torch-1.12.0+cpu.html
Collecting torch-scatter
  Downloading https://data.pyg.org/whl/torch-1.12.0%2Bcpu/torch_scatter-2.0.9-cp37-cp37m-linux_x86_64.whl (286 kB)
[K     |████████████████████████████████| 286 kB 4.0 MB/s 
[?25hCollecting torch-sparse
  Downloading https://data.pyg.org/whl/torch-1.12.0%2Bcpu/torch_sparse-0.6.15-cp37-cp37m-linux_x86_64.whl (641 kB)
[K     |████████████████████████████████| 641 kB 36.9 MB/s 
[?25hCollecting torch-cluster
  Downloading https://data.pyg.org/whl/torch-1.12.0%2Bcpu/torch_cluster-1.6.0-cp37-cp37m-linux_x86_64.whl (311 kB)
[K     |████████████████████████████████| 311 kB 49.0 MB/s 
[?25hCollecting torch-spline-conv
  Downloading https://data.pyg.org/whl/torch-1.12.0%2Bcpu/torch_spline_conv-1.2.1-cp37-cp37m-linux_x86_64.whl (121 kB)
[K     |████████████████████████████████| 121 kB 46.1 MB/s 
[?25hCo

In [None]:
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
from torch_geometric.nn import Sequential, GCNConv

## Encoder Module

In [None]:
torch.manual_seed(11)
# create classes for layers that are used a lot to avoid repeating code

class MLP(nn.Module):
  # e.g. dimensions = [50,40,30,20]
    def __init__(self, dimensions):
        super().__init__()
        self.flatten = nn.Flatten()
        layers = []
        for i in range(len(dimensions)-1):
          layers.append(nn.Linear(dimensions[i], dimensions[i+1]))
          layers.append(nn.ReLU(inplace=True))

        self.linear_relu_stack = nn.Sequential(*layers)

    def forward(self, x):
        x = self.flatten(x)
        logits = self.linear_relu_stack(x)
        return logits

## PMA

In [None]:
class PMA(nn.Module):
    # A - adjacency matrix     TODO: this should not be given to the module itself, it should access it in training (or from the graph dataset)
    # num_hops - the number of hops covered by this GNN
    def __init__(self, A, num_hops):
        super().__init__()
        # TODO: Figure out if you should tranpose this
        self.A_transpose = torch.transpose(A, 0,1)
        self.num_hops = num_hops
    
    def forward(self, x):
        out = [torch.nn.functional.normalize(x, dim=1)]
        for k in range(self.num_hops):
            aggr = torch.mm(self.A_transpose, out[-1])
            # TODO: noise it up
            noised = aggr
            normalized = torch.nn.functional.normalize(noised, dim=1)
            out.append(normalized)
        return torch.stack(out)

In [None]:
# TEMP CODE
smoothing = 0.2
A = torch.tensor([[1.,smoothing,smoothing],
                  [smoothing,1.,smoothing],
                  [smoothing,smoothing,1.]])
x = torch.tensor([[1.,0.,0.],[0.,1.,0.],[0.,0.,1.]])
pma = PMA(A, 10)
tensor = pma(x)
print(tensor.shape)
tensor = tensor.cpu().numpy()

# plt.figure(figsize=(16,7))
# plt.imshow(tensor)
# plt.show()
        # [encoder, pma, element_wise_mlp, combine, mlp]


torch.Size([11, 3, 3])


## Classification Module
NOTE: 

MLP base: The first MLP in the cassification module. 

MLP head: The last MLP and takes the combined output of all MLP base.

In [None]:
class Classification(nn.Module):
    # num_hops - the number of hops covered by this GNN
    # encoder_dimensions - the MLP dimensions of each base MLP
    # head_dimensions - the dimensions of the head MLP
    def __init__(self, num_hops, encoder_dimensions, head_dimensions):
        super().__init__()
        self.base_mlps = nn.ModuleList()
        for i in range(num_hops+1):
          self.base_mlps.append(MLP(encoder_dimensions))
        self.head_mlp = MLP(head_dimensions) # TODO: should this be softmax? I think we add a softmax for classification tasks. We can test if it works better
    
    def forward(self, cache):
        # forward through bases
        out = []
        for i in range(len(self.base_mlps)):
          encoding = self.base_mlps[i](cache[i,:,:])
          out.append(encoding) # add corresponding encoding
        # combine (use concatenation)
        combined_x = torch.cat(out, dim=1)
        # forward through head
        return self.head_mlp(combined_x)

In [None]:
class GAP(nn.Module):
  # encoder - pretrained encoder module
  # pma - PMA module
  # classification - classification module
  def __init__(self, encoder, pma, classification): # TODO: decide whether we should recieve the models as parameters
    super().__init__()
    self.encoder = encoder
    self.pma = pma
    self.classification = classification

  def forward(self, x):
    # initial node encoding
    x_encoded = self.encoder(x)
    # aggregation module
    cache = self.pma(x_encoded) 
    # classification
    return self.classification(cache) 


## Train/Test


In [None]:
# train
def train(dataset, model, loss_fn, optimizer): 
    # make this into dataloader using backup
    size = dataset['x'].size()[1]
    model.train()
    X, y = dataset['x'], dataset['y']
    X, y = X.to(device), y.to(device)

    # Compute prediction error
    pred = model(X)
    loss = loss_fn(pred, y)

    # Backpropagation
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

# test
def test(dataset, model, loss_fn):
    size = dataset['x'].size()[1]
    model.eval()
    test_loss, correct = 0, 0
    with torch.inference_mode():
        X, y = dataset['x'], dataset['y']
        X, y = X.to(device), y.to(device)
        pred = model(X)
        test_loss += loss_fn(pred, y).item()
        correct += (pred.argmax(1) == y).type(torch.float).sum().item()
    correct /= size
    print(f"Test Error: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")

## Data

In [None]:
from torch_geometric.datasets import KarateClub
from torch_geometric.loader import DataLoader
dataset = KarateClub()[0]
num_classes = torch.unique(dataset['y']).size()[0]
# loader = DataLoader(dataset, batch_size=len(dataset), shuffle=True)
A = torch.zeros((dataset['x'].size()[1], dataset['x'].size()[1]), dtype=torch.float)
# since we are using an adjacency matrix instead of edgelist, make that
for i in range(dataset['edge_index'].size()[1]):
  src, dst = dataset['edge_index'][0, i], dataset['edge_index'][1, i]
  # since undirected
  A[src, dst] = 1
  A[dst, src] = 1

In [12]:
from torch_geometric.datasets import Reddit
from torch_geometric.loader import DataLoader
dataset = Reddit('.')
# throw away classes with less than 10k nodes

tensor(30)

In [21]:
y = dataset.data['y']
y_per_class = []
mask_per_class = []
for label in range(dataset.num_classes):
    mask = y == label
    mask_per_class.append(mask)
    y_per_class.append(y[mask])
    print(label, y_per_class[label].size())

0 torch.Size([13101])
1 torch.Size([3550])
2 torch.Size([3302])
3 torch.Size([15181])
4 torch.Size([2322])
5 torch.Size([3597])
6 torch.Size([3952])
7 torch.Size([2138])
8 torch.Size([11187])
9 torch.Size([2246])
10 torch.Size([4928])
11 torch.Size([2964])
12 torch.Size([1696])
13 torch.Size([2731])
14 torch.Size([4854])
15 torch.Size([28272])
16 torch.Size([1003])
17 torch.Size([2639])
18 torch.Size([13999])
19 torch.Size([10308])
20 torch.Size([1596])
21 torch.Size([4066])
22 torch.Size([8222])
23 torch.Size([12146])
24 torch.Size([328])
25 torch.Size([1659])
26 torch.Size([4239])
27 torch.Size([5962])
28 torch.Size([4673])
29 torch.Size([5101])
30 torch.Size([2846])
31 torch.Size([4570])
32 torch.Size([1575])
33 torch.Size([4960])
34 torch.Size([3429])
35 torch.Size([4202])
36 torch.Size([4180])
37 torch.Size([4233])
38 torch.Size([12797])
39 torch.Size([3099])
40 torch.Size([5112])


## Encoder

Encoder Design


In [None]:
# encoder
dimensions = [34, 17, 5]
encoder_train = nn.Sequential(
    MLP(dimensions),
    nn.Linear(dimensions[-1],num_classes),
    nn.Softmax(dim=1)
)

Encoder Pretraining

In [None]:
encoder_model = encoder_train.to(device)
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(encoder_model.parameters(), lr=1e-3)

epochs = 10
for t in range(epochs):
    # print(f"Epoch {t+1}\n-------------------------------")
    train(dataset, encoder_model, loss_fn, optimizer)
test(dataset, encoder_model, loss_fn)
print("Done!")

encoder = encoder_model[0]

# for name, param in encoder_model.named_parameters():
#     if param.requires_grad:
#         print(name, param.data)

Test Error: 
 Accuracy: 35.3%, Avg loss: 1.383978 

Done!


Train full model

In [None]:
gap = GAP(encoder, PMA(A, 5), Classification(5, [5, 5, 4], [24, 12, 4]))
gap_model = gap.to(device)
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(gap_model.parameters(), lr=1e-4)

epochs = 100
for t in range(epochs):
    # print(f"Epoch {t+1}\n-------------------------------")
    train(dataset, gap_model, loss_fn, optimizer)
test(dataset, gap_model, loss_fn)
print("Done!")

Test Error: 
 Accuracy: 38.2%, Avg loss: 1.364594 

Done!


## Backup

In [None]:
# # train
# def train(dataloader, model, loss_fn, optimizer, print_every = 100):
#     size = len(dataloader.dataset)
#     model.train()
#     for batch, (X, y) in enumerate(dataloader):
#         X, y = X.to(device), y.to(device)

#         # Compute prediction error
#         pred = model(X)
#         loss = loss_fn(pred, y)

#         # Backpropagation
#         optimizer.zero_grad()
#         loss.backward()
#         optimizer.step()

#         if batch % print_every == 0:
#             loss, current = loss.item(), batch * len(X)
#             print(f"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]")

# # test
# def test(dataloader, model, loss_fn):
#     size = len(dataloader.dataset)
#     num_batches = len(dataloader)
#     model.eval()
#     test_loss, correct = 0, 0
#     with torch.inference_mode():
#         for X, y in dataloader:
#             X, y = X.to(device), y.to(device)
#             pred = model(X)
#             test_loss += loss_fn(pred, y).item()
#             correct += (pred.argmax(1) == y).type(torch.float).sum().item()
#     test_loss /= num_batches
#     correct /= size
#     print(f"Test Error: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")