In [1]:
from transformers import BartTokenizer, BartForConditionalGeneration
from torch.utils.data import DataLoader, Dataset, RandomSampler, SequentialSampler
import torch
import functools
from transformers import AdamW, get_linear_schedule_with_warmup
import math
from time import time
from datasets import load_dataset, Dataset



In [2]:
path_train = "/Users/datle/Downloads/ELI5.jsonl"
path_val = "/Users/datle/Downloads/ELI5_val.jsonl"
dataset_train = load_dataset('json', data_files = path_train)
dataset_val = load_dataset('json', data_files = path_val)
train = dataset_train['train'].select(range(1,30000))
val = dataset_val['train']

Found cached dataset json (/Users/datle/.cache/huggingface/datasets/json/default-0913b0ed92067c10/0.0.0/8bb11242116d547c741b2e8a1f18598ffdd40a1d4f2a2872c7a28b697434bc96)


  0%|          | 0/1 [00:00<?, ?it/s]

Found cached dataset json (/Users/datle/.cache/huggingface/datasets/json/default-f1b40a908fd17224/0.0.0/8bb11242116d547c741b2e8a1f18598ffdd40a1d4f2a2872c7a28b697434bc96)


  0%|          | 0/1 [00:00<?, ?it/s]

In [3]:
tokenizer = BartTokenizer.from_pretrained("facebook/bart-base")

In [4]:
def preproces(ex):
    return tokenizer(ex['answers'],max_length=128, padding='max_length', truncation=True, return_tensors='pt')

In [15]:
def replace_text(string):
    return string.replace("\'","").replace("\n","").replace("URL_0","").lower().strip()
def concat(ex, n):
    question = replace_text(ex['question'])
    context = ex['ctxs'][:n]
    if type((context[0])) == list:
        context = [k[0] for k in context]
    context = replace_text(' '.join(context))
    ex['ques_ctxs'] = f"question: {question} context: {context}"
    return ex

In [16]:
question_ds = train.map(lambda x: concat(x, n=2), remove_columns = ['question', 'ctxs'])

Map:   0%|          | 0/29999 [00:00<?, ? examples/s]

In [37]:
def tokenize(ex):
    dct = tokenizer(ex['ques_ctxs'],max_length=128, 
                    padding='max_length', truncation=True, return_tensors='pt')
    ex['input_id_q'], ex['attention_q'] = dct['input_ids'], dct['attention_mask']
    
    dct1 = (tokenizer(ex['answers'],max_length=128, 
                    padding='max_length', truncation=True, return_tensors='pt'))
    ex['input_id_a'], ex['attention_a']= dct1['input_ids'], dct1['attention_mask']
    
    return ex

In [38]:
tokenized_ds = question_ds.map(tokenize)

Map:   0%|          | 0/29999 [00:00<?, ? examples/s]

In [43]:
tokenized_ds

Dataset({
    features: ['question_id', 'answers', 'ques_ctxs', 'input_id_q', 'attention_q', 'input_id_a', 'attention_a'],
    num_rows: 29999
})

In [69]:
class eli5dataset(Dataset):
    def __init__(self, dataset):
        self.dataset = dataset
        self.qa_id_list = [
            (i,j)
            for i, q in enumerate(self.dataset['input_id_a'])
            for j, a in enumerate(q)
        ]
    def __len__(self):
        return len(self.qa_id_list)
    def make_example(self, idx):
        q_ids = []
        q_mask = []
        a_ids = []
        a_mask = []
        for i in idx:
            i,j = self.qa_id_list[i]
            
        
            q_ids.append(self.dataset['input_id_q'][i])
            q_mask.append(self.dataset['attention_q'][i])
            a_id = self.dataset['input_id_a'][i][j]
            a_m = self.dataset['attention_a'][i][j]
            a_ids.append(a_id)
            a_mask.append(a_m)
        
        q_ids, q_mask = (
        torch.squeeze(torch.LongTensor(q_ids)),
        torch.squeeze(torch.LongTensor(q_mask))
        )
        
        a_ids, a_mask = (
        torch.squeeze(torch.LongTensor(a_ids)),
        torch.squeeze(torch.LongTensor(a_mask))
        )

        labels = a_ids[:, 1:].contiguous().clone()
        labels[a_mask[:, 1:].contiguous() == 0] = -100

        model_inputs = {
        'input_ids': q_ids,
        'attention_mask': q_mask,
        'decoder_input_ids': a_ids[:, :-1].contiguous(),
        'labels': labels,
    }
        return model_inputs
    def __getitem__(self, idx):
        return self.make_example(idx)

In [70]:
def data_loader(dataset, args):
    train_sampler = SequentialSampler(dataset)
    data_loader = DataLoader(dataset, batch_size=args.batch_size, sampler=train_sampler)
    return data_loader

In [71]:
class ArgumentsS2S():
    def __init__(self):
        self.batch_size = 2
        self.max_length = 128

s2s_args = ArgumentsS2S()

In [72]:
import lightning as L
class bart_model(L.LightningModule):
    def __init__(self):
        super().__init__()
        self.model = BartForConditionalGeneration.from_pretrained("facebook/bart-base")
    def forward(self, batch_input):
        output = self.model(**batch_input)
        return output.loss, output.logits
    def training_step(self, batch, batch_idx):
        loss, output = self(batch)
        self.log('train_loss',loss, prog_bar=True, logger=True)
        return loss
    def validation_step(self, batch, batch_idx):
        loss, output = self(batch)
        self.log('val_loss',loss, prog_bar=True, logger=True)
        return loss
    def test_step(self, batch, batch_idx):
        loss, output = self(batch)
        self.log('test_loss',loss, prog_bar=True, logger=True)
        return loss
    def configure_optimizers(self):
        return torch.optim.AdamW(self.parameters(), lr=2e-4)

In [73]:
tokenized_ds

Dataset({
    features: ['question_id', 'answers', 'ques_ctxs', 'input_id_q', 'attention_q', 'input_id_a', 'attention_a'],
    num_rows: 29999
})

In [74]:
train2 = eli5dataset(tokenized_ds)

train_1 = data_loader(train2, s2s_args)


In [75]:
my_model = bart_model()

In [76]:
trainer = L.Trainer(max_epochs=3)
trainer.fit(my_model, train_1)

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs

  | Name  | Type                         | Params
-------------------------------------------------------
0 | model | BartForConditionalGeneration | 139 M 
-------------------------------------------------------
139 M     Trainable params
0         Non-trainable params
139 M     Total params
557.682   Total estimated model params size (MB)


Training: 0it [00:00, ?it/s]

[0, 1]
[2, 3]
[4, 5]
[6, 7]
[8, 9]


  rank_zero_warn("Detected KeyboardInterrupt, attempting graceful shutdown...")
