<a href="https://colab.research.google.com/github/WhatRaSudeep/SAiDL-Spring-Assignment-2024/blob/main/GAT.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import os
import torch
os.environ['TORCH'] = torch.__version__
print(torch.__version__)

!pip install -q torch-scatter -f https://data.pyg.org/whl/torch-${TORCH}.html
!pip install -q torch-sparse -f https://data.pyg.org/whl/torch-${TORCH}.html
!pip install -q git+https://github.com/pyg-team/pytorch_geometric.git


2.4.1+cu121
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m108.0/108.0 kB[0m [31m6.1 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
  Building wheel for torch-scatter (setup.py) ... [?25l[?25hdone
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m210.0/210.0 kB[0m [31m12.4 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
  Building wheel for torch-sparse (setup.py) ... [?25l[?25hdone
  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
  Building wheel for torch-geometric (pyproject.toml) ... [?25l[?25hdone


In [2]:
import torch
from torch.nn import Linear, Parameter
from torch_geometric.nn import MessagePassing
from torch_geometric.utils import add_self_loops, degree, to_dense_adj, softmax
import torch_scatter


class GATLayer(MessagePassing):
    def __init__(self, in_features, out_features, dropout =0.6, alpha = 0.2, concat=True):
        super(GATLayer, self).__init__()
        self.dropout       = dropout        # drop prob = 0.6
        self.in_features   = in_features    #
        self.out_features  = out_features   #
        self.alpha         = alpha          # LeakyReLU with negative input slope, alpha = 0.2
        self.concat        = concat         # conacat = True for all layers except the output layer.


        # Xavier Initialization of Weights
        # Alternatively use weights_init to apply weights of choice
        self.W = nn.Parameter(torch.zeros(size=(in_features, out_features)))
        nn.init.xavier_uniform_(self.W.data, gain=1.414)

        self.a = nn.Parameter(torch.zeros(size=(2*out_features, 1)))
        nn.init.xavier_uniform_(self.a.data, gain=1.414)

        # LeakyReLU
        self.leakyrelu = nn.LeakyReLU(self.alpha)

    def forward(self, input, adj):
        # Linear Transformation
        h = torch.mm(input, self.W) # matrix multiplication
        N = h.size()[0]
        adj = to_dense_adj(adj)
        print(adj)
        # Attention Mechanism
        a_input = torch.cat([h.repeat(1, N).view(N * N, -1), h.repeat(N, 1)], dim=1).view(N, -1, 2 * self.out_features)
        e       = self.leakyrelu(torch.matmul(a_input, self.a).squeeze(2))

        # Masked Attention
        zero_vec  = -9e15*torch.ones_like(e)
        attention = torch.where(adj > 0, e, zero_vec)

        attention = F.softmax(attention, dim=1)
        attention = F.dropout(attention, self.dropout, training=self.training)
        h_prime   = torch.matmul(attention, h)

        if self.concat:
            return F.elu(h_prime)
        else:
            return h_prime


In [7]:
class GAT(MessagePassing):

    def __init__(self, in_channels, out_channels, heads = 1,
                 negative_slope = 0.2, dropout = 0., **kwargs):
        super(GAT, self).__init__(node_dim=0, **kwargs)

        self.in_channels = in_channels
        self.out_channels = out_channels
        self.heads = heads
        self.negative_slope = negative_slope
        self.dropout = dropout

        self.lin_l = None
        self.lin_r = None
        self.att_l = None
        self.att_r = None

        self.lin_l = Linear(in_channels, out_channels*self.heads)
        self.lin_r = self.lin_l
        self.att_l = Parameter(torch.Tensor(1, self.heads, out_channels))
        self.att_r = Parameter(torch.Tensor(1, self.heads, out_channels))
        self.reset_parameters()

    def reset_parameters(self):
        nn.init.xavier_uniform_(self.lin_l.weight)
        nn.init.xavier_uniform_(self.lin_r.weight)
        nn.init.xavier_uniform_(self.att_l)
        nn.init.xavier_uniform_(self.att_r)

    def forward(self, x, edge_index, size = None):

        H, C = self.heads, self.out_channels

        wh_l = self.lin_l(x).view(-1, H, C)
        wh_r = self.lin_r(x).view(-1, H, C)
        alpha_l = torch.mul(self.att_l, wh_l)
        alpha_r = torch.mul(self.att_r, wh_r)
        out = self.propagate(edge_index,x=(wh_l, wh_r), size=size, alpha=(alpha_l, alpha_r))
        out = out.view(-1, H*C)
        return out


    def message(self, x_j, alpha_j, alpha_i, index, ptr, size_i):

        att = alpha_i + alpha_j

        att = F.leaky_relu(att, negative_slope=self.negative_slope)
        att = softmax(att, ptr if ptr else index)
        att = F.dropout(att, self.dropout)
        out = torch.mul(x_j, att)
        return out

    def aggregate(self, inputs, index, dim_size = None):
        out = torch_scatter.scatter(inputs, index = index, dim = self.node_dim, dim_size = dim_size, reduce = "sum")

        return out



In [4]:
from torch import nn
from torch_geometric.datasets import Planetoid

dataset = Planetoid(root='/tmp/Cora', name='Cora')
import torch.nn.functional as F
class GNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = GAT(dataset.num_node_features, 16)
        self.conv2 = GAT(16, dataset.num_classes)

    def forward(self, data):
        x, edge_index = data.x, data.edge_index

        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = F.dropout(x, training=self.training)
        x = self.conv2(x, edge_index)
        x = F.log_softmax(x, dim=1)
        return x

model = GNN()
print(model)

Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.x
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.tx
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.allx
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.y
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.ty
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.ally
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.graph
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.test.index
Processing...


GNN(
  (conv1): GAT(1433, 16)
  (conv2): GAT(16, 7)
)


Done!


In [5]:
data = dataset[0]
optimizer = torch.optim.Adam(model.parameters(), lr=1e-2, weight_decay= 5e-4)
model.train()
model.train()
for epoch in range(200):
    optimizer.zero_grad()
    out = model(data)
    loss = F.nll_loss(out[data.train_mask], data.y[data.train_mask])
    loss.backward()
    optimizer.step()
    print(f'Epoch: {epoch:03d}, Loss: {loss:.4f}')

Epoch: 000, Loss: 1.9438
Epoch: 001, Loss: 1.8397
Epoch: 002, Loss: 1.7263
Epoch: 003, Loss: 1.5902
Epoch: 004, Loss: 1.4831
Epoch: 005, Loss: 1.3434
Epoch: 006, Loss: 1.2008
Epoch: 007, Loss: 1.1354
Epoch: 008, Loss: 1.0615
Epoch: 009, Loss: 0.9722
Epoch: 010, Loss: 0.8942
Epoch: 011, Loss: 0.8492
Epoch: 012, Loss: 0.7291
Epoch: 013, Loss: 0.7271
Epoch: 014, Loss: 0.6293
Epoch: 015, Loss: 0.5602
Epoch: 016, Loss: 0.5262
Epoch: 017, Loss: 0.5063
Epoch: 018, Loss: 0.4399
Epoch: 019, Loss: 0.4251
Epoch: 020, Loss: 0.4155
Epoch: 021, Loss: 0.3268
Epoch: 022, Loss: 0.3520
Epoch: 023, Loss: 0.2759
Epoch: 024, Loss: 0.2591
Epoch: 025, Loss: 0.2614
Epoch: 026, Loss: 0.2808
Epoch: 027, Loss: 0.2211
Epoch: 028, Loss: 0.1798
Epoch: 029, Loss: 0.2321
Epoch: 030, Loss: 0.2125
Epoch: 031, Loss: 0.1825
Epoch: 032, Loss: 0.2064
Epoch: 033, Loss: 0.2302
Epoch: 034, Loss: 0.1193
Epoch: 035, Loss: 0.1542
Epoch: 036, Loss: 0.1150
Epoch: 037, Loss: 0.1386
Epoch: 038, Loss: 0.1115
Epoch: 039, Loss: 0.1072


In [6]:
model.eval()
pred = model(data).argmax(dim=1)
correct = (pred[data.test_mask] == data.y[data.test_mask]).sum()
acc = int(correct) / int(data.test_mask.sum())
print(f'Accuracy: {acc:.4f}')


Accuracy: 0.7710
