# Extractive Summarization

## imports

In [16]:
import json
import numpy as np
import networkx as nx

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.autograd import Variable

from sentence_transformers import models
from sentence_transformers import SentenceTransformer

from torch_geometric.data import Data
from torch_geometric.nn import GATConv

from sklearn.metrics import pairwise_distances

## data load

### 1) Sentence data

In [2]:
%%time

data_path = '../../data/summary/data/train.json'
with open(data_path, 'r') as f:
    data = [json.loads(line) for line in f]

CPU times: user 2.98 s, sys: 306 ms, total: 3.28 s
Wall time: 3.28 s


In [3]:
data_len = len(data)
data_idx = np.random.randint(data_len)

sample = data[data_idx]

text = sample['doc']
summary = sample['summaries']
labels = sample['labels']
labels = labels.split('\n')
labels = [int(label) for label in labels]

sentences = text.split('\n')

## Extractive Summarization architecture

### 1) GAT Classifier

In [4]:
class GATClassifier(nn.Module):
    def __init__(self, in_dim, hidden_dim, out_dim, num_heads, num_classes=2):
        super().__init__()
        
        self.out_head = 1
        self.out_dim = out_dim
        
        self.conv1 = GATConv(in_dim, hidden_dim, heads=num_heads, dropout=0.6)
        self.conv2 = GATConv(hidden_dim * num_heads, out_dim, concat=False,
                             heads=self.out_head, dropout=0.6)
        
        self.lstm = nn.LSTM(out_dim, 32, 1, batch_first=True, bidirectional=False)
        self.fc = nn.Linear(32, num_classes)
        
    
    def init_hidden(self, batch_size):
        # (num_layers * num_directions, batch_size, hidden_size)
        hidden = Variable(torch.zeros(1, batch_size, 32))
        cell = Variable(torch.zeros(1, batch_size, 32))
        return hidden, cell
    

    def forward(self, features, edge_index):
        x = F.dropout(features, p=0.6, training=True)
        x = self.conv1(x, edge_index)
        x = F.elu(x)
        x = F.dropout(x, p=0.6, training=True)
        x = self.conv2(x, edge_index)
        x = x.view(-1, x.size(0), self.out_dim)
        
        h_0, cell = self.init_hidden(x.size(0))  # initial h_0
        
        output, h_n = self.lstm(x, (h_0, cell))
        
        # many-to-many
        output = self.fc(output)
        
        return output

### 2) Summarizer

In [5]:
class Summarizer(nn.Module):
    
    def __init__(self, 
                 in_dim, 
                 hidden_dim, 
                 out_dim, 
                 num_heads, 
                 num_classes=2):
        super(Summarizer, self).__init__()
        
        self.embedder = SentenceTransformer('bert-base-nli-stsb-mean-tokens')
        self.gat_classifier = GATClassifier(in_dim, hidden_dim, out_dim, num_heads, num_classes)
        
    def build_graph(self, features, threshold=0.2):
        cosine_matrix = 1 - pairwise_distances(features, metric="cosine")
        adj_matrix = (cosine_matrix > threshold) * 1
        
        G = nx.from_numpy_matrix(adj_matrix)
        
        e1_list = [e1 for e1, _ in list(G.edges)]
        e2_list = [e2 for _, e2 in list(G.edges)]
        edge_index = [e1_list, e2_list]
        edge_index = torch.tensor(edge_index)
        
        return edge_index
    
    def forward(self, sents):
        features = self.embedder.encode(sents)
        features = np.array(features)
        
        edge_index = self.build_graph(features)
        features = torch.from_numpy(features)
        
        output = self.gat_classifier(features, edge_index)
        return output

In [6]:
net = Summarizer(in_dim=768,
                 hidden_dim=128,
                 out_dim=64,
                 num_heads=2,
                 num_classes=1)

In [7]:
output = net(sentences)
# output.shape

In [8]:
output

tensor([[[0.0529],
         [0.1494],
         [0.1218],
         [0.0843],
         [0.0522],
         [0.1826],
         [0.1093],
         [0.0701],
         [0.0628],
         [0.0239],
         [0.0327]]], grad_fn=<AddBackward0>)

In [9]:
labels = sample['labels']
labels = labels.split('\n')
labels = [int(label) for label in labels]
labels = torch.tensor(labels)

In [12]:
labels = labels.float()
labels = labels.view(-1, output.size()[1])

In [18]:
criterion = nn.BCEWithLogitsLoss()
optimizer = optim.Adam(net.parameters(), lr=0.001)

In [19]:
optimizer.zero_grad()

output = output.view(-1, output.size()[1])
loss = criterion(output, labels)
loss.backward()
optimizer.step()

In [20]:
loss

tensor(0.7067, grad_fn=<BinaryCrossEntropyWithLogitsBackward>)