In [1]:
from transformers import BartTokenizer, BartForConditionalGeneration
from torch.utils.data import DataLoader, Dataset, RandomSampler, SequentialSampler
import torch
import functools
from tqdm import tqdm
from transformers import AdamW, get_linear_schedule_with_warmup
import math
from time import time
from datasets import load_dataset



In [183]:
path_train = "/Users/datle/Downloads/ELI5.jsonl"
path_val = "/Users/datle/Downloads/ELI5_val.jsonl"
dataset_train = load_dataset('json', data_files = path_train)
dataset_val = load_dataset('json', data_files = path_val)
train = dataset_train['train'].select(range(0,1000))
val = dataset_val['train']

Found cached dataset json (/Users/datle/.cache/huggingface/datasets/json/default-0913b0ed92067c10/0.0.0/8bb11242116d547c741b2e8a1f18598ffdd40a1d4f2a2872c7a28b697434bc96)


  0%|          | 0/1 [00:00<?, ?it/s]

Found cached dataset json (/Users/datle/.cache/huggingface/datasets/json/default-f1b40a908fd17224/0.0.0/8bb11242116d547c741b2e8a1f18598ffdd40a1d4f2a2872c7a28b697434bc96)


  0%|          | 0/1 [00:00<?, ?it/s]

In [184]:
val

Dataset({
    features: ['question_id', 'question', 'answers', 'ctxs'],
    num_rows: 1507
})

In [185]:
tokenizer = BartTokenizer.from_pretrained("facebook/bart-base")
model = BartForConditionalGeneration.from_pretrained("facebook/bart-base")

In [186]:
def replace_text(string):
    return string.replace("\'","").replace("\n","").replace("URL_0","").lower().strip()
def preprocess(ex, n):
    ex['question'] = replace_text(ex['question'])
#     context = [k[0] for k in ex['ctxs'][:3]]
#     context = replace_text(' '.join(context))
    context = ex['ctxs'][:n]
    if type((context[0])) == list:
        context = [k[0] for k in context]
    context = replace_text(' '.join(context))
    ex['ctxs'] = context
    ex['answers'] = [replace_text(i) for i in ex['answers']]
    return ex
    

In [187]:
train1 = train.map(lambda ex: preprocess(ex,n=3), remove_columns = ['question_id'])
val1 = val.map(lambda ex: preprocess(ex,n=3), remove_columns = ['question_id'])

Loading cached processed dataset at /Users/datle/.cache/huggingface/datasets/json/default-0913b0ed92067c10/0.0.0/8bb11242116d547c741b2e8a1f18598ffdd40a1d4f2a2872c7a28b697434bc96/cache-49d3551ef0864a00.arrow
Loading cached processed dataset at /Users/datle/.cache/huggingface/datasets/json/default-f1b40a908fd17224/0.0.0/8bb11242116d547c741b2e8a1f18598ffdd40a1d4f2a2872c7a28b697434bc96/cache-c160336ec261883e.arrow


In [188]:
class eli5dataset(Dataset):
    def __init__(self, data, num_docs):
        self.data = data
        self.qa_id_list = [
            (i,j)
            for i, qa in enumerate(self.data)
            for j, a in enumerate(qa['answers'])
        ]
        self.num_docs = num_docs
    def __len__(self):
        return len(self.qa_id_list)
    def make_example(self, idx):
        i,j = self.qa_id_list[idx]
        question = self.data['question'][i]

        context = self.data['ctxs'][i]

        answer = self.data['answers'][i][j]

        inputs = 'question: {} context: {}'.format(question, context)

        outputs = answer

        return (inputs, outputs)
    def __getitem__(self, idx):
        return self.make_example(idx)

In [189]:
def make_qa_s2s_batch(qa_list, tokenizer, max_len=64, max_a_len=360):
    q_ls = (q for q, a in qa_list)
    a_ls = (a for q, a in qa_list)

    q_toks = tokenizer.batch_encode_plus(q_ls, max_length=max_len, padding='max_length', truncation=True, return_tensors='pt')
    q_ids, q_mask = (
        torch.LongTensor(q_toks['input_ids']),
        torch.LongTensor(q_toks['attention_mask'])
    )
    a_toks = tokenizer.batch_encode_plus(a_ls, max_length=min(max_len, max_a_len), padding='max_length', truncation=True,return_tensors='pt')
    a_ids, a_mask = (
        torch.LongTensor(a_toks['input_ids']),
        torch.LongTensor(a_toks['attention_mask'])
    )
    labels = a_ids[:, 1:].contiguous().clone()
    labels[a_mask[:, 1:].contiguous() == 0] = -100

    model_inputs = {
        'input_ids': q_ids,
        'attention_mask': q_mask,
        'decoder_input_ids': a_ids[:, :-1].contiguous(),
        'labels': labels,
    }
    return model_inputs

In [190]:
def data_loader(dataset, args):
    train_sampler = SequentialSampler(dataset)
    model_collate_fn = functools.partial(make_qa_s2s_batch, tokenizer=tokenizer, max_len=args.max_length)
    data_loader = DataLoader(dataset, batch_size=args.batch_size, sampler=train_sampler, collate_fn=model_collate_fn)
    return data_loader

In [191]:
class ArgumentsS2S():
    def __init__(self):
        self.batch_size = 3
        self.backward_freq = 16
        self.max_length = 512
        self.print_freq = 100
        self.model_save_name = "seq2seq_models/eli5_bart_model"
        self.learning_rate = 2e-4
        self.num_epochs = 3

s2s_args = ArgumentsS2S()

In [192]:
import lightning as L
class bart_model(L.LightningModule):
    def __init__(self):
        super().__init__()
        self.model = BartForConditionalGeneration.from_pretrained("facebook/bart-base")
    def forward(self, batch_input):
        output = model(**batch_input)
        return output.loss, output.logits
    def training_step(self, batch, batch_idx):
        loss, output = self(batch)
        self.log('train_loss',loss, prog_bar=True, logger=True)
        return loss
    def validation_step(self, batch, batch_idx):
        loss, output = self(batch)
        self.log('val_loss',loss, prog_bar=True, logger=True)
        return loss
    def test_step(self, batch, batch_idx):
        loss, output = self(batch)
        self.log('test_loss',loss, prog_bar=True, logger=True)
        return loss
    def configure_optimizers(self):
        return torch.optim.AdamW(self.parameters(), lr=0.0001)

In [193]:
train2 = eli5dataset(train1, num_docs =2)
val2 = eli5dataset(val1, num_docs =2)

In [194]:
train_1 = data_loader(train2, s2s_args)
val_1 = data_loader(val2, s2s_args)

In [195]:
my_model = bart_model()

In [196]:
trainer = L.Trainer()
trainer.fit(my_model, train_1, val_1)

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs

  | Name  | Type                         | Params
-------------------------------------------------------
0 | model | BartForConditionalGeneration | 139 M 
-------------------------------------------------------
139 M     Trainable params
0         Non-trainable params
139 M     Total params
557.682   Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]