In [8]:
from transformers import BartTokenizer, BartForConditionalGeneration
from torch.utils.data import DataLoader, Dataset, RandomSampler, SequentialSampler
import torch
import functools
from tqdm import tqdm
from transformers import AdamW, get_linear_schedule_with_warmup
import math
from time import time
from datasets import load_dataset

In [9]:
path_train = "/Users/datle/Downloads/ELI5.jsonl"
path_val = "/Users/datle/Downloads/ELI5_val.jsonl"

In [10]:
dataset_train = load_dataset('json', data_files = path_train)
dataset_val = load_dataset('json', data_files = path_val)

Found cached dataset json (/Users/datle/.cache/huggingface/datasets/json/default-0913b0ed92067c10/0.0.0/8bb11242116d547c741b2e8a1f18598ffdd40a1d4f2a2872c7a28b697434bc96)


  0%|          | 0/1 [00:00<?, ?it/s]

Found cached dataset json (/Users/datle/.cache/huggingface/datasets/json/default-f1b40a908fd17224/0.0.0/8bb11242116d547c741b2e8a1f18598ffdd40a1d4f2a2872c7a28b697434bc96)


  0%|          | 0/1 [00:00<?, ?it/s]

In [11]:
train = dataset_train['train'].select(range(0,1000))
val = dataset_val['train']

In [12]:
# create model and tokenizer
# Load tokenizer and model
tokenizer = BartTokenizer.from_pretrained("facebook/bart-base")
model = BartForConditionalGeneration.from_pretrained("facebook/bart-base")

In [13]:
def replacetext(string):
    return string.replace("\\","").replace("\n","")
class eli5dataset(Dataset):
    def __init__(self, data, num_docs):
        self.data = data
        self.qa_id_list = [
            (i,j)
            for i, qa in enumerate(self.data)
            for j, a in enumerate(qa['answers'])
        ]
        self.num_docs = num_docs
    def __len__(self):
        return len(self.qa_id_list)
    def make_example(self, idx):
        i,j = self.qa_id_list[idx]
        question = self.data['question'][i]
        question = replacetext(question)

        context = self.data['ctxs'][i][:self.num_docs]
        context = [k[0] for k in context]
        context = ' '.join(context)
        context = replacetext(context)

        answer = self.data['answers'][i][j]

        inputs = 'question: {} context: {}'.format(question.lower(), context.lower().strip())

        outputs = answer

        return (inputs, outputs)
    def __getitem__(self, idx):
        return self.make_example(idx)

In [14]:
def make_qa_s2s_batch(qa_list, tokenizer, max_len=64, max_a_len=360):
    q_ls = (q for q, a in qa_list)
    a_ls = (a for q, a in qa_list)

    q_toks = tokenizer.batch_encode_plus(q_ls, max_length=max_len, pad_to_max_length=True)
    q_ids, q_mask = (
        torch.LongTensor(q_toks['input_ids']),
        torch.LongTensor(q_toks['attention_mask'])
    )
    a_toks = tokenizer.batch_encode_plus(a_ls, max_length=min(max_len, max_a_len), pad_to_max_length=True)
    a_ids, a_mask = (
        torch.LongTensor(a_toks['input_ids']),
        torch.LongTensor(a_toks['attention_mask'])
    )
    labels = a_ids[:, 1:].contiguous().clone()
    labels[a_mask[:, 1:].contiguous() == 0] = -100

    model_inputs = {
        'input_ids': q_ids,
        'attention_mask': q_mask,
        'decoder_input_ids': a_ids[:, :-1].contiguous(),
        'labels': labels,
    }
    return model_inputs

