In [212]:
from transformers import BartForConditionalGeneration, BartTokenizer
from torch.utils.data import random_split, DataLoader

import torch
from torch.utils.data import DataLoader
import pytorch_lightning as pl

from torchmetrics.text.rouge import ROUGEScore

import pandas as pd
import numpy as np



In [170]:
BART_CHKPT = 'facebook/bart-large-xsum'
DATA_SRCS   = {'train': '../../data/data_clean/polisumm_fs_train.csv', 
               'test': '../../data/data_clean/polisumm_fs_test.csv', 
               'val': '../../data/data_clean/polisumm_fs_val.csv'}

In [3]:
model = BartForConditionalGeneration.from_pretrained(BART_CHKPT, forced_bos_token_id = 0)
tokenizer = BartTokenizer.from_pretrained(BART_CHKPT)

# Dataset Definition 

In [40]:
def encode_sentences(tokenizer, source_sentences, target_sentences, max_length=1024, pad_to_max_length=True, return_tensors="pt"):
    ''' Function that tokenizes a sentence 
      Args: tokenizer - the BART tokenizer; source and target sentences are the source and target sentences
      Returns: Dictionary with keys: input_ids, attention_mask, target_ids
    '''

    input_ids = []
    attention_masks = []
    target_ids = []
    tokenized_sentences = {}

    for sentence in source_sentences:
        encoded_dict = tokenizer(
              sentence,
              max_length=max_length,
              padding="max_length" if pad_to_max_length else None,
              truncation=True,
              return_tensors=return_tensors,
              add_prefix_space = True
          )

        input_ids.append(encoded_dict['input_ids'])
        attention_masks.append(encoded_dict['attention_mask'])

    input_ids = torch.cat(input_ids, dim = 0)
    attention_masks = torch.cat(attention_masks, dim = 0)

    for sentence in target_sentences:
        encoded_dict = tokenizer(
              sentence,
              max_length=max_length,
              padding="max_length" if pad_to_max_length else None,
              truncation=True,
              return_tensors=return_tensors,
              add_prefix_space = True
          )
        # Shift the target ids to the right
        # shifted_target_ids = shift_tokens_right(encoded_dict['input_ids'], tokenizer.pad_token_id)
        target_ids.append(encoded_dict['input_ids'])

    target_ids = torch.cat(target_ids, dim = 0)


    batch = {
      "input_ids": input_ids,
      "attention_mask": attention_masks,
      "labels": target_ids,
    }

    return batch

In [74]:
class PoliSummDataset(torch.utils.data.Dataset):
    
    def __init__(self, encodings):
        
        self.encodings = encodings
        
    def __getitem__(self, idx):
        item = {key: val[idx] for key, val in self.encodings.items()}
        return item
    
    def __len__(self):
        return len(self.encodings['input_ids'])

In [171]:
class PoliSummEvalModule(pl.LightningDataModule):
    
    def __init__(self, tokenizer, src_csv, batch_size = 64):
        '''
            Data expected to have columns ('all_texts', 'all_sum')
        '''
        
        super().__init__()
        
        self.tokenizer = tokenizer
        self.src_csv = src_csv
        self.batch_size = batch_size
        
    def prepare_data(self):
        
        self.data = pd.read_csv(self.src_csv)
        assert 'all_texts' in self.data.columns, 'Missing source texts column: all_texts'
        assert 'all_sum' in self.data.columns, 'Missing target summaries column: all_sum'
        
    def setup(self, stage = None):
#         self.test_encodings = self.tokenizer(self.data['all_texts'].astype(str).values, 
#                                                 self.data['all_sum'].astype(str).values,
#                                                 truncation = True, padding = True,
#                                                 return_tensors = 'pt')
        src_texts = self.data['all_texts'].astype(str).values
        targ_texts = self.data['all_sum'].astype(str).values
        
        self.test_encodings = encode_sentences(self.tokenizer, src_texts, targ_texts, return_tensors = 'pt')
        
    def test_dataloader(self):
        dataset = PoliSummDataset(self.test_encodings)
        test_dl = torch.utils.data.DataLoader(dataset, batch_size = self.batch_size, shuffle = False)
        return test_dl

