# Graph Attention Network using PyG

## imports

In [35]:
import json
import numpy as np
import networkx as nx

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable

from sentence_transformers import models
from sentence_transformers import SentenceTransformer

from dgl import DGLGraph
from sklearn.metrics import pairwise_distances

## data load

### 1) Sentence data

In [2]:
%%time

data_path = '../../data/summary/data/train.json'
with open(data_path, 'r') as f:
    data = [json.loads(line) for line in f]

CPU times: user 2.92 s, sys: 355 ms, total: 3.27 s
Wall time: 3.27 s


In [115]:
# data

In [137]:
sample = data[0]

text = sample['doc']
summary = sample['summaries']
labels = sample['labels']
labels = labels.split('\n')
labels = [int(label) for label in labels]

sentences = text.split('\n')

In [116]:
sample = data[111]

text = sample['doc']
summary = sample['summaries']
labels = sample['labels']
labels = labels.split('\n')
labels = [int(label) for label in labels]

sentences = text.split('\n')

### 2) node-feature matrix

In [138]:
embedder = SentenceTransformer('bert-base-nli-stsb-mean-tokens')
features = embedder.encode(sentences)

### 3) adjacency matrix

In [139]:
threshold = 0.2

cosine_matrix = 1 - pairwise_distances(features, metric="cosine")
adj_matrix = (cosine_matrix > threshold) * 1

### 4) create Graph using Networkx

In [140]:
G = nx.from_numpy_matrix(adj_matrix)

## GAT architecture

In [92]:
from torch_geometric.data import Data
from torch_geometric.nn import GATConv

In [150]:
class GATClassifier(nn.Module):
    def __init__(self, in_dim, hidden_dim, out_dim, num_heads, num_classes=2):
        super().__init__()
        
        self.out_head = 1
        self.out_dim = out_dim
        
        self.conv1 = GATConv(in_dim, hidden_dim, heads=num_heads, dropout=0.6)
        self.conv2 = GATConv(hidden_dim * num_heads, out_dim, concat=False,
                             heads=self.out_head, dropout=0.6)
        
        self.lstm = nn.LSTM(out_dim, 32, 1, batch_first=True, bidirectional=False)
        self.fc = nn.Linear(32, num_classes)
        
    
    def init_hidden(self, batch_size):
        # (num_layers * num_directions, batch_size, hidden_size)
        hidden = Variable(torch.zeros(1, batch_size, 32))
        cell = Variable(torch.zeros(1, batch_size, 32))
        return hidden, cell
    

    def forward(self, features, edge_index):
        x = F.dropout(features, p=0.6, training=True)
        x = self.conv1(x, edge_index)
        x = F.elu(x)
        x = F.dropout(x, p=0.6, training=True)
        x = self.conv2(x, edge_index)
        x = x.view(-1, x.size(0), self.out_dim)
        
        h_0, cell = self.init_hidden(x.size(0))  # initial h_0
        
        output, h_n = self.lstm(x, (h_0, cell))
        
        # many-to-many
        output = self.fc(output)
        
        return output

In [151]:
net = GATClassifier(in_dim=features.size()[1],
                    hidden_dim=128,
                    out_dim=64,
                    num_heads=2,
                    num_classes=2)

In [152]:
features = np.array(features)
features = torch.from_numpy(features)

In [153]:
features.size()

torch.Size([21, 768])

In [154]:
e1_list = [e1 for e1, _ in list(G.edges)]
e2_list = [e2 for _, e2 in list(G.edges)]

edge_index = [e1_list, e2_list]
edge_index = torch.tensor(edge_index)

In [155]:
edge_index.size()

torch.Size([2, 117])

In [156]:
output = net(features, edge_index)

In [157]:
output.shape

torch.Size([1, 21, 2])

In [158]:
output

tensor([[[-0.1372, -0.0718],
         [-0.1486, -0.0827],
         [-0.0085, -0.2377],
         [-0.2707, -0.1346],
         [-0.0291, -0.3249],
         [ 0.1115, -0.1386],
         [ 0.0857, -0.2462],
         [-0.1398, -0.2831],
         [-0.0196, -0.0489],
         [ 0.1478, -0.2591],
         [ 0.0219, -0.2626],
         [-0.0985, -0.2355],
         [-0.1451, -0.2023],
         [-0.1952, -0.2040],
         [-0.2347, -0.1790],
         [-0.1679, -0.1950],
         [-0.1660, -0.1966],
         [-0.1847, -0.1809],
         [-0.1425, -0.2505],
         [ 0.0766, -0.1208],
         [-0.0404, -0.1236]]], grad_fn=<AddBackward0>)