In [1]:
import transformers
from transformers import AdamW, AutoModel, AutoModelForSeq2SeqLM, AutoTokenizer, get_linear_schedule_with_warmup
from torch.utils.data import DataLoader, Dataset, RandomSampler, SequentialSampler
from peft import get_peft_config, get_peft_model, LoraConfig, TaskType
import json
import torch
from tqdm import tqdm
import math
import functools
from time import time

In [2]:
def qa_s2s_generate(
    question_doc,
    qa_s2s_model,
    qa_s2s_tokenizer,
    num_answers=1,
    num_beams=None,
    min_len=64,
    max_len=256,
    do_sample=False,
    temp=1.0,
    top_p=None,
    top_k=None,
    max_input_length=512,
    device="cuda:0",
):
    model_inputs = make_qa_s2s_batch([(question_doc, "A")], qa_s2s_tokenizer, max_input_length, device=device,)
    n_beams = num_answers if num_beams is None else max(num_beams, num_answers)
    generated_ids = qa_s2s_model.generate(
        input_ids=model_inputs["input_ids"],
        attention_mask=model_inputs["attention_mask"],
        min_length=min_len,
        max_length=max_len,
        do_sample=do_sample,
        early_stopping=True,
        num_beams=1 if do_sample else n_beams,
        temperature=temp,
        top_k=top_k,
        top_p=top_p,
        eos_token_id=qa_s2s_tokenizer.eos_token_id,
        no_repeat_ngram_size=3,
        num_return_sequences=num_answers,
        decoder_start_token_id=qa_s2s_tokenizer.bos_token_id,
    )
    return [qa_s2s_tokenizer.decode(ans_ids, skip_special_tokens=True).strip() for ans_ids in generated_ids]

In [3]:
def make_qa_s2s_batch(qa_list, tokenizer, max_len=64, max_a_len=360, device="cuda:0"):
    q_ls = [q for q, a in qa_list]
    a_ls = [a for q, a in qa_list]
    q_toks = tokenizer.batch_encode_plus(q_ls, max_length=max_len, pad_to_max_length=True)
    q_ids, q_mask = (
        torch.LongTensor(q_toks["input_ids"]).to(device),
        torch.LongTensor(q_toks["attention_mask"]).to(device),
    )
    a_toks = tokenizer.batch_encode_plus(a_ls, max_length=min(max_len, max_a_len), pad_to_max_length=True)
    a_ids, a_mask = (
        torch.LongTensor(a_toks["input_ids"]).to(device),
        torch.LongTensor(a_toks["attention_mask"]).to(device),
    )
    lm_labels = a_ids[:, 1:].contiguous().clone()
    lm_labels[a_mask[:, 1:].contiguous() == 0] = -100
    model_inputs = {
        "input_ids": q_ids,
        "attention_mask": q_mask,
        "decoder_input_ids": a_ids[:, :-1].contiguous(),
        "labels": lm_labels,
    }
    return model_inputs

In [4]:
def make_qa_s2s_model(model_name="facebook/bart-large", from_file=None, device="cuda:0"):
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    model = AutoModelForSeq2SeqLM.from_pretrained(model_name, 
                                                  device_map = 'auto'
                                                  ).to(device)
    print(model)
    if from_file is not None:
        param_dict = torch.load(from_file)  # has model weights, optimizer, and scheduler states
        model.load_state_dict(param_dict["model"])
    return tokenizer, model

In [5]:
def train_qa_s2s_epoch(model, dataset, tokenizer, optimizer, scheduler, args, e=0, curriculum=False):
    model.train()
    # make iterator
    if curriculum:
        train_sampler = SequentialSampler(dataset)
    else:
        train_sampler = RandomSampler(dataset)
    model_collate_fn = functools.partial(
        make_qa_s2s_batch, tokenizer=tokenizer, max_len=args.max_length, device=args.device
    )
    data_loader = DataLoader(dataset, batch_size=args.batch_size, sampler=train_sampler, collate_fn=model_collate_fn)
    epoch_iterator = tqdm(data_loader, desc="Iteration", disable=True)
    # accumulate loss since last print
    loc_steps = 0
    loc_loss = 0.0
    st_time = time()
    for step, batch_inputs in enumerate(epoch_iterator):
        # print(batch_inputs)
        batch_inputs['labels'] = batch_inputs.pop('labels')
        pre_loss = model(**batch_inputs)[0]
        # print(pre_loss)
        loss = pre_loss.sum()# / pre_loss.shape[0]
        loss.backward()
        # optimizer
        if step % args.backward_freq == 0:
            optimizer.step()
            scheduler.step()
            model.zero_grad()
        # some printing within the epoch
        loc_loss += loss.item()
        loc_steps += 1
        if step % args.print_freq == 0 or step == 1:
            print(
                "{:2d} {:5d} of {:5d} \t L: {:.3f} \t -- {:.3f}".format(
                    e, step, len(dataset) // args.batch_size, loc_loss / loc_steps, time() - st_time,
                )
            )
            loc_loss = 0
            loc_steps = 0

