In [76]:
import torch
import pickle
import datasets
import random

from transformers import T5ForConditionalGeneration, T5TokenizerFast, BartForConditionalGeneration, BartTokenizerFast
from torch.utils.data import Dataset, DataLoader
from tqdm.auto import tqdm
from transformers import get_cosine_schedule_with_warmup
from nltk.translate.bleu_score import sentence_bleu

In [2]:
with open("data/qqp_train_dict", "rb") as f:
    qqp_train_dict = pickle.load(f)

with open("data/qqp_eval_dict", "rb") as f:
    qqp_eval_dict = pickle.load(f)

In [3]:
qqp_eval_dict.keys()

dict_keys(['question1_text', 'question2_text', 'question1_wordlst', 'question2_wordlst', 'question1_input_ids_bert', 'question1_token_type_ids_bert', 'question1_attention_mask_bert', 'question2_input_ids_bert', 'question2_token_type_ids_bert', 'question2_attention_mask_bert', 'question1_input_ids', 'question2_input_ids', 'question1_attention_mask', 'question2_attention_mask'])

In [4]:
qqp_train_dict = {'question1_text': qqp_train_dict['question1_text'],
                  'question2_text': qqp_train_dict['question2_text']}
qqp_eval_dict = {'question1_text': qqp_eval_dict['question1_text'],
                  'question2_text': qqp_eval_dict['question2_text']}

In [5]:
qqp_train_raw = datasets.Dataset.from_dict(qqp_train_dict)
qqp_eval_raw = datasets.Dataset.from_dict(qqp_eval_dict)

In [6]:
qqp_eval_raw

Dataset({
    features: ['question1_text', 'question2_text'],
    num_rows: 13438
})

In [7]:
tokenizer = BartTokenizerFast.from_pretrained("facebook/bart-base")

In [8]:
def preprocess_func(examples):
    question1 = examples['question1_text']
    question2 = examples['question2_text']
    tokenized_question1 = tokenizer(question1, max_length=32, padding='max_length', truncation=True,)
    tokenized_question2 = tokenizer(question2, max_length=32, padding='max_length', truncation=True,)
    res = {}
    for k,v in tokenized_question1.items():
        res['question1_'+k] = v
    for k,v in tokenized_question2.items():
        res['question2_'+k] = v
    return res

In [9]:
qqp_train_processed = qqp_train_raw.map(preprocess_func, batched=True)
qqp_eval_processed = qqp_eval_raw.map(preprocess_func, batched=True)

  0%|          | 0/121 [00:00<?, ?ba/s]

  0%|          | 0/14 [00:00<?, ?ba/s]

In [10]:
qqp_train_processed

Dataset({
    features: ['question1_text', 'question2_text', 'question1_input_ids', 'question1_attention_mask', 'question2_input_ids', 'question2_attention_mask'],
    num_rows: 120940
})

In [11]:
class QQPParaphraseDataset(Dataset):
    def __init__(self, dataset, random_swap=False):
        """
        QQP paraphrase dataset using torch.utils.data.Dataset
        :param dataset: huggingface dataset
        :param random_swap: randomly swap question1 and question2
        """
        longtensor_keys = ['question1_input_ids',
                           'question2_input_ids',]
        tensor_keys = ['question1_attention_mask',
                       'question2_attention_mask']

        dict_dataset = {}
        for key in longtensor_keys:
            dict_dataset[key] = torch.LongTensor(dataset[key])
        for key in tensor_keys:
            dict_dataset[key] = torch.Tensor(dataset[key])
        self.dataset = dict_dataset
        self.random_swap = random_swap

    @staticmethod
    def flip_key(key_str):
        # used to swap 1 and 2
        # ord('1') + ord('2') = 99
        return key_str[:8] + chr(99-ord(key_str[8])) + key_str[9:]

    def __getitem__(self, item):
        return_keys = ['question1_input_ids',
                       'question2_input_ids',
                       'question1_attention_mask',
                       'question2_attention_mask']

        if self.random_swap and random.random() < 0.5:
            tmp = {k: self.dataset[self.flip_key(k)][item] for k in return_keys}
        else:
            tmp = {k: self.dataset[k][item] for k in return_keys}

        return {"input_ids": tmp['question1_input_ids'],
                "attention_mask": tmp['question1_attention_mask'],
                "labels": tmp['question2_input_ids']}

    def __len__(self):
        return self.dataset['question1_input_ids'].shape[0]