In [213]:
class PoliSummDataModule(pl.LightningDataModule):
    
    def __init__(self, tokenizer, src_csvs, batch_size = 64):
        '''
            Data expected to have columns ('all_texts', 'all_sum')
        '''
        
        super().__init__()
        
        self.tokenizer = tokenizer
        self.src_csvs = src_csvs
        self.batch_size = batch_size
        
    def prepare_data(self):
        
        self.data_dict = {key: pd.read_csv(val) for key, val in self.src_csvs.items()}
        
    def setup(self, stage = None):
        
        train_src, train_targ = self.df_to_pairs(self.data_dict['train'])
        test_src, test_targ   = self.df_to_pairs(self.data_dict['test'])
        val_src, val_targ     = self.df_to_pairs(self.data_dict['val'])
        
        self.train_encodings = encode_sentences(self.tokenizer, train_src, train_targ, return_tensors = 'pt')
        self.test_encodings  = encode_sentences(self.tokenizer, test_src, test_targ, return_tensors = 'pt')
        self.val_encodings = encode_sentences(self.tokenizer, val_src, val_targ, return_tensors = 'pt')
    
    def train_dataloader(self):
        dataset = PoliSummDataset(self.train_encodings)
        train_dl = DataLoader(dataset, 
                              batch_size = self.batch_size, 
                              shuffle = True)
        return train_dl
        
    def test_dataloader(self):
        dataset = PoliSummDataset(self.test_encodings)
        test_dl = DataLoader(dataset, batch_size = self.batch_size, shuffle = False)
        return test_dl
    
    def val_dataloader(self):
        dataset = PoliSummDataset(self.val_encodings)
        val_dl = DataLoader(dataset, batch_size = self.batch_size, shuffle = False)
    
    def df_to_pairs(self, df):
        src_l = 'summarize left: ' + df['all_texts'].astype(str)
        src_r = 'summarize right: ' + df['all_texts'].astype(str)
        src_texts_double = np.concatenate((src_texts.values, src_texts.values))
        
        targ_l    = df['left_sum'].astype(str).values
        targ_r    = df['right_sum'].astype(str).values
        targs     = np.concatenate((targ_l, targ_r))
        
        return src_texts_double, targs
        

# Lightning Model 

In [198]:
class PoliSummModel(pl.LightningModule):
    
    def __init__(self, tokenizer, model, test_in_train = True):
        
        super().__init__()
        
        self.tokenizer = tokenizer
        self.model = model
        self.test_in_train = test_in_train
        
        if self.test_in_train:
            self.rouge_scorer = ROUGEScore()
        
    def forward(self, input_ids, **kwargs):
        return self.model(input_ids, **kwargs)
    
    def train_step(self, batch, batch_idx):
        src, mask, targ = batch['input_ids'], batch['attention_mask'], batch['labels']
        
        output = self(src, 
                     attention_mask = mask,
                     decoder_input_ids = targ)
        logits = output[0]
        
        ce_loss = torch.nn.CrossEntropyLoss(ignore_index = self.tokenizer.pad_token_id)
        val_loss = ce_loss(logits, targ)
        
        return {'loss': val_loss}
    
    def val_step(self, batch, batch_idx):
        src, mask, targ = batch['input_ids'], batch['attention_mask'], batch['labels']
        
        output = self(src, 
                     attention_mask = mask,
                     decoder_input_ids = targ)
        logits = output[0]
        
        ce_loss = torch.nn.CrossEntropyLoss(ignore_index = self.tokenizer.pad_token_id)
        val_loss = ce_loss(logits, targ)
        
        res_dict = {'loss': val_loss}
        
        # Summarization Metrics
        if self.test_in_train:
            gen_summs = self.generate_summ(batch, max_len = 256)
            ref_summs = self.untokenize_targ(batch)
            
            # ROUGE Score
            rouge_keep = ('rouge1_fmeasure', 'rouge2_fmeasure', 'rougeL_fmeasure')
            rouge_scores = self.rouge_scorer(gen_summs, ref_summs)
            rouge_scores = {name:score for name, score in rouge_scores if name in rouge_keep}          
            
            res_dict.update(rouge_scores)
        
        return res_dict
    
    def test_step(self, batch, batch_idx):
        src, mask, targ = batch['input_ids'], batch['attention_mask'], batch['labels']
        
        output = self(src, 
                     attention_mask = mask,
                     decoder_input_ids = targ)
        logits = output[0]
        
        ce_loss = torch.nn.CrossEntropyLoss(ignore_index = self.tokenizer.pad_token_id)
        val_loss = ce_loss(logits, targ)
        
        return {'loss': val_loss}
    
    def generate_summ(self, sample, eval_beams = 3, early_stopping = True, max_len = 40):

        generated_ids = self.model.generate(
            sample["input_ids"],
            attention_mask=sample["attention_mask"],
            use_cache=True,
            decoder_start_token_id = self.tokenizer.pad_token_id,
            num_beams= eval_beams,
            max_length = max_len,
            early_stopping = early_stopping
        )
        
        gen_summ = [self.tokenizer.decode(w, skip_special_tokens=True, clean_up_tokenization_spaces=True) for w in generated_ids]
        return gen_summ
    
    def untokenize_targ(self, sample):
        sent      = sample['labels']
        targ_summ = [self.tokenizer.decode(w, skip_special_tokens=True, clean_up_tokenization_spaces=True) for w in sent]
        return targ_summ
        

