# Graph Attention Network using DGL

## imports

In [35]:
import json
import numpy as np
import networkx as nx

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable

from sentence_transformers import models
from sentence_transformers import SentenceTransformer

from dgl import DGLGraph
from sklearn.metrics import pairwise_distances

## data load

### 1) Sentence data

In [2]:
%%time

data_path = '../../data/summary/data/train.json'
with open(data_path, 'r') as f:
    data = [json.loads(line) for line in f]

CPU times: user 2.92 s, sys: 355 ms, total: 3.27 s
Wall time: 3.27 s


In [3]:
sample = data[0]

text = sample['doc']
summary = sample['summaries']
labels = sample['labels']
labels = labels.split('\n')
labels = [int(label) for label in labels]

sentences = text.split('\n')

### 2) node-feature matrix

In [4]:
embedder = SentenceTransformer('bert-base-nli-stsb-mean-tokens')
features = embedder.encode(sentences)

### 3) adjacency matrix

In [5]:
threshold = 0.2

cosine_matrix = 1 - pairwise_distances(features, metric="cosine")
adj_matrix = (cosine_matrix > threshold) * 1

### 4) create Graph using Networkx

In [6]:
nx_g = nx.from_numpy_matrix(adj_matrix)

In [7]:
g = DGLGraph()
g.from_networkx(nx_g)

## GAT architecture

In [8]:
class GATLayer(nn.Module):
    def __init__(self, g, in_dim, out_dim):
        super(GATLayer, self).__init__()
        self.g = g
        # equation (1)
        self.fc = nn.Linear(in_dim, out_dim, bias=False)
        # equation (2)
        self.attn_fc = nn.Linear(2 * out_dim, 1, bias=False)

    def edge_attention(self, edges):
        # edge UDF for equation (2)
        z2 = torch.cat([edges.src['z'], edges.dst['z']], dim=1)
        a = self.attn_fc(z2)
        return {'e': F.leaky_relu(a)}

    def message_func(self, edges):
        # message UDF for equation (3) & (4)
        return {'z': edges.src['z'], 'e': edges.data['e']}

    def reduce_func(self, nodes):
        # reduce UDF for equation (3) & (4)
        # equation (3)
        alpha = F.softmax(nodes.mailbox['e'], dim=1)
        # equation (4)
        h = torch.sum(alpha * nodes.mailbox['z'], dim=1)
        return {'h': h}

    def forward(self, h):
        # equation (1)
        z = self.fc(h)
        self.g.ndata['z'] = z
        # equation (2)
        self.g.apply_edges(self.edge_attention)
        # equation (3) & (4)
        self.g.update_all(self.message_func, self.reduce_func)
        return self.g.ndata.pop('h')
    
    
class MultiHeadGATLayer(nn.Module):
    def __init__(self, g, in_dim, out_dim, num_heads, merge='cat'):
        super(MultiHeadGATLayer, self).__init__()
        self.heads = nn.ModuleList()
        for i in range(num_heads):
            self.heads.append(GATLayer(g, in_dim, out_dim))
        self.merge = merge

    def forward(self, h):
        head_outs = [attn_head(h) for attn_head in self.heads]
        if self.merge == 'cat':
            # concat on the output feature dimension (dim=1)
            return torch.cat(head_outs, dim=1)
        else:
            # merge using average
            return torch.mean(torch.stack(head_outs))

In [65]:
class GATClassifier(nn.Module):
    def __init__(self, g, in_dim, hidden_dim, out_dim, num_heads, num_classes=2):
        super().__init__()
        
        self.hidden_dim = hidden_dim
        self.out_dim = out_dim
        
        self.layer1 = MultiHeadGATLayer(g, in_dim, hidden_dim, num_heads)
        # Be aware that the input dimension is hidden_dim*num_heads since
        # multiple head outputs are concatenated together. Also, only
        # one attention head in the output layer.
        self.layer2 = MultiHeadGATLayer(g, hidden_dim * num_heads, out_dim, 1)
        self.lstm = nn.LSTM(out_dim, 32, 1, batch_first=True, bidirectional=False)
        self.fc = nn.Linear(32, num_classes)
        
    
    def init_hidden(self, batch_size):
        # (num_layers * num_directions, batch_size, hidden_size)
        hidden = Variable(torch.zeros(1, batch_size, 32))
        cell = Variable(torch.zeros(1, batch_size, 32))
        return hidden, cell
    

    def forward(self, h):
        h = self.layer1(h)
        h = F.elu(h)
        h = self.layer2(h)
        h = h.view(-1, 21, self.out_dim)
        
        h_0, cell = self.init_hidden(h.size(0))  # initial h_0
        
        output, h_n = self.lstm(h, (h_0, cell))
        
        # many-to-many
        output = self.fc(output)
        
        return output

In [59]:
features = np.array(features)
features = torch.from_numpy(features)

In [66]:
net = GATClassifier(g,
                    in_dim=features.size()[1],
                    hidden_dim=128,
                    out_dim=64,
                    num_heads=2,
                    num_classes=2)

In [67]:
output = net(features)

In [68]:
output.shape

torch.Size([1, 21, 2])

In [69]:
output

tensor([[[-0.0786, -0.1580],
         [-0.0872, -0.1576],
         [-0.0927, -0.1620],
         [-0.0959, -0.1659],
         [-0.0960, -0.1728],
         [-0.0969, -0.1738],
         [-0.0979, -0.1701],
         [-0.0976, -0.1710],
         [-0.0942, -0.1739],
         [-0.0977, -0.1730],
         [-0.0968, -0.1771],
         [-0.0955, -0.1760],
         [-0.0977, -0.1744],
         [-0.0997, -0.1729],
         [-0.0999, -0.1725],
         [-0.0995, -0.1737],
         [-0.0995, -0.1735],
         [-0.0979, -0.1780],
         [-0.0979, -0.1769],
         [-0.0962, -0.1813],
         [-0.0952, -0.1839]]], grad_fn=<AddBackward0>)