In [None]:
from transformers import BartTokenizer, BartForConditionalGeneration
from torch.utils.data import DataLoader, Dataset, RandomSampler, SequentialSampler
import torch
import functools
from transformers import AdamW, get_linear_schedule_with_warmup
import math
from time import time
from datasets import load_dataset

In [None]:
path_train = "/kaggle/input/eli5-explain-like-i-am-5/ELI5-001.jsonl"
path_val = "/kaggle/input/eli5-explain-like-i-am-5/ELI5_val.jsonl"
dataset_train = load_dataset('json', data_files = path_train)
dataset_val = load_dataset('json', data_files = path_val)
train = dataset_train['train'].select(range(1,50000))
val = dataset_val['train']

In [None]:
def replace_text(string):
    return string.replace("\'","").replace("\n","").replace("URL_0","").lower().strip()
def preprocess(ex, n):
    ex['question'] = replace_text(ex['question'])
#     context = [k[0] for k in ex['ctxs'][:3]]
#     context = replace_text(' '.join(context))
    context = ex['ctxs'][:n]
    if type((context[0])) == list:
        context = [k[0] for k in context]
    context = replace_text(' '.join(context))
    ex['ctxs'] = context
    ex['answers'] = [replace_text(i) for i in ex['answers']]
    return ex

In [None]:
train1 = train.map(lambda ex: preprocess(ex,n=3), remove_columns = ['question_id'])
val1 = val.map(lambda ex: preprocess(ex,n=3), remove_columns = ['question_id'])

In [None]:
class eli5dataset(Dataset):
    def __init__(self, data):
        self.data = data
        self.qa_id_list = [
            (i,j)
            for i, qa in enumerate(self.data)
            for j, a in enumerate(qa['answers'])
            if j <= 3
        ]
    def __len__(self):
        return len(self.qa_id_list)
    def make_example(self, idx):
        i,j = self.qa_id_list[idx]
        question = self.data['question'][i]

        context = self.data['ctxs'][i]

        answer = self.data['answers'][i][j]

        return (question, context,  answer)
    def __getitem__(self, idx):
        return self.make_example(idx)

In [None]:
def make_qa_s2s_batch(qa_list, tokenizer, max_len=64, max_a_len=360):
    q_ls = (q for q,c,a in qa_list)
    c_ls = (c for q,c,a in qa_list)
    a_ls = (a for q,c,a in qa_list)

    q_toks = tokenizer.batch_encode_plus(q_ls, c_ls, max_length=max_len, padding='max_length', truncation=True, return_tensors='pt')
    q_ids, q_mask = (
        torch.LongTensor(q_toks['input_ids']),
        torch.LongTensor(q_toks['attention_mask'])
    )
    a_toks = tokenizer.batch_encode_plus(a_ls, max_length=min(max_len, max_a_len), padding='max_length', truncation=True,return_tensors='pt')
    a_ids, a_mask = (
        torch.LongTensor(a_toks['input_ids']),
        torch.LongTensor(a_toks['attention_mask'])
    )
    labels = a_ids[:, 1:].contiguous().clone()
    labels[a_mask[:, 1:].contiguous() == 0] = -100
    
#     print('q_ids shape',q_ids.shape)
#     print('q_mask shape', q_mask.shape)
#     print('a_ids shape', a_ids.shape)
#     print('a_mask shape', a_mask.shape)
#     print("labels shape", labels.shape)
    
    model_inputs = {
        'input_ids': q_ids,
        'attention_mask': q_mask,
        'decoder_input_ids': a_ids[:, :-1].contiguous(),
        'labels': labels,
    }
    return model_inputs

In [None]:
def data_loader(dataset, args):
    train_sampler = SequentialSampler(dataset)
    model_collate_fn = functools.partial(make_qa_s2s_batch, tokenizer=tokenizer, max_len=args.max_length)
    data_loader = DataLoader(dataset, batch_size=args.batch_size, sampler=train_sampler, collate_fn=model_collate_fn)
    return data_loader

In [None]:
class ArgumentsS2S():
    def __init__(self):
        self.batch_size = 20
        self.max_length = 256

In [None]:
!pip install lightning --q

In [None]:
import lightning as L
class bart_model(L.LightningModule):
    def __init__(self):
        super().__init__()
        self.model = BartForConditionalGeneration.from_pretrained("facebook/bart-base")
    def forward(self, batch_input):
        output = self.model(**batch_input)
        return output.loss, output.logits
    def training_step(self, batch, batch_idx):
        loss, output = self(batch)
        self.log('train_loss',loss, prog_bar=True, logger=True)
        return loss
    def validation_step(self, batch, batch_idx):
        loss, output = self(batch)
        self.log('val_loss',loss, prog_bar=True, logger=True)
        return loss
    def test_step(self, batch, batch_idx):
        loss, output = self(batch)
        self.log('test_loss',loss, prog_bar=True, logger=True)
        return loss
    def configure_optimizers(self):
        return torch.optim.AdamW(self.parameters(), lr=2e-4)

In [None]:
tokenizer = BartTokenizer.from_pretrained("facebook/bart-base")

In [None]:
s2s_args = ArgumentsS2S()
train2 = eli5dataset(train1)
val2 = eli5dataset(val1)
train_1 = data_loader(train2, s2s_args)
val_1 = data_loader(val2, s2s_args)

In [None]:
next(iter(train_1))

In [None]:
my_model = bart_model()

In [None]:
trainer = L.Trainer(accelerator='gpu', devices=2, max_epochs=3)
trainer.fit(my_model, train_1, val_1)