# Train Summarizer Using PL

## imports

In [1]:
import json
from collections import OrderedDict

import numpy as np

import torch
import torch.optim as optim
import torch.nn.functional as F
import torch_geometric.transforms as T

import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint

from torch.utils.data import Dataset, DataLoader, random_split

from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.metrics import pairwise_distances

from models import BasicSummarizerWithGDC
from types_ import *

In [2]:
USE_CUDA = torch.cuda.is_available()
DEVICE = torch.device("cuda" if USE_CUDA else "cpu")

In [3]:
DEVICE

device(type='cuda')

## SummaExperiment Class with PL 

In [4]:
# class SummaryDataset(Dataset):
    
#     def __init__(self, path):
        
#         with open(path, 'r', encoding='utf8') as f:
#             self.data = [json.loads(line) for line in f]
        
#     def __len__(self):
#         """Returns the number of data."""
#         return len(self.data)
    
#     def __getitem__(self, idx):
#         sentences = self.data[idx]['doc'].split('\n')
#         labels = self.data[idx]['labels'].split('\n')
#         labels = [int(label) for label in labels]
        
#         return sentences, labels

In [5]:
class SummaExperiment(pl.LightningModule):
    
    def __init__(self,
                 model: BasicSummarizerWithGDC,
                 params: dict) -> None:
        
        super(SummaExperiment, self).__init__()
        
        self.model = model
        self.params = params
        # self.curr_device = None
        
        
    # ---------------------
    # TRAINING
    # ---------------------
    def forward(self, docs, offsets, tdms, labels) -> Tensor:
        return self.model(docs, offsets, tdms, labels)
    
    def loss_function(self, logits, labels):
        labels = torch.cat(
            [torch.tensor(label, dtype=torch.float) for label in labels]
        )
        labels = labels.view(-1, logits.size()[1]).to(DEVICE)
        logits = logits.view(-1, logits.size()[1])
        
        bce_loss = F.binary_cross_entropy_with_logits(logits, labels)
        return bce_loss
    
    def accuracy(self, logits, labels):
        """Computes the accuracy for multiple binary predictions"""

        labels = torch.cat(
            [torch.tensor(label, dtype=torch.float) for label in labels]
        )
        labels = labels.view(-1, logits.size()[1]).to(DEVICE)
        logits = logits.view(-1, logits.size()[1])

        preds = torch.round(logits)
        corrects = (preds == labels).sum().float()
        acc = corrects / labels.numel()
        return acc
    
    
    def training_step(self, batch, batch_idx):
        docs, offsets, labels, tdms = batch
        
        logits = self.forward(docs, offsets, tdms, labels)
        train_loss = self.loss_function(logits, labels)
        train_acc = self.accuracy(logits, labels)
        
        tqdm_dict = {'train_acc': train_acc}
        output = OrderedDict({
            'loss': train_loss,
            'progress_bar': tqdm_dict,
            'log': tqdm_dict
        })
        return output
    
    def validation_step(self, batch, batch_idx):
        docs, offsets, labels, tdms = batch
        
        logits = self.forward(docs, offsets, tdms, labels)
        val_loss = self.loss_function(logits, labels)
        
        # acc
        val_acc = self.accuracy(logits, labels)
        
        output = OrderedDict({
            'val_loss': val_loss,
            'val_acc': val_acc,
        })
        return output
    
    def validation_epoch_end(self, outputs):
        """
        Called at the end of validation to aggregate outputs
        :param outputs: list of individual outputs of each validation step
        :return:
        """
        val_loss_mean = 0
        val_acc_mean = 0
        for output in outputs:
            val_loss_mean += output['val_loss']
            val_acc_mean += output['val_acc']
        
        val_loss_mean /= len(outputs)
        val_acc_mean /= len(outputs)
        tqdm_dict = {'val_loss': val_loss_mean, 'val_acc': val_acc_mean}
        result = {'progress_bar': tqdm_dict, 'log': tqdm_dict, 'val_loss': val_loss_mean}
        return result
    
    
    def test_step(self, batch, batch_idx):
        docs, offsets, labels, tdms = batch
        logits = self.forward(docs, offsets, tdms, labels)
        test_loss = self.loss_function(logits, labels)
        
        # acc
        test_acc = self.accuracy(logits, labels)
        
        output = OrderedDict({
            'test_loss': test_loss,
            'test_acc': test_acc,
        })
        return output
    
    def test_epoch_end(self, outputs):
        test_loss_mean = 0
        test_acc_mean = 0
        for output in outputs:
            test_loss_mean += output['test_loss']
            test_acc_mean += output['test_acc']
        
        test_loss_mean /= len(outputs)
        test_acc_mean /= len(outputs)
        tqdm_dict = {'test_loss': test_loss_mean, 'test_acc': test_acc_mean}
        result = {'progress_bar': tqdm_dict, 'log': tqdm_dict, 'test_loss': test_loss_mean}
        return result
    
    # ---------------------
    # TRAINING SETUP
    # ---------------------
    def configure_optimizers(self):
        optimizer = optim.Adam(self.model.parameters(),
                               lr=self.params['LR'],
                               weight_decay=self.params['weight_decay'])
        
        return [optimizer]
    
    
    @staticmethod
    def __collate_fn(batch):
        docs = [entry[0] for entry in batch]
        labels = [entry[1] for entry in batch]
        tdms = [entry[2] for entry in batch]
        offsets = [0] + [len(doc) for doc in docs]
        return docs, offsets, labels, tdms
    
    def __dataloader(self, phase='train'):
        
        train_path = '../../data/summary/data/train.json'
        valid_path = '../../data/summary/data/val.json'
        test_path = '../../data/summary/data/test.json'

        trainset = SummaryDataset(train_path)
        validset = SummaryDataset(valid_path)
        testset = SummaryDataset(test_path)
        
        if phase == 'train':
            loader =  DataLoader(trainset, 
                                 batch_size=self.params['batch_size'], 
                                 shuffle=False, 
                                 collate_fn=self.__collate_fn)
        elif phase == 'valid':
            loader =  DataLoader(validset, 
                                 batch_size=self.params['batch_size'], 
                                 shuffle=False, 
                                 collate_fn=self.__collate_fn)
        elif phase == 'test':
            loader =  DataLoader(testset, 
                                 batch_size=self.params['batch_size'], 
                                 shuffle=False, 
                                 collate_fn=self.__collate_fn)
        
        return loader
    
    def train_dataloader(self):
        # log.info('Training data loader called.')
        print('Training data loader called.')
        return self.__dataloader(phase='train')
    
    def val_dataloader(self):
        # log.info('Validation data loader called.')
        print('Validation data loader called.')
        return self.__dataloader(phase='valid')
    
    def test_dataloader(self):
        # log.info('Test data loader called.')
        print('Test data loader called.')
        return self.__dataloader(phase='test')