In [15]:
def train_qa_s2s_epoch(model, dataset, tokenizer, optimizer, scheduler, args, e=0, curriculum=False):
    model.train()
    # make iterator
    if curriculum:
        train_sampler = SequentialSampler(dataset)
    else:
        train_sampler = RandomSampler(dataset)
    model_collate_fn = functools.partial(
        make_qa_s2s_batch, tokenizer=tokenizer, max_len=args.max_length
    )
    data_loader = DataLoader(dataset, batch_size=args.batch_size, sampler=train_sampler, collate_fn=model_collate_fn)
    epoch_iterator = tqdm(data_loader, desc="Iteration", disable=True)
    # accumulate loss since last print
    loc_steps = 0
    loc_loss = 0.0
    st_time = time()
    for step, batch_inputs in enumerate(epoch_iterator):
        pre_loss = model(**batch_inputs)[0]
        loss = pre_loss
        loss.backward()
        # optimizer
        if step % args.backward_freq == 0:
            optimizer.step()
            scheduler.step()
            model.zero_grad()
        # some printing within the epoch
        loc_loss += loss.item()
        loc_steps += 1
        if step % args.print_freq == 0 or step == 1:
            print(
                "{:2d} {:5d} of {:5d} \t L: {:.3f} \t -- {:.3f}".format(
                    e, step, len(dataset) // args.batch_size, loc_loss / loc_steps, time() - st_time,
                )
            )
            loc_loss = 0
            loc_steps = 0

In [16]:
def eval_qa_s2s_epoch(model, dataset, tokenizer, args):
    model.eval()
    # make iterator
    train_sampler = SequentialSampler(dataset)
    model_collate_fn = functools.partial(
        make_qa_s2s_batch, tokenizer=tokenizer, max_len=args.max_length
    )
    data_loader = DataLoader(dataset, batch_size=args.batch_size, sampler=train_sampler, collate_fn=model_collate_fn)
    epoch_iterator = tqdm(data_loader, desc="Iteration", disable=True)
    # accumulate loss since last print
    loc_steps = 0
    loc_loss = 0.0
    st_time = time()
    with torch.no_grad():
        for step, batch_inputs in enumerate(epoch_iterator):
            pre_loss = model(**batch_inputs)[0]
            loss = pre_loss
            loc_loss += loss.item()
            loc_steps += 1
            if step % args.print_freq == 0:
                print(
                    "{:5d} of {:5d} \t L: {:.3f} \t -- {:.3f}".format(
                        step, len(dataset) // args.batch_size, loc_loss / loc_steps, time() - st_time,
                    )
                )
    print("Total \t L: {:.3f} \t -- {:.3f}".format(loc_loss / loc_steps, time() - st_time,))

In [17]:
def train_qa_s2s(qa_s2s_model, qa_s2s_tokenizer, s2s_train_dset, s2s_valid_dset, s2s_args):
    s2s_optimizer = torch.optim.AdamW(qa_s2s_model.parameters(), lr=s2s_args.learning_rate, eps=1e-8)
    s2s_scheduler = get_linear_schedule_with_warmup(
        s2s_optimizer,
        num_warmup_steps=400,
        num_training_steps=(s2s_args.num_epochs + 1) * math.ceil(len(s2s_train_dset) / s2s_args.batch_size),
    )
    for e in range(s2s_args.num_epochs):
        train_qa_s2s_epoch(
            qa_s2s_model,
            s2s_train_dset,
            qa_s2s_tokenizer,
            s2s_optimizer,
            s2s_scheduler,
            s2s_args,
            e,
            curriculum=(e == 0),
        )
        m_save_dict = {
            "model": qa_s2s_model.state_dict(),
            "optimizer": s2s_optimizer.state_dict(),
            "scheduler": s2s_scheduler.state_dict(),
        }
        print("Saving model {}".format(s2s_args.model_save_name))
        eval_qa_s2s_epoch(qa_s2s_model, s2s_valid_dset, qa_s2s_tokenizer, s2s_args)
        torch.save(m_save_dict, "{}_{}.pth".format(s2s_args.model_save_name, e))

In [18]:
class ArgumentsS2S():
    def __init__(self):
        self.batch_size = 3
        self.backward_freq = 16
        self.max_length = 512
        self.print_freq = 100
        self.model_save_name = "seq2seq_models/eli5_bart_model"
        self.learning_rate = 2e-4
        self.num_epochs = 3

s2s_args = ArgumentsS2S()

In [19]:
train1 = eli5dataset(train, num_docs =3)
val1 = eli5dataset(val, num_docs =3)

In [None]:
train_qa_s2s(model, tokenizer, train1, val1, s2s_args)

Truncation was not explicitly activated but `max_length` is provided a specific value, please use `truncation=True` to explicitly truncate examples to max length. Defaulting to 'longest_first' truncation strategy. If you encode pairs of sequences (GLUE-style) with the tokenizer you can select this strategy more precisely by providing a specific strategy to `truncation`.
