# Extractive Summarization with Batch

## imports

In [1]:
import json
import numpy as np
import networkx as nx

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.autograd import Variable
from torch.utils.data import Dataset, DataLoader, random_split

import torch_geometric
from torch_geometric.nn import GATConv

from transformers import AlbertTokenizer, AlbertModel

from sklearn.metrics import pairwise_distances

In [2]:
USE_CUDA = torch.cuda.is_available()
DEVICE = torch.device("cuda" if USE_CUDA else "cpu")

In [3]:
DEVICE

device(type='cuda')

## Extractive Summarization architecture

### 1) GAT Classifier

In [4]:
class GATClassifier(nn.Module):
    def __init__(self, in_dim, hidden_dim, out_dim, num_heads, num_classes=1):
        super().__init__()
        
        self.out_head = 1
        self.out_dim = out_dim
        
        self.conv1 = GATConv(in_dim, hidden_dim, heads=num_heads, dropout=0.6)
        self.conv2 = GATConv(hidden_dim * num_heads, out_dim, concat=False,
                             heads=self.out_head, dropout=0.6)
        
        self.lstm = nn.LSTM(out_dim, 32, 1, batch_first=True, bidirectional=False)
        self.fc = nn.Linear(32, num_classes)
        
    
    def init_hidden(self, batch_size):
        # (num_layers * num_directions, batch_size, hidden_size)
        hidden = Variable(torch.zeros(1, batch_size, 32), )
        cell = Variable(torch.zeros(1, batch_size, 32))
        return hidden, cell
    

    def forward(self, data):
        x, edge_index = data.x.to(DEVICE), data.edge_index.to(DEVICE)
        
        x = F.dropout(x, p=0.6, training=True)
        x = self.conv1(x, edge_index)
        x = F.elu(x)
        x = F.dropout(x, p=0.6, training=True)
        x = self.conv2(x, edge_index)
        x = x.view(-1, x.size(0), self.out_dim)
        
        h_0, cell = self.init_hidden(x.size(0))  # initial h_0
        h_0, cell = h_0.to(DEVICE), cell.to(DEVICE)
        output, h_n = self.lstm(x, (h_0, cell))
        
        # many-to-many
        output = self.fc(output)
        
        return output

### 2) Summarizer

In [5]:
class Summarizer(nn.Module):
    
    def __init__(self, 
                 in_dim, 
                 hidden_dim, 
                 out_dim, 
                 num_heads, 
                 num_classes=2):
        super(Summarizer, self).__init__()
        
        self.tokenizer = AlbertTokenizer.from_pretrained('albert-base-v2')
        self.embedder = AlbertModel.from_pretrained('albert-base-v2')
        self.gat_classifier = GATClassifier(in_dim, hidden_dim, out_dim, num_heads, num_classes)

        
    def get_tokenize(self, docs):
        sent_tokens = [
            torch.cat(
                [self.tokenizer.encode(
                        sentences[i],
                        add_special_tokens=True,
                        max_length=64,
                        pad_to_max_length=True,
                        return_tensors='pt'
                 ) for i in range(len(sentences))]
            ) for sentences in docs
        ]

        sent_tokens = torch.cat([*sent_tokens])
        return sent_tokens
    
    def get_sentence_embedding(self, word_vecs, offsets):
        '''get node-featrues(setences embedding)'''
        docs = []
        for idx in range(len(offsets) - 1):
            docs.append(word_vecs[ offsets[idx]: offsets[idx]+offsets[idx+1] ])
        
        features = [torch.mean(doc, dim=1).squeeze() for doc in docs]
        return features
    
    def build_graph(self, features_list, threshold=0.2):
        '''get edge_index for GATLayer'''
        edge_index_list = []
        for features in features_list:
            features = features.cpu()
            cosine_matrix = 1 - pairwise_distances(features.detach().numpy(), metric="cosine")
            adj_matrix = (cosine_matrix > threshold) * 1

            G = nx.from_numpy_matrix(adj_matrix)

            e1_list = [e1 for e1, _ in list(G.edges)]
            e2_list = [e2 for _, e2 in list(G.edges)]
            edge_index = [e1_list, e2_list]
            edge_index = torch.tensor(edge_index)
            edge_index_list.append(edge_index)

        return edge_index_list
    
    def gat_dataloader(self, features_list, edge_index_list, labels_list, batch_size):
        data_list = [
            torch_geometric.data.Data(features, edge_index, y=labels)
                for features, edge_index, labels in zip(features_list, edge_index_list, labels_list)
        ]

        gat_loader = torch_geometric.data.DataLoader(data_list, batch_size=batch_size, shuffle=False)
        return gat_loader
    

    def forward(self, 
                docs, 
                offsets, 
                labels_list, 
                threshold=0.2, 
                batch_size=32):
        
        sent_tokens = self.get_tokenize(docs).to(DEVICE)
        word_vecs = self.embedder(sent_tokens)[0]
        features_list = self.get_sentence_embedding(word_vecs, offsets)
        edge_index_list = self.build_graph(features_list, threshold)
        
        # dataloader for GATLayer
        dataloader = self.gat_dataloader(features_list, edge_index_list, labels_list, batch_size)
        
        output = self.gat_classifier(next(iter(dataloader)))
        return output

