# Graph Attention Network

## Imports

In [2]:
import dgl
import dgl.function as fn
import torch as th
import torch.nn as nn
import torch.nn.functional as F
from dgl import DGLGraph

import pickle
import numpy as np

import itertools

import performance as pf

## GNN Definition

In [3]:
class GATLayer(nn.Module):
    def __init__(self, g, in_dim, out_dim):
        super(GATLayer, self).__init__()
        self.g = g
        # equation (1)
        self.fc = nn.Linear(in_dim, out_dim, bias=False)
        # equation (2)
        self.attn_fc = nn.Linear(2 * out_dim, 1, bias=False)

    def edge_attention(self, edges):
        # edge UDF for equation (2)
        z2 = th.cat([edges.src['z'], edges.dst['z']], dim=1)
        a = self.attn_fc(z2)
        return {'e': F.leaky_relu(a)}

    def message_func(self, edges):
        # message UDF for equation (3) & (4)
        return {'z': edges.src['z'], 'e': edges.data['e']}

    def reduce_func(self, nodes):
        # reduce UDF for equation (3) & (4)
        # equation (3)
        alpha = F.softmax(nodes.mailbox['e'], dim=1)
        # equation (4)
        h = th.sum(alpha * nodes.mailbox['z'], dim=1)
        return {'h': h}

    def forward(self, h):
        # equation (1)
        z = self.fc(h)
        self.g.ndata['z'] = z
        # equation (2)
        self.g.apply_edges(self.edge_attention)
        # equation (3) & (4)
        self.g.update_all(self.message_func, self.reduce_func)
        return self.g.ndata.pop('h')

In [4]:
class MultiHeadGATLayer(nn.Module):
    def __init__(self, g, in_dim, out_dim, num_heads, merge='cat'):
        super(MultiHeadGATLayer, self).__init__()
        self.heads = nn.ModuleList()
        for i in range(num_heads):
            self.heads.append(GATLayer(g, in_dim, out_dim))
        self.merge = merge

    def forward(self, h):
        head_outs = [attn_head(h) for attn_head in self.heads]
        if self.merge == 'cat':
            # concat on the output feature dimension (dim=1)
            return th.cat(head_outs, dim=1)
        else:
            # merge using average
            return th.mean(torch.stack(head_outs))

In [5]:
class Net(nn.Module):
    def __init__(self, g, in_dim, hidden_dim, out_dim, num_heads):
        super(Net, self).__init__()
        self.layer1 = MultiHeadGATLayer(g, in_dim, hidden_dim, num_heads)
        # Be aware that the input dimension is hidden_dim*num_heads since
        # multiple head outputs are concatenated together. Also, only
        # one attention head in the output layer.
        self.layer2 = MultiHeadGATLayer(g, hidden_dim * num_heads, out_dim, 1)

    def forward(self, h):
        h = self.layer1(h)
        h = F.elu(h)
        h = self.layer2(h)
        h = F.log_softmax(h, 1)
        return h

## Data Loading

In [13]:
from dgl.data import citation_graph as citegrh
import networkx as nx

data = citegrh.load_cora()
features = th.FloatTensor(data.features)
labels = th.LongTensor(data.labels)
mask = th.ByteTensor(data.train_mask)
g = data.graph

# add self loop
g.remove_edges_from(nx.selfloop_edges(g))
g = DGLGraph(g)
g.add_edges(g.nodes(), g.nodes())

## Select Training Set

In [14]:
percentage_train = 0.02

with open("data/cora_permutation1.pickle","rb") as f:
    perm1 = pickle.load(f)
mask = np.zeros(g.number_of_nodes())
mask[perm1[range(int(percentage_train*g.number_of_nodes()))]] = 1
mask = th.ByteTensor(mask)

## Training

In [22]:
loss_function = pf.perm_inv_loss(labels)

In [0]:
import time

net = Net(g,
          in_dim=features.shape[1],
          hidden_dim=8,
          out_dim=len(np.unique(labels)),
          num_heads=2)
#print(net)

optimizer = th.optim.Adam(net.parameters(), lr=1e-2, weight_decay=0)
net.train() # Set to training mode (use dropout)

dur = []
for epoch in range(500):
    if epoch >=3:
        t0 = time.time()

    # Compute loss for test nodes (only for validation, not used by optimizer)
    net.eval()
    prediction = net(features)
    train_rand=pf.rand_score(labels[mask].numpy(),np.argmax(prediction[mask].detach().numpy(), axis=1))
    validation_rand=pf.rand_score(labels[1-mask].numpy(),np.argmax(prediction[1-mask].detach().numpy(), axis=1))
    net.train()

    # Compute loss for train nodes
    logits = net(features)

    loss = loss_function.approximate_loss(logits,mask,nclasses=6)

    #loss = F.nll_loss(logits[mask], labels[mask])
    
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    if epoch >=3:
        dur.append(time.time() - t0)

    print(f"Epoch {epoch:05d} | Loss {loss.item():.4f} | Train.Rand {train_rand:.4f} | Valid.Rand {validation_rand:.4f} | Time(s) {np.mean(dur):.4f}")

In [20]:
# Visualise predictions
net.eval() # Set net to evaluation mode (deactivates dropout)
final_prediction = net(g, features).detach()
a = np.transpose(np.vstack([final_prediction[mask].numpy().argmax(axis=1),labels[mask].numpy()]))
a[a[:,0].argsort()][np.random.choice(range(a.shape[0]),size=10)]
# as can be seen, the net predicts other labels, but gets the clusters right :)

array([[0, 6],
       [0, 3],
       [2, 2],
       [2, 2],
       [0, 6],
       [6, 4],
       [4, 0],
       [4, 0],
       [0, 2],
       [0, 4]], dtype=int64)

## Evaluation

In [21]:
net.eval() # Set net to evaluation mode (deactivates dropout)
final_prediction = net(g, features).detach()
pf.performance_as_df(labels,final_prediction,mask)

Unnamed: 0,All,Train,Test
Mutual Information,0.200959,0.668763,0.193395
Rand-Index,0.092219,0.456485,0.088449
Variation of Information,2.655766,0.936472,2.667611