In [12]:
train_dataset = QQPParaphraseDataset(qqp_train_processed, random_swap=True)
eval_dataset = QQPParaphraseDataset(qqp_eval_processed, random_swap=False)

In [13]:
num_epoch = 10
lr = 5e-5
weight_decay = 1e-3
batch_size = 32
device = torch.device("cuda:0")
num_warmup_steps = 200


In [14]:
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
eval_dataloader = DataLoader(eval_dataset, batch_size=batch_size, shuffle=False)

In [15]:
bart_model = BartForConditionalGeneration.from_pretrained("facebook/bart-base").to(device)

In [16]:
sum([p.numel() for p in bart_model.parameters()])

139420416

In [17]:
bart_model.config

BartConfig {
  "_name_or_path": "facebook/bart-base",
  "activation_dropout": 0.1,
  "activation_function": "gelu",
  "add_bias_logits": false,
  "add_final_layer_norm": false,
  "architectures": [
    "BartModel"
  ],
  "attention_dropout": 0.1,
  "bos_token_id": 0,
  "classif_dropout": 0.1,
  "classifier_dropout": 0.0,
  "d_model": 768,
  "decoder_attention_heads": 12,
  "decoder_ffn_dim": 3072,
  "decoder_layerdrop": 0.0,
  "decoder_layers": 6,
  "decoder_start_token_id": 2,
  "dropout": 0.1,
  "early_stopping": true,
  "encoder_attention_heads": 12,
  "encoder_ffn_dim": 3072,
  "encoder_layerdrop": 0.0,
  "encoder_layers": 6,
  "eos_token_id": 2,
  "forced_bos_token_id": 0,
  "forced_eos_token_id": 2,
  "gradient_checkpointing": false,
  "id2label": {
    "0": "LABEL_0",
    "1": "LABEL_1",
    "2": "LABEL_2"
  },
  "init_std": 0.02,
  "is_encoder_decoder": true,
  "label2id": {
    "LABEL_0": 0,
    "LABEL_1": 1,
    "LABEL_2": 2
  },
  "max_position_embeddings": 1024,
  "model_ty

In [18]:
no_decay = ["bias", "LayerNorm.weight", "layer_norm.weight"]
optimizer_grouped_parameters = [
    {
        "params": [p for n, p in bart_model.named_parameters() if not any(nd in n for nd in no_decay)],
        "weight_decay": weight_decay,
    },
    {
        "params": [p for n, p in bart_model.named_parameters() if any(nd in n for nd in no_decay)],
        "weight_decay": 0.0,
    },
]
optimizer = torch.optim.AdamW(optimizer_grouped_parameters, lr=lr)
scheduler = get_cosine_schedule_with_warmup(optimizer=optimizer,
                                            num_warmup_steps=num_warmup_steps,
                                            num_training_steps=num_epoch * len(train_dataloader))

In [19]:
def train(model, dataloader, optimizer, scheduler=None, num_epoch=1, verbose=False, print_steps=200):
    progress_bar = tqdm(range(num_epoch*len(dataloader)))
    device = model.device
    for epoch in range(num_epoch):
        training_loss = 0
        sample_cnt = 0
        model.train()
        for step, batch in enumerate(dataloader):
            batch = {k: v.to(device) for k, v in batch.items()}
            outputs = model(**batch)
            loss = outputs.loss
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            if scheduler is not None:
                scheduler.step()
            progress_bar.update(1)

            bs = batch['input_ids'].shape[0]
            sample_cnt += bs
            training_loss += loss.cpu() * bs

            if verbose and step % print_steps == print_steps-1:
                # print training loss
                training_loss /= sample_cnt
                sample_cnt = 0
                print('step:', step+1, ' training loss={:.4f}'.format(training_loss))
                training_loss = 0


def evaluate(model, dataloader, verbose=True):
    progress_bar = tqdm(range(len(dataloader)))
    device = model.device
    model.eval()
    test_loss = 0
    with torch.no_grad():
        for batch in dataloader:
            batch = {k: v.to(device) for k, v in batch.items()}
            outputs = model(**batch)
            bs = batch['input_ids'].shape[0]
            test_loss += bs*outputs.loss
            progress_bar.update(1)

        test_loss /= len(dataloader.dataset)
        if verbose:
            print('eval loss={:.4f}'.format(test_loss))

In [None]:
for epoch in range(num_epoch):
    print("epoch:",epoch+1)
    train(model=bart_model, dataloader=train_dataloader, optimizer=optimizer, scheduler=scheduler ,verbose=True)
    evaluate(model=bart_model, dataloader=eval_dataloader)

In [21]:
eval_batch = next(iter(eval_dataloader))

In [30]:
tokenizer.batch_decode(eval_batch['input_ids'], skip_special_tokens=True)

['"Is ""Pokemon Ranger and The Temple of The Sea"" considered appropriate for kids?"',
 'What is the best textbook for Hebrew?',
 'How do I take control on masturbation?',
 'Who is your favourite female movie director and why?',
 'Do ghost actually exists?',
 'Which is the best institute in Mumbai for doing Financial Modeling certification course?',
 'What is exam pattern of MH CET MBA?',
 'What is the difference between a porn figure and a prostitute?',
 '"Why my \'\'i"" is different than yours?"',
 'How do you become more masculine?',
 'How can improve my managerial skills?',
 'Why do people ask question on Quora that can be easily and definitively answered by Googling?',
 'How did you find a job abroad?',
 'Why would a boy love a girl?',
 'What do you hate about school?',
 'What are the consequences of lying about your ethnicity on your college applications?',
 'What are compounds? What are some examples?',
 'What would you do if you woke up to find that nuclear war had started?',
 

In [31]:
tokenizer.batch_decode(eval_batch['labels'], skip_special_tokens=True)

['"Is ""Pokémon Ranger and The Temple of The Sea"" considered childish?"',
 "What's the best self study book to learn Hebrew?",
 'How do I control on masturbation?',
 'Who is the best female movie director?',
 'Does ghost really exist?',
 'Which is the best institute in Mumbai from where a fresher can learn financial modeling?',
 'What is the exam pattern of MH CET MBA?',
 'What is a difference between a prostitute and a porn star?',
 '"Why my \'\' I "" is different than yours?"',
 'How can one become more masculine?',
 'How do you improve your managerial skills?',
 'Why do so many people ask questions on Quora that can be easily answered by any number of legitimate sources on the Web? Have they not heard of',
 'How do I find a job abroad?',
 'Why do boys love girls?',
 'Why do you hate school?',
 'Can I lie about my ethnicity to top college admissions?',
 'What are some examples of compounds?',
 'What would you do, if nuclear war began?',
 'How should I start learning about stock trad

In [38]:
generate_output = bart_model.generate(eval_batch['input_ids'].to(device), num_beams=4, max_length=32)

In [39]:
tokenizer.batch_decode(generate_output, skip_special_tokens=True)

['"Is ""Pokemon Ranger and The Temple of The Sea"" a good anime?"',
 'What is the best book for learning Hebrew?',
 'How can I stop masturbation?',
 'Who is your favorite female movie director and why?',
 'Do spirits really exist?',
 'Which is the best institute for financial modelling in Mumbai?',
 'What is the exam pattern of MH CET MBA?',
 'What is the difference between a prostitute and a porn star?',
 '"Why my ""i"" is different than yours?"',
 'How can I be more masculine?',
 'How can I improve my managerial skills?',
 'Why do some people ask questions on Quora that could easily be answered by using a search engine?',
 'How can I find a job abroad?',
 'Why do boys love girls?',
 'What do you hate most about school?',
 'What are the consequences of lying about ethnicity in college admissions?',
 'What are some examples of compounds?',
 'What would you do if you woke up in the middle of a nuclear war?',
 'What is the best way to learn about stock trading?',
 'What do you think of t

In [47]:
(tokenizer.convert_ids_to_tokens(generate_output[1], skip_special_tokens=True))

['What', 'Ġis', 'Ġthe', 'Ġbest', 'Ġbook', 'Ġfor', 'Ġlearning', 'ĠHebrew', '?']

In [50]:
src = tokenizer.convert_ids_to_tokens(eval_batch['input_ids'][1], skip_special_tokens=True)
print(src)

['What', 'Ġis', 'Ġthe', 'Ġbest', 'Ġtextbook', 'Ġfor', 'ĠHebrew', '?']


In [56]:
eval_batch['input_ids'][0]

tensor([    0,   113,  6209, 41039, 46145, 23151,     8,    20,  9660,     9,
           20,  3939, 48149,  1687,  3901,    13,  1159,  1917,     2,     1,
            1,     1,     1,     1,     1,     1,     1,     1,     1,     1,
            1,     1])

In [58]:
tokenizer.convert_tokens_to_ids('is')

354

In [62]:
tokenizer(qqp_eval_raw[0]['question1_text'])

{'input_ids': [0, 113, 6209, 41039, 46145, 23151, 8, 20, 9660, 9, 20, 3939, 48149, 1687, 3901, 13, 1159, 1917, 2], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]}

In [64]:
tokenizer.decode(generate_output[0])

'</s><s>"Is ""Pokemon Ranger and The Temple of The Sea"" a good anime?"</s><pad><pad><pad><pad>'

In [66]:
list(eval_batch['input_ids'])

[tensor([    0,   113,  6209, 41039, 46145, 23151,     8,    20,  9660,     9,
            20,  3939, 48149,  1687,  3901,    13,  1159,  1917,     2,     1,
             1,     1,     1,     1,     1,     1,     1,     1,     1,     1,
             1,     1]),
 tensor([    0,  2264,    16,     5,   275, 31046,    13, 27428,   116,     2,
             1,     1,     1,     1,     1,     1,     1,     1,     1,     1,
             1,     1,     1,     1,     1,     1,     1,     1,     1,     1,
             1,     1]),
 tensor([    0,  6179,   109,    38,   185,   797,    15, 44473,   116,     2,
             1,     1,     1,     1,     1,     1,     1,     1,     1,     1,
             1,     1,     1,     1,     1,     1,     1,     1,     1,     1,
             1,     1]),
 tensor([    0, 12375,    16,   110,  5548,  2182,  1569,   736,     8,   596,
           116,     2,     1,     1,     1,     1,     1,     1,     1,     1,
             1,     1,     1,     1,     1,     1,     1

In [72]:
generate_output

tensor([[    2,     0,   113,  6209, 41039, 46145, 23151,     8,    20,  9660,
             9,    20,  3939, 48149,    10,   205, 28805,  1917,     2,     1,
             1,     1,     1],
        [    2,     0,  2264,    16,     5,   275,  1040,    13,  2239, 27428,
           116,     2,     1,     1,     1,     1,     1,     1,     1,     1,
             1,     1,     1],
        [    2,     0,  6179,    64,    38,   912, 44473,   116,     2,     1,
             1,     1,     1,     1,     1,     1,     1,     1,     1,     1,
             1,     1,     1],
        [    2,     0, 12375,    16,   110,  2674,  2182,  1569,   736,     8,
           596,   116,     2,     1,     1,     1,     1,     1,     1,     1,
             1,     1,     1],
        [    2,     0,  8275, 11656,   269,  5152,   116,     2,     1,     1,
             1,     1,     1,     1,     1,     1,     1,     1,     1,     1,
             1,     1,     1],
        [    2,     0, 32251,    16,     5,   275, 1461

In [73]:
src_inputs = []
reference_outputs = []
generated_outputs = []

for eval_batch in tqdm(eval_dataloader):
    src_inputs += [tokenizer.convert_ids_to_tokens(input_ids, skip_special_tokens=True) for input_ids in eval_batch['input_ids']]
    reference_outputs += [tokenizer.convert_ids_to_tokens(labels, skip_special_tokens=True) for labels in eval_batch['labels']]
    generate_output = bart_model.generate(eval_batch['input_ids'].to(device), num_beams=4, max_length=32)
    generated_outputs += [tokenizer.convert_ids_to_tokens(outputs, skip_special_tokens=True) for outputs in generate_output]

  0%|          | 0/420 [00:00<?, ?it/s]

In [75]:
len(generated_outputs)

13438

In [78]:
bleu_score = []
self_bleu_score = []

for src, ref, generate in zip(src_inputs, reference_outputs, generated_outputs):
    bleu_score.append(sentence_bleu([ref], generate))
    self_bleu_score.append(sentence_bleu([src], generate))


The hypothesis contains 0 counts of 3-gram overlaps.
Therefore the BLEU score evaluates to 0, independently of
how many N-gram overlaps of lower order it contains.
Consider using lower n-gram order or use SmoothingFunction()
The hypothesis contains 0 counts of 4-gram overlaps.
Therefore the BLEU score evaluates to 0, independently of
how many N-gram overlaps of lower order it contains.
Consider using lower n-gram order or use SmoothingFunction()
The hypothesis contains 0 counts of 2-gram overlaps.
Therefore the BLEU score evaluates to 0, independently of
how many N-gram overlaps of lower order it contains.
Consider using lower n-gram order or use SmoothingFunction()


In [80]:
sum(bleu_score)/len(bleu_score)

0.23770649604085697

In [81]:
sum(self_bleu_score)/len(self_bleu_score)

0.3532578247938156