# Evaluation 

In [199]:
lightning_model = PoliSummModel(tokenizer, model)

In [200]:
data_mod = PoliSummDataModule(tokenizer, DATA_SRCS)

In [201]:
data_mod.prepare_data()

In [202]:
data_mod.setup()

In [204]:
test_dl = data_mod.test_dataloader()

In [205]:
for batch in test_dl:
    print(batch)
    break

{'input_ids': tensor([[    0,   208,  9112,  ...,     1,     1,     1],
        [    0,   140, 33727,  ..., 50118, 13082,     2],
        [    0,  1205,   640,  ..., 21792, 27324,     2],
        ...,
        [    0,   849, 22616,  ...,  1265,   467,     2],
        [    0,   849,   975,  ...,    50,  5898,     2],
        [    0,  2101, 15478,  ...,  1673,    11,     2]]), 'attention_mask': tensor([[1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 1, 1, 1],
        [1, 1, 1,  ..., 1, 1, 1],
        ...,
        [1, 1, 1,  ..., 1, 1, 1],
        [1, 1, 1,  ..., 1, 1, 1],
        [1, 1, 1,  ..., 1, 1, 1]]), 'labels': tensor([[  0,  20, 314,  ...,   1,   1,   1],
        [  0,  20, 314,  ...,   1,   1,   1],
        [  0,  20, 314,  ...,   1,   1,   1],
        ...,
        [  0,  20, 314,  ...,   1,   1,   1],
        [  0,  20, 235,  ...,   1,   1,   1],
        [  0,  20, 314,  ...,   1,   1,   1]])}


In [118]:
trainer = pl.Trainer(gpus = 0)

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs


In [148]:
for batch in test_dl:
    one_samp = batch
    break

In [149]:
one_samp = {key: val[0:3, :] for key, val in one_samp.items()}

In [150]:
one_samp

{'input_ids': tensor([[    0,  7639, 15895,  ...,    17,    27,     2],
         [    0,  5213,  1721,  ...,   831,  3257,     2],
         [    0,   195,   538,  ..., 31014,    63,     2]]),
 'attention_mask': tensor([[1, 1, 1,  ..., 1, 1, 1],
         [1, 1, 1,  ..., 1, 1, 1],
         [1, 1, 1,  ..., 1, 1, 1]]),
 'labels': tensor([[  0,  20, 314,  ...,   1,   1,   1],
         [  0,  20, 314,  ...,   1,   1,   1],
         [  0,  20, 314,  ...,   1,   1,   1]])}

In [155]:
lightning_model.generate_summ(one_samp, eval_beams = 1)

['People are reacting to the news that Treasury Secretary Steven Mnuchin says the first payments under a plan to rescue the US economy from the Coronavirus could be $1,200 per',
 "People are rejoicing at the news that US Vice President Joe Biden has reversed President Donald Trump's ban on transgender people serving in the military.",
 "The latest from the world of tech, including a major News Feed change, a plan to break up Big Tech, and Mark Zuckerberg's old blog posts disappearing."]