# Extractive Summarization with DataLoader

## imports

In [1]:
import json
import numpy as np
import networkx as nx

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.autograd import Variable
from torch.utils.data import Dataset, DataLoader, random_split

from sentence_transformers import models
from sentence_transformers import SentenceTransformer

# from torch_geometric.data import Data, DataLoader
from torch_geometric.nn import GATConv

from sklearn.metrics import pairwise_distances

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cuda')

## collate_fn

In [3]:
def generate_batch(batch):
    docs = [entry[0] for entry in batch]
    doc_labels = [entry[1] for entry in batch]
    
    max_len = max([len(sents) for sents in docs])
    pad_docs, pad_doc_labels = [], [] 
    
    for sents, labels in zip(docs, doc_labels):
        tmp_len = len(sents)
        pad_doc, pad_label = [], []
        for idx in range(max_len):
            if idx < tmp_len:
                pad_doc.append(sents[idx])
                pad_label.append(int(labels[idx]))
            else:
                pad_doc.append('')
                pad_label.append(0)
        
        pad_docs.append(pad_doc)
        pad_doc_labels.append(pad_label)
        
    return pad_docs, pad_doc_labels

In [4]:
class SummaryDataset(Dataset):
    
    def __init__(self, path):
        
        with open(path, 'r') as f:
            self.data = [json.loads(line) for line in f]
        
    def __len__(self):
        """Returns the number of data."""
        return len(self.data)
    
    def __getitem__(self, idx):
        sentences = self.data[idx]['doc'].split('\n')
        labels = self.data[idx]['labels'].split('\n')
        labels = [int(label) for label in labels]
        
        return sentences, labels

## dataset split (train, valid, test)

In [5]:
data_path = '../../data/summary/data/train.json'

dataset = SummaryDataset(data_path)

In [6]:
train_size = int(0.6 * len(dataset))
valid_size = int(0.2 * len(dataset))
test_size = len(dataset) - (train_size + valid_size)

train_dataset, valid_dataset, test_dataset \
    = random_split(dataset, [train_size, valid_size, test_size])

In [7]:
train_dataloader = DataLoader(train_dataset, 
                              batch_size=1, 
                              shuffle=True, 
                              collate_fn=generate_batch)

valid_dataloader = DataLoader(valid_dataset, 
                              batch_size=1, 
                              shuffle=True,
                              collate_fn=generate_batch)

test_dataloader = DataLoader(test_dataset, 
                             batch_size=1, 
                             shuffle=False,
                             collate_fn=generate_batch)

In [17]:
for batch in train_dataloader:
    data, labels = batch
    break

In [19]:
# data

In [10]:
# sample

## Extractive Summarization architecture

### 1) GAT Classifier

In [11]:
class GATClassifier(nn.Module):
    def __init__(self, in_dim, hidden_dim, out_dim, num_heads, num_classes=2):
        super().__init__()
        
        self.out_head = 1
        self.out_dim = out_dim
        
        self.conv1 = GATConv(in_dim, hidden_dim, heads=num_heads, dropout=0.6)
        self.conv2 = GATConv(hidden_dim * num_heads, out_dim, concat=False,
                             heads=self.out_head, dropout=0.6)
        
        self.lstm = nn.LSTM(out_dim, 32, 1, batch_first=True, bidirectional=False)
        self.fc = nn.Linear(32, num_classes)
        
    
    def init_hidden(self, batch_size):
        # (num_layers * num_directions, batch_size, hidden_size)
        hidden = Variable(torch.zeros(1, batch_size, 32))
        cell = Variable(torch.zeros(1, batch_size, 32))
        return hidden, cell
    

    def forward(self, features, edge_index):
        x = F.dropout(features, p=0.6, training=True)
        x = self.conv1(x, edge_index)
        x = F.elu(x)
        x = F.dropout(x, p=0.6, training=True)
        x = self.conv2(x, edge_index)
        x = x.view(-1, x.size(0), self.out_dim)
        
        h_0, cell = self.init_hidden(x.size(0))  # initial h_0
        
        output, h_n = self.lstm(x, (h_0, cell))
        
        # many-to-many
        output = self.fc(output)
        
        return output

### 2) Summarizer

In [12]:
class Summarizer(nn.Module):
    
    def __init__(self, 
                 in_dim, 
                 hidden_dim, 
                 out_dim, 
                 num_heads, 
                 num_classes=2):
        super(Summarizer, self).__init__()
        
        self.embedder = SentenceTransformer('bert-base-nli-stsb-mean-tokens')
        self.gat_classifier = GATClassifier(in_dim, hidden_dim, out_dim, num_heads, num_classes)
        
    def build_graph(self, features, threshold=0.2):
        cosine_matrix = 1 - pairwise_distances(features, metric="cosine")
        adj_matrix = (cosine_matrix > threshold) * 1
        
        G = nx.from_numpy_matrix(adj_matrix)
        
        e1_list = [e1 for e1, _ in list(G.edges)]
        e2_list = [e2 for _, e2 in list(G.edges)]
        edge_index = [e1_list, e2_list]
        edge_index = torch.tensor(edge_index)
        
        return edge_index
    
    def forward(self, sents):
        features = self.embedder.encode(sents[0])
        features = np.array(features)
        
        edge_index = self.build_graph(features)
        features = torch.from_numpy(features)
        
        output = self.gat_classifier(features, edge_index)
        return output

In [13]:
model = Summarizer(in_dim=768,
                   hidden_dim=128,
                   out_dim=64,
                   num_heads=2,
                   num_classes=1)

# model = model.to(device)

In [14]:
criterion = nn.BCEWithLogitsLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

In [15]:
model.train()
for batch_idx, (data, labels) in enumerate(train_dataloader):
    labels = torch.tensor(labels, dtype=torch.float)
#     labels = labels.to(device)
    
    optimizer.zero_grad()
    logits = model(data)
    
    labels = labels.view(-1, logits.size()[1])
    logits = logits.view(-1, logits.size()[1])

    loss = criterion(logits, labels)
    loss.backward()
    optimizer.step()

KeyboardInterrupt: 