## dataloader

### 1) Summary Dataset

In [6]:
def generate_batch(batch):
    docs = [entry[0] for entry in batch]
    labels = [entry[1] for entry in batch]
    tdms = [entry[2] for entry in batch]
    
    offsets = [0] + [len(doc) for doc in docs]
        
    return docs, offsets, labels, tdms

In [7]:
class SummaryDataset(Dataset):
    
    def __init__(self, path):
        
        with open(path, 'r', encoding='utf8') as f:
            self.data = [json.loads(line) for line in f]
        
    def __len__(self):
        """Returns the number of data."""
        return len(self.data)
    
    def __getitem__(self, idx):
        sentences = self.data[idx]['doc'].split('\n')
        labels = self.data[idx]['labels'].split('\n')
        labels = [int(label) for label in labels]
        
        tfidf = TfidfVectorizer().fit(sentences)
        tdm = tfidf.transform(sentences).toarray()
        
        return sentences, labels, tdm

## Train 

In [8]:
hparams = {
    'batch_size' : 32,
    'LR': 0.005,
    'weight_decay': 0.0001
}

model = BasicSummarizerWithGDC(in_dim=128,
                               hidden_dim=64,
                               out_dim=32,
                               num_heads=2,
                               num_classes=1,
                               use_gdc=True).to(DEVICE)
experiment = SummaExperiment(model, hparams)

In [9]:
# checkpoint
checkpoint_callback = ModelCheckpoint(
    filepath='./checkpoints/basicsummarizerGDC_{epoch:02d}_{val_acc:.2f}_lr005_v02',
    monitor='val_acc',
    verbose=True,
    save_top_k=5,
)

trainer = pl.Trainer(max_epochs=20,
                     checkpoint_callback=checkpoint_callback)

In [None]:
trainer.fit(experiment)

In [12]:
# model.CNT

## Test

In [13]:
hparams = {
    'batch_size' : 32,
    'LR': 0.005,
    'weight_decay': 0.0001
}

model = BasicSummarizer(in_dim=128,
                        hidden_dim=64,
                        out_dim=32,
                        num_heads=2,
                        num_classes=1).to(DEVICE)

INFO:transformers.tokenization_utils:loading file https://s3.amazonaws.com/models.huggingface.co/bert/albert-base-v2-spiece.model from cache at C:\Users\korea\.cache\torch\transformers\dd1588b85b6fdce1320e224d29ad062e97588e17326b9d05a0b29ee84b8f5f93.c81d4deb77aec08ce575b7a39a989a79dd54f321bfb82c2b54dd35f52f8182cf


In [14]:
experiment = SummaExperiment(model, hparams)

In [15]:
trainer = pl.Trainer(resume_from_checkpoint='./checkpoints/basicsummarizer_epoch=18_val_acc=0.62_lr005_v02.ckpt')

In [18]:
trainer.test(experiment)

Test data loader called.


HBox(children=(FloatProgress(value=0.0, description='Testing', layout=Layout(flex='2'), max=1213.0, style=Prog…

--------------------------------------------------------------------------------
TEST RESULTS
{'test_acc': 0.6256545782089233, 'test_loss': 0.7528885006904602}
--------------------------------------------------------------------------------



## Test2

In [8]:
save_path = './save_weights/basicsumarizer_epoch=18_val_acc=0.62_lr005.pth'
# torch.save(model.state_dict(), save_path)

In [9]:
hparams = {
    'batch_size' : 32,
    'LR': 0.005,
    'weight_decay': 0.0001
}

model = BasicSummarizer(in_dim=128,
                        hidden_dim=64,
                        out_dim=32,
                        num_heads=2,
                        num_classes=1).to(DEVICE)

model.load_state_dict(torch.load(save_path))

INFO:transformers.tokenization_utils:loading file https://s3.amazonaws.com/models.huggingface.co/bert/albert-base-v2-spiece.model from cache at C:\Users\korea\.cache\torch\transformers\dd1588b85b6fdce1320e224d29ad062e97588e17326b9d05a0b29ee84b8f5f93.c81d4deb77aec08ce575b7a39a989a79dd54f321bfb82c2b54dd35f52f8182cf


<All keys matched successfully>

In [10]:
experiment = SummaExperiment(model, hparams)

In [12]:
trainer = pl.Trainer()
trainer.test(experiment)

Test data loader called.


HBox(children=(FloatProgress(value=0.0, description='Testing', layout=Layout(flex='2'), max=1213.0, style=Prog…

--------------------------------------------------------------------------------
TEST RESULTS
{'test_acc': 0.6236557960510254, 'test_loss': 0.7534582614898682}
--------------------------------------------------------------------------------

