# Graph Attention Networks(GAT) tutorial

## 01. imports

In [1]:
import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F

torch.manual_seed(42)  # seed for reproducible numbers

<torch._C.Generator at 0x7f098ee43f50>

## 02. GAT class

In [3]:
class GATLayer(nn.Module):
    """
    Simple PyTorch Implementation of the Graph Attention layer.
    """
    
    def __init__(self, in_features, out_features, dropout, alpha, concat=True):
        super(GATLayer, self).__init__()
        self.dropout = dropout  # drop prob = 0.6
        self.in_features = in_features
        self.out_features = out_features
        self.alpha = alpha  # LeakyReLU with negative slope, alpha=0.2
        self.concat = concat  # concat = True for all layers except the output layer
        
        # xavier initialization of weights
        # alternatively use wieghts_init to apply weights of choice
        self.W = nn.Parameter(torch.zeros(size=(in_features, out_features)))
        nn.init.xavier_uniform_(self.W.data, gain=1.414)
        self.a = nn.Parameter(torch.zeros(size=(2*out_features, 1)))
        nn.init.xavier_uniform_(self.a.data, gain=1.414)
        
        # LeakyReLU
        self.leakyrelu = nn.LeakyReLU(self.alpha)
        
        
    def forward(self, inputs, adj):
        # Linear Transformation
        h = torch.mm(inputs, self.W)
        N = h.size()[0]
        
        # Attention Mechanism
        a_input = torch.cat(
            [h.repeat(1, N).view(N*N, -1), h.repeat(N, 1)], dim=1
        ).view(N, -1, 2*self.out_features)
        e = self.leakyrelu(torch.matmul(a_input, self.a).squeeze(2))
        
        # Masked Attention
        zero_vec = -9e15 * torch.ones_like(e)
        attention = torch.where(adj > 0, e, zero_vec)
        
        attention = F.softmax(attention, dim=1)
        attention = F.dropout(attention, self.dropout, trainint=self.training)
        h_prime = torch.matmul(attention, h)
        
        if self.concat:
            return F.elu(h_prime)
        else:
            return h_prime

## 03. PyG imports

In [4]:
%matplotlib notebook
import matplotlib.pyplot as plt

import warnings
warnings.filterwarnings("ignore")

In [1]:
import torch_geometric.transforms as T

from torch_geometric.nn import GATConv

from torch_geometric.data import Data
from torch_geometric.datasets import Planetoid

### 1) Dataset

In [2]:
data_name = 'Cora'
dataset = Planetoid(root=f'./data/{data_name}', name=data_name)
dataset.transform = T.NormalizeFeatures()

print(f"Number of Classes in {data_name}:", dataset.num_classes)
print(f"Number of Node Features in {data_name}:", dataset.num_node_features)

Number of Classes in Cora: 7
Number of Node Features in Cora: 1433


In [3]:
data = dataset[0]

In [4]:
data.x.size()

torch.Size([2708, 1433])

In [8]:
data.edge_index.size()

torch.Size([2, 10556])

### 2) Model

In [None]:
class GAT(torch.nn.Module):
    
    def __init__(self):
        super(GAT, self).__init__()
        self.hid = 8
        self.in_head = 8
        self.out_head = 1
        
        self.conv1 = GATConv()