In [6]:
net = Summarizer(in_dim=768,
                 hidden_dim=128,
                 out_dim=64,
                 num_heads=2,
                 num_classes=1).to(DEVICE)

## dataloader

### 1) Summary Dataset

In [7]:
def generate_batch(batch):
    docs = [entry[0] for entry in batch]
    labels = [entry[1] for entry in batch]
    
    offsets = [0] + [len(doc) for doc in docs]
        
    return docs, offsets, labels

In [8]:
class SummaryDataset(Dataset):
    
    def __init__(self, path):
        
        with open(path, 'r') as f:
            self.data = [json.loads(line) for line in f]
        
    def __len__(self):
        """Returns the number of data."""
        return len(self.data)
    
    def __getitem__(self, idx):
        sentences = self.data[idx]['doc'].split('\n')
        labels = self.data[idx]['labels'].split('\n')
        labels = [int(label) for label in labels]
        
        return sentences, labels

### 2) dataset split (train, valid, test)

In [9]:
data_path = '../../data/summary/data/train.json'

dataset = SummaryDataset(data_path)

In [10]:
train_size = int(0.6 * len(dataset))
valid_size = int(0.2 * len(dataset))
test_size = len(dataset) - (train_size + valid_size)

train_dataset, valid_dataset, test_dataset \
    = random_split(dataset, [train_size, valid_size, test_size])

### 3) Dataloader for Summarizer

In [11]:
train_dataloader = DataLoader(train_dataset, 
                              batch_size=8, 
                              shuffle=False, 
                              collate_fn=generate_batch)

valid_dataloader = DataLoader(valid_dataset, 
                              batch_size=8, 
                              shuffle=False,
                              collate_fn=generate_batch)

test_dataloader = DataLoader(test_dataset, 
                             batch_size=8, 
                             shuffle=False,
                             collate_fn=generate_batch)

In [12]:
batch = next(iter(train_dataloader))

In [13]:
docs, offsets, labels_list = batch

In [14]:
output = net(docs, offsets, labels_list)

RuntimeError: CUDA out of memory. Tried to allocate 156.00 MiB (GPU 0; 10.92 GiB total capacity; 9.83 GiB already allocated; 149.44 MiB free; 9.97 GiB reserved in total by PyTorch)

In [28]:
# output

In [29]:
criterion = nn.BCEWithLogitsLoss()
optimizer = optim.Adam(net.parameters(), lr=0.001)

In [32]:
labels_list = [torch.tensor(labels, dtype=torch.float) for labels in labels_list]

labels = torch.cat(labels_list)
labels = labels.view(-1, output.size()[1])
labels = labels.to(DEVICE)

  """Entry point for launching an IPython kernel.


In [33]:
optimizer.zero_grad()

output = output.view(-1, output.size()[1])
loss = criterion(output, labels)
loss.backward()
optimizer.step()

In [34]:
loss

tensor(0.6827, device='cuda:0', grad_fn=<BinaryCrossEntropyWithLogitsBackward>)