In [6]:
def eval_qa_s2s_epoch(model, dataset, tokenizer, args):
    model.eval()
    # make iterator
    train_sampler = SequentialSampler(dataset)
    model_collate_fn = functools.partial(
        make_qa_s2s_batch, tokenizer=tokenizer, max_len=args.max_length, device=args.device
    )
    data_loader = DataLoader(dataset, batch_size=args.batch_size, sampler=train_sampler, collate_fn=model_collate_fn)
    epoch_iterator = tqdm(data_loader, desc="Iteration", disable=True)
    # accumulate loss since last print
    loc_steps = 0
    loc_loss = 0.0
    st_time = time()
    with torch.no_grad():
        for step, batch_inputs in enumerate(epoch_iterator):
            batch_inputs['labels'] = batch_inputs.pop('labels')
            pre_loss = model(**batch_inputs)[0]
            loss = pre_loss.sum() #/ pre_loss.shape[0]
            loc_loss += loss.item()
            loc_steps += 1
            if step % args.print_freq == 0:
                print(
                    "{:5d} of {:5d} \t L: {:.3f} \t -- {:.3f}".format(
                        step, len(dataset) // args.batch_size, loc_loss / loc_steps, time() - st_time,
                    )
                )
    print("Total \t L: {:.3f} \t -- {:.3f}".format(loc_loss / loc_steps, time() - st_time,))

In [7]:
def train_qa_s2s(qa_s2s_model, qa_s2s_tokenizer, s2s_train_dset, s2s_valid_dset, s2s_args):
    s2s_optimizer = AdamW(qa_s2s_model.parameters(), lr=s2s_args.learning_rate, eps=1e-8)
    s2s_scheduler = get_linear_schedule_with_warmup(
        s2s_optimizer,
        num_warmup_steps=400,
        num_training_steps=(s2s_args.num_epochs + 1) * math.ceil(len(s2s_train_dset) / s2s_args.batch_size),
    )
    for e in range(s2s_args.num_epochs):
        train_qa_s2s_epoch(
            qa_s2s_model,
            s2s_train_dset,
            qa_s2s_tokenizer,
            s2s_optimizer,
            s2s_scheduler,
            s2s_args,
            e,
            curriculum=(e == 0),
        )
        m_save_dict = {
            "model": qa_s2s_model.state_dict(),
            "optimizer": s2s_optimizer.state_dict(),
            "scheduler": s2s_scheduler.state_dict(),
        }
        print("Saving model {}".format(s2s_args.model_save_name))
        eval_qa_s2s_epoch(qa_s2s_model, s2s_valid_dset, qa_s2s_tokenizer, s2s_args)
        torch.save(m_save_dict, "{}_{}.pth".format(s2s_args.model_save_name, e))

In [8]:
class ELI5DatasetS2S(Dataset):
    def __init__(
        self,
        data_array,
    ):
        self.data = data_array

    def __len__(self):
        return len(self.data)

    def append(self, question_doc, answer):
        self.data.append([question_doc, answer])

    # def make_example(self, idx):
    #     i, j = self.qa_id_list[idx]
    #     example = self.data[i]
    #     question = example["title"] + " " + example["selftext"]
    #     answer = example["answers"]["text"][j]
    #     q_id = example["q_id"]
    #     if self.make_doc_function is not None:
    #         self.document_cache[q_id] = self.document_cache.get(q_id, self.make_doc_function(example["title"]))
    #     document = self.document_cache[q_id]
    #     in_st = "question: {} context: {}".format(
    #         question.lower().replace(" --t--", "").strip(), document.lower().strip(),
    #     )
    #     out_st = answer
    #     return (in_st, out_st)

    def __getitem__(self, idx):
        return (self.data[idx][0], self.data[idx][1])

In [9]:
# Training set
try:
  f.close()
except:
  print("No file to close")

path = "Bản sao của ELI5-001.jsonl"
f = open(path, "r")

train_data = ELI5DatasetS2S([])

for id, line in enumerate(f):
  # print(id)
  data = json.loads(line)
  # print(data)

  question = data['question']
  doc = '. '.join(map(str, data['ctxs']))
  answer = '. '.join(map(str, data['answers']))

  question_doc = "question: {} context: {}".format(question, doc)

  train_data.append(question_doc, answer)

