<a href="https://colab.research.google.com/github/TheoBacqueyrisse/Graph-Neural-Networks/blob/main/GNN_MultiHead_Attention.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# **Graph Neural Network with Multi Attention**

In [1]:
# Let us first clone the GitHub repository
%%capture
!git clone https://github.com/TheoBacqueyrisse/Graph-Neural-Networks.git

In [2]:
# Install dependencies
%%capture
!pip install -r /content/Graph-Neural-Networks/requirements.txt

In [36]:
# Import Packages
import pandas
import numpy as np
import torch
from torch_geometric.datasets import ZINC

# Visualisation
import networkx as nx
import matplotlib.pyplot as plt
import seaborn as sns

# Data Loader
from torch_geometric.loader import DataLoader

# Neural Network Architecture
from torch.nn import Linear, Sequential, ReLU, Sigmoid
import torch.nn.functional as F
from torch_geometric.nn import GCNConv, GATConv, GINConv
from torch_geometric.nn import global_max_pool as gmp, global_mean_pool as gap, global_add_pool as gad
from torch.optim.lr_scheduler import ReduceLROnPlateau

# Loss Function
from torch.nn import MSELoss, L1Loss

# Optimizer
from torch.optim import Adam, SGD, Adagrad

# See the progression of the Training
import tqdm

## Model Architecture

In [38]:
NUM_HEADS = 4
IN_CHANNELS = 1
OUT_CHANNELS = 128

class GNN_MultiHead(torch.nn.Module):
    def __init__(self, num_heads, in_channels, out_channels):
      super(GNN_MultiHead, self).__init__()

      # Care about the design of the NN here
      self.heads = num_heads
      self.convs = torch.nn.ModuleList([
            GINConv(Sequential(
                Linear(in_channels, out_channels // num_heads),
                ReLU(),
                Linear(out_channels // num_heads, out_channels // num_heads)
            ))
            for _ in range(num_heads)
        ])

      self.pooling = gap
      self.out = Linear(in_features = out_channels, out_features = 1)

    def forward(self, x, edge_index, batch_index):

      head_out = []
      for conv in self.convs:
          out = conv(x, edge_index)
          head_out.append(out)

      multi_head_output = torch.cat(head_out, dim=1)

      aggregated_output = self.pooling(multi_head_output, batch_index)

      out = self.out(aggregated_output)

      return out, aggregated_output

model = GNN_MultiHead(NUM_HEADS, IN_CHANNELS, OUT_CHANNELS)
print(model)

GNN_MultiHead(
  (convs): ModuleList(
    (0-3): 4 x GINConv(nn=Sequential(
      (0): Linear(in_features=1, out_features=32, bias=True)
      (1): ReLU()
      (2): Linear(in_features=32, out_features=32, bias=True)
    ))
  )
  (out): Linear(in_features=128, out_features=1, bias=True)
)


## Configuration

In [39]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = model.to(device)

NUM_EPOCHS = 10

loss_function = L1Loss()
optimizer = Adam(params = model.parameters(), lr = 0.01)
scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=20, min_lr=0.00001)

In [31]:
NB_GRAPHS_PER_BATCH = 64

train = ZINC('/content/Graph-Neural-Networks/data', split = 'train')
val = ZINC('/content/Graph-Neural-Networks/data', split = 'val')
test = ZINC('/content/Graph-Neural-Networks/data', split = 'test')

train_loader = DataLoader(train,
                          batch_size = NB_GRAPHS_PER_BATCH,
                          shuffle = True)

val_loader = DataLoader(val,
                        batch_size = NB_GRAPHS_PER_BATCH,
                        shuffle = False)

test_loader = DataLoader(test,
                         batch_size = NB_GRAPHS_PER_BATCH,
                         shuffle = False)

print("Number of Batches in Train Loader :", len(train_loader))
print("Number of Batches in Val Loader :", len(val_loader))
print("Number of Batches in Test Loader :", len(test_loader))

Number of Batches in Train Loader : 3438
Number of Batches in Val Loader : 382
Number of Batches in Test Loader : 79


## Train and Test Functions 🚀

In [40]:
def train_mh(model, optimizer, scheduler, train_loader, val_loader):
    for epoch in range(1, NUM_EPOCHS + 1):
        model.train()
        train_loss = 0.0

        for data in train_loader:
            data = data.to(device)
            optimizer.zero_grad()

            pred, _ = model(data.x.float(), data.edge_index, data.batch)

            loss = loss_function(pred, data.y.view(-1, 1).float())
            loss.backward()

            optimizer.step()
            train_loss += loss.item()

        average_train_loss = train_loss / len(train_loader)

        model.eval()
        val_loss = 0.0

        with torch.no_grad():
            for data in val_loader:
                data = data.to(device)

                pred, _ = model(data.x.float(), data.edge_index, data.batch)
                loss = loss_function(pred, data.y.view(-1, 1).float())

                val_loss += loss.item()

        average_val_loss = val_loss / len(val_loader)

        scheduler.step(average_val_loss)

        print(f"Epoch {epoch}/{NUM_EPOCHS} -> Train Loss: {average_train_loss:.4f} - Val Loss: {average_val_loss:.4f}")


def test_mh(test_loader):
  model.eval()
  with torch.no_grad():
      tot_test_loss = 0.0

      for test_batch in test_loader:
          test_batch.to(device)

          test_pred, test_y = model(test_batch.x.float(), test_batch.edge_index, test_batch.batch)
          test_loss = loss_function(test_pred, test_batch.y.view(-1, 1).float())

          tot_test_loss += test_loss.item()

      average_test_loss = tot_test_loss / len(test_loader)

  print(f"Test Loss: {average_test_loss:.4f}")

## Model Training and Evaluation ⚡

In [41]:
train_mh(model, optimizer, scheduler, train_loader, val_loader)

Epoch 1/10 -> Train Loss: 1.0986 - Val Loss: 1.0698
Epoch 2/10 -> Train Loss: 1.0727 - Val Loss: 1.0939
Epoch 3/10 -> Train Loss: 1.0681 - Val Loss: 1.0533
Epoch 4/10 -> Train Loss: 1.0652 - Val Loss: 1.0459
Epoch 5/10 -> Train Loss: 1.0632 - Val Loss: 1.0510
Epoch 6/10 -> Train Loss: 1.0634 - Val Loss: 1.1129
Epoch 7/10 -> Train Loss: 1.0601 - Val Loss: 1.0473
Epoch 8/10 -> Train Loss: 1.0598 - Val Loss: 1.0434
Epoch 9/10 -> Train Loss: 1.0603 - Val Loss: 1.0451
Epoch 10/10 -> Train Loss: 1.0598 - Val Loss: 1.0515


In [43]:
test_mh(test_loader)

Test Loss: 1.0524