f.close()
del question, doc, answer, question_doc, f

No file to close


In [None]:
# Val set
try:
  f.close()
except:
  print("No file to close")

path = "Bản sao của ELI5_val.jsonl"
f = open(path, "r")

val_data = ELI5DatasetS2S([])

for id, line in enumerate(f):
  # print(id)
  data = json.loads(line)
  # print(data)

  question = data['question']
  doc = '. '.join(map(str, data['ctxs']))
  answer = '. '.join(map(str, data['answers']))

  question_doc = "question: {} context: {}".format(question, doc)

  val_data.append(question_doc, answer)

f.close()
del question, doc, answer, question_doc, f

No file to close


In [None]:
class ArgumentsS2S():
    def __init__(self):
        self.batch_size = 2
        self.backward_freq = 16
        self.max_length = 1024
        self.print_freq = 100
        self.model_save_name = "seq2seq_models/eli5_flan_t5_model"
        self.learning_rate = 3e-4
        self.num_epochs = 3
        self.device = 'cuda:0'

s2s_args = ArgumentsS2S()

qa_s2s_tokenizer, qa_s2s_model = make_qa_s2s_model(
    model_name="google/flan-t5-small",
    from_file=None,
    device=s2s_args.device
)

peft_config = LoraConfig(
    r=4,
    lora_alpha=8,
    target_modules=["q", "v"],
    lora_dropout=0.02,
    bias="none",
    task_type=TaskType.SEQ_2_SEQ_LM
)

qa_s2s_model = get_peft_model(qa_s2s_model, peft_config)
qa_s2s_model.print_trainable_parameters()

train_qa_s2s(qa_s2s_model, qa_s2s_tokenizer, train_data, val_data, s2s_args)

T5ForConditionalGeneration(
  (shared): Embedding(32128, 512)
  (encoder): T5Stack(
    (embed_tokens): Embedding(32128, 512)
    (block): ModuleList(
      (0): T5Block(
        (layer): ModuleList(
          (0): T5LayerSelfAttention(
            (SelfAttention): T5Attention(
              (q): Linear(in_features=512, out_features=384, bias=False)
              (k): Linear(in_features=512, out_features=384, bias=False)
              (v): Linear(in_features=512, out_features=384, bias=False)
              (o): Linear(in_features=384, out_features=512, bias=False)
              (relative_attention_bias): Embedding(32, 6)
            )
            (layer_norm): T5LayerNorm()
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (1): T5LayerFF(
            (DenseReluDense): T5DenseGatedActDense(
              (wi_0): Linear(in_features=512, out_features=1024, bias=False)
              (wi_1): Linear(in_features=512, out_features=1024, bias=False)
              (wo): 

Truncation was not explicitly activated but `max_length` is provided a specific value, please use `truncation=True` to explicitly truncate examples to max length. Defaulting to 'longest_first' truncation strategy. If you encode pairs of sequences (GLUE-style) with the tokenizer you can select this strategy more precisely by providing a specific strategy to `truncation`.


 0     0 of 136317 	 L: 9.022 	 -- 2.791
 0     1 of 136317 	 L: 6.060 	 -- 3.017
 0   100 of 136317 	 L: 7.060 	 -- 26.085
 0   200 of 136317 	 L: 7.199 	 -- 53.934
 0   300 of 136317 	 L: 6.997 	 -- 80.729
 0   400 of 136317 	 L: 7.050 	 -- 109.480
 0   500 of 136317 	 L: 6.950 	 -- 136.372
 0   600 of 136317 	 L: 6.944 	 -- 161.829
 0   700 of 136317 	 L: 6.811 	 -- 188.210
 0   800 of 136317 	 L: 6.710 	 -- 215.266
 0   900 of 136317 	 L: 6.996 	 -- 240.762
 0  1000 of 136317 	 L: 6.744 	 -- 267.345
 0  1100 of 136317 	 L: 6.554 	 -- 294.244
 0  1200 of 136317 	 L: 6.514 	 -- 321.672
 0  1300 of 136317 	 L: 6.406 	 -- 348.211
 0  1400 of 136317 	 L: 5.995 	 -- 374.185
 0  1500 of 136317 	 L: 5.816 	 -- 400.251
 0  1600 of 136317 	 L: 5.756 	 -- 427.513
 0  1700 of 136317 	 L: 5.323 	 -- 453.075
 0  1800 of 136317 	 L: 5.221 	 -- 479.488
 0  1900 of 136317 	 L: 4.984 	 -- 506.428
 0  2000 of 136317 	 L: 5.056 	 -- 539.029
 0  2100 of 136317 	 L: 4.852 	 -- 564.433
 0  2200 of 136317