In [1]:
import numpy as np
import torch
import transformers
import matplotlib.pyplot as plt
import time

from transformers import BertConfig, BertTokenizerFast
from transformers import get_cosine_schedule_with_warmup
from sklearn.model_selection import train_test_split
from torch.utils.data import DataLoader
from sklearn.manifold import TSNE
from tqdm.auto import tqdm

from models import BertForDiffusion, DiffusionLM, ConditionalDiffusionLM
from data_utils import load_qqp_dataset_and_tokenizer_from_disk, QQPParaphraseDataset, load_split_qqp_dataset_and_tokenizer_from_disk
from noise_schedule import get_named_beta_schedule
from train_utils import train_conditional, evaluate_conditional
from metric_utils import calculate_bleu, calculate_rouge

%matplotlib inline

In [2]:
# dataset args
max_len = 32

# training args
batch_size = 64
device = torch.device("cuda:1")
lr = 1e-4
num_epoch = 30
weight_decay = 0
num_warmup_steps = 100

# model args
word_embedding_dim = 128
# hidden_size = 768
# num_hidden_layers = 12
# num_attention_heads = 12
# intermediate_size = 3072
hidden_size = 512
num_hidden_layers = 4
num_attention_heads = 8
intermediate_size = 2048
max_position_embeddings = max_len

encoder_type = 'from-scratch'

In [3]:
train_dataset, eval_dataset, tokenizer = load_split_qqp_dataset_and_tokenizer_from_disk(data_path="data")

# tokenized_qqp_train, tokenized_qqp_eval, tokenizer = load_qqp_dataset_and_tokenizer_from_disk(data_path="data")

rev_tokenizer = {v: k for k, v in tokenizer.items()}

print("Tokenizer vocab size:", len(tokenizer))

# train_dataset = QQPParaphraseDataset(dataset=tokenized_qqp_train, random_swap=True)
print("Training set size:", len(train_dataset))
# eval_dataset = QQPParaphraseDataset(dataset=tokenized_qqp_eval, random_swap=False)
print("Evaluation set size:", len(eval_dataset))

train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
eval_dataloader = DataLoader(eval_dataset, batch_size=batch_size, shuffle=False)

Tokenizer vocab size: 15672
Training set size: 120940
Evaluation set size: 13438


In [4]:
config = BertConfig(vocab_size=len(tokenizer), hidden_size=hidden_size, num_hidden_layers=num_hidden_layers, num_attention_heads=num_attention_heads, intermediate_size=intermediate_size, max_position_embeddings=max_position_embeddings, pad_token_id=tokenizer['[PAD]'])

config.T = 2000
# comment next line if using bit word embedding
#config.word_embedding_dim = word_embedding_dim

print(config)

betas = torch.Tensor(get_named_beta_schedule(schedule_name="cosine", num_diffusion_timesteps=config.T))

BertConfig {
  "T": 2000,
  "attention_probs_dropout_prob": 0.1,
  "classifier_dropout": null,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.1,
  "hidden_size": 512,
  "initializer_range": 0.02,
  "intermediate_size": 2048,
  "layer_norm_eps": 1e-12,
  "max_position_embeddings": 32,
  "model_type": "bert",
  "num_attention_heads": 8,
  "num_hidden_layers": 4,
  "pad_token_id": 3,
  "position_embedding_type": "absolute",
  "transformers_version": "4.19.2",
  "type_vocab_size": 2,
  "use_cache": true,
  "vocab_size": 15672
}



In [5]:
diffusion_model = ConditionalDiffusionLM(config=config, betas=betas, use_shared_weight=True, lm_head_bias=False, add_emb_noise=False, conditional_gen=True, encoder_type=encoder_type, encoder_name_or_path='bert-base-uncased', emb_type='bit').to(device)

print("Diffusion model #parameters:")
print(sum([p.numel() for p in diffusion_model.parameters()]))

print("Diffusion model #trainable parameters")
print(sum([p.numel() for p in filter(lambda p:p.requires_grad, diffusion_model.parameters())]))

using bit word embedding
set word_embedding_dim to: 14
Diffusion model #parameters:
38886414
Diffusion model #trainable parameters
38886414


In [None]:
optimizer = torch.optim.AdamW(filter(lambda p:p.requires_grad, diffusion_model.parameters()), lr=lr, weight_decay=weight_decay)
scheduler = get_cosine_schedule_with_warmup(optimizer=optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_epoch*len(train_dataloader))

In [None]:
# train loop
loss_terms_dict_lst = []
progress_bar = tqdm(range(num_epoch*len(train_dataloader)))

for epoch in range(num_epoch):
    print("epoch:",epoch+1)
    loss_terms_dict_lst.append(train_conditional(diffusion_model=diffusion_model, dataloader=train_dataloader, optimizer=optimizer, scheduler=scheduler ,progress_bar=progress_bar ,verbose=True))
    evaluate_conditional(diffusion_model=diffusion_model, dataloader=eval_dataloader,)

In [None]:
torch.save(diffusion_model.state_dict(), "checkpoints/conditional_from_scratch.pth")

In [6]:
# diffusion_model.load_state_dict(torch.load("checkpoints/20221015_2026"))
diffusion_model.load_state_dict(torch.load("checkpoints/20220906_0519"), strict=False)

_IncompatibleKeys(missing_keys=['alphas_bar', 'alphas_bar_prev'], unexpected_keys=[])

In [7]:
evaluate_conditional(diffusion_model=diffusion_model, dataloader=eval_dataloader,)

  0%|          | 0/210 [00:00<?, ?it/s]

eval loss=0.07535


tensor(0.0753, device='cuda:1')

In [34]:
diffusion_model.eval()

generated_questions_mbr5_ddim20 = diffusion_model.generate(
    dataset = eval_dataset,
    rev_tokenizer=rev_tokenizer,
    sampling_timesteps=20,
    eta=0,
    mbr=5,
    verbose=True,
)

  0%|          | 0/105 [00:00<?, ?it/s]

The hypothesis contains 0 counts of 4-gram overlaps.
Therefore the BLEU score evaluates to 0, independently of
how many N-gram overlaps of lower order it contains.
Consider using lower n-gram order or use SmoothingFunction()
The hypothesis contains 0 counts of 3-gram overlaps.
Therefore the BLEU score evaluates to 0, independently of
how many N-gram overlaps of lower order it contains.
Consider using lower n-gram order or use SmoothingFunction()
The hypothesis contains 0 counts of 2-gram overlaps.
Therefore the BLEU score evaluates to 0, independently of
how many N-gram overlaps of lower order it contains.
Consider using lower n-gram order or use SmoothingFunction()


In [35]:
bleu_dict = calculate_bleu(generated_questions_mbr5_ddim20, eval_dataset, rev_tokenizer)
print(sum(bleu_dict["bleu"])/len(bleu_dict["bleu"]))
print(sum(bleu_dict["self_bleu"])/len(bleu_dict["self_bleu"]))

0.15155854506376149
0.22320761914790413


In [36]:
rouge_scores = calculate_rouge(generated_questions_mbr5_ddim20, eval_dataset, rev_tokenizer)
rouge_l_f = [d['rouge-l']['f'] for d in rouge_scores]
print(sum(rouge_l_f)/len(rouge_l_f))

0.5461694549349757


In [33]:
i = 51
src_question = [rev_tokenizer[id.item()] for id in eval_dataset[i]['question1_input_ids']]
src_question = list(filter(lambda x:x not in ['[PAD]','[START]','[END]'], src_question))
print(" ".join(src_question))

tgt_question = [rev_tokenizer[id.item()] for id in eval_dataset[i]['question2_input_ids']]
tgt_question = list(filter(lambda x:x not in ['[PAD]','[START]','[END]'], tgt_question))
print(" ".join(tgt_question))

#print(" ".join(generated_questions_mbr1_ddim200[i]))
print(" ".join(generated_questions_mbr15_ddim2[i]))
#print(" ".join(generated_questions_mbr5_ddim20[i]))

What are some ways to help a teenager overcome depression ?
How does anyone overcome depression ?
How do I overcome depression ?


In [7]:
diffusion_model.eval()

generated_questions_mbr5 = diffusion_model.generate(
    dataset = eval_dataset,
    rev_tokenizer=rev_tokenizer,
    sampling_timesteps=200,
    eta=0,
    mbr=5,
    verbose=True,
)

  0%|          | 0/105 [00:00<?, ?it/s]

The hypothesis contains 0 counts of 4-gram overlaps.
Therefore the BLEU score evaluates to 0, independently of
how many N-gram overlaps of lower order it contains.
Consider using lower n-gram order or use SmoothingFunction()
The hypothesis contains 0 counts of 3-gram overlaps.
Therefore the BLEU score evaluates to 0, independently of
how many N-gram overlaps of lower order it contains.
Consider using lower n-gram order or use SmoothingFunction()
The hypothesis contains 0 counts of 2-gram overlaps.
Therefore the BLEU score evaluates to 0, independently of
how many N-gram overlaps of lower order it contains.
Consider using lower n-gram order or use SmoothingFunction()


In [9]:
bleu_dict = calculate_bleu(generated_questions_mbr5, eval_dataset, rev_tokenizer)
print(sum(bleu_dict["bleu"])/len(bleu_dict["bleu"]))
print(sum(bleu_dict["self_bleu"])/len(bleu_dict["self_bleu"]))

The hypothesis contains 0 counts of 4-gram overlaps.
Therefore the BLEU score evaluates to 0, independently of
how many N-gram overlaps of lower order it contains.
Consider using lower n-gram order or use SmoothingFunction()
The hypothesis contains 0 counts of 2-gram overlaps.
Therefore the BLEU score evaluates to 0, independently of
how many N-gram overlaps of lower order it contains.
Consider using lower n-gram order or use SmoothingFunction()
The hypothesis contains 0 counts of 3-gram overlaps.
Therefore the BLEU score evaluates to 0, independently of
how many N-gram overlaps of lower order it contains.
Consider using lower n-gram order or use SmoothingFunction()


0.18039911311842954
0.2751321938279151


In [10]:
rouge_scores = calculate_rouge(generated_questions_mbr5, eval_dataset, rev_tokenizer)
rouge_l_f = [d['rouge-l']['f'] for d in rouge_scores]
print(sum(rouge_l_f)/len(rouge_l_f))

0.5874787212343898


In [13]:
eval_batch = next(iter(eval_dataloader))

In [14]:
from nltk.translate.bleu_score import sentence_bleu

In [22]:
generated_questions = []
bs = eval_batch['question1_input_ids'].shape[0]
mbr = 5
if mbr > 1:     # using MBR decoding
    batch_questions = []
    for cnt in range(mbr):
        x_T = torch.randn(size=(bs,
                                diffusion_model.config.max_position_embeddings,
                                diffusion_model.config.word_embedding_dim))
        final_hidden_state = diffusion_model.ddim_sample(x_T.to(device),
                                              sampling_timesteps=200,
                                              eta=0,
                                              src_ids=eval_batch['question1_input_ids'].to(device),
                                              src_attention_mask=eval_batch['question1_attention_mask'].to(device),
                                              return_hidden_states=False,
                                              verbose=False
                                              )
        sampled_ids = diffusion_model.rounding(final_hidden_state).cpu()
        questions = [[rev_tokenizer[token_id.item()] for token_id in sampled_id] for sampled_id in sampled_ids]
        batch_questions.append([list(filter(lambda x: x not in ['[PAD]', '[START]', '[END]'], question)) for question in questions])
    # batch_questions [mbr, bs, question]
    for batch_ind in range(bs):
        candidates = [one_generation[batch_ind] for one_generation in batch_questions]      # [mbr, question]
        bleu_scores = torch.zeros(mbr)
        for candidate_ind, candidate in enumerate(candidates):
            for ref_ind, ref in enumerate(candidates):
                if ref_ind != candidate_ind:
                    bleu_scores[candidate_ind] += sentence_bleu([ref], candidate)
        select_ind = torch.argmax(bleu_scores).item()
        generated_questions.append(batch_questions[select_ind][batch_ind])

The hypothesis contains 0 counts of 4-gram overlaps.
Therefore the BLEU score evaluates to 0, independently of
how many N-gram overlaps of lower order it contains.
Consider using lower n-gram order or use SmoothingFunction()
The hypothesis contains 0 counts of 3-gram overlaps.
Therefore the BLEU score evaluates to 0, independently of
how many N-gram overlaps of lower order it contains.
Consider using lower n-gram order or use SmoothingFunction()
The hypothesis contains 0 counts of 2-gram overlaps.
Therefore the BLEU score evaluates to 0, independently of
how many N-gram overlaps of lower order it contains.
Consider using lower n-gram order or use SmoothingFunction()


In [None]:
bert_tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased")

In [None]:
diffusion_model.eval()

with torch.no_grad():
    x_T = torch.randn(size=(batch_size, max_len, diffusion_model.config.word_embedding_dim))
    final_hidden_state, hidden_states = diffusion_model.sample(x_T.to(device),
                                        src_ids=eval_batch['question1_input_ids'].to(device),
                                        src_attention_mask=eval_batch['question1_attention_mask'].to(device),
                                        return_hidden_states=True,
                                        verbose=True)

In [None]:
diffusion_model.eval()

with torch.no_grad():
    x_T = torch.randn(size=(batch_size, max_len, diffusion_model.config.word_embedding_dim))
    final_hidden_state_ddim, hidden_states_ddim = diffusion_model.ddim_sample(x_T.to(device), sampling_timesteps=200, src_ids=eval_batch['question1_input_ids'].to(device),src_attention_mask=eval_batch['question1_attention_mask'].to(device), return_hidden_states=True, verbose=True)

In [None]:
diffusion_model.word_embeddings

In [None]:
# src_questions = bert_tokenizer.batch_decode(eval_batch['question1_input_ids_bert'], skip_special_tokens=True)
src_questions = [[rev_tokenizer[id.item()] for id in ids] for ids in eval_batch['question1_input_ids']]
for sample_idx in range(batch_size):
    print("idx:", sample_idx)
    # print(src_questions[sample_idx])
    print(" ".join(src_questions[sample_idx]))

In [None]:
# target_questions = bert_tokenizer.batch_decode(eval_batch['question2_input_ids_bert'], skip_special_tokens=True)
target_questions = [[rev_tokenizer[id.item()] for id in ids] for ids in eval_batch['question2_input_ids']]
for sample_idx in range(batch_size):
    print("idx:", sample_idx)
    # print(target_questions[sample_idx])
    print(" ".join(target_questions[sample_idx]))

In [None]:
with torch.no_grad():
    sampled_ids = diffusion_model.rounding(hidden_states[-1])
    generated_questions = [[rev_tokenizer[token_id.item()] for token_id in sampled_id] for sampled_id in sampled_ids]
for sample_idx in range(batch_size):
    print("sample_idx:", sample_idx)
    print(" ".join(generated_questions[sample_idx]))

In [None]:
with torch.no_grad():
    sampled_ids = diffusion_model.rounding(hidden_states_ddim[-1])
    generated_questions = [[rev_tokenizer[token_id.item()] for token_id in sampled_id] for sampled_id in sampled_ids]
for sample_idx in range(batch_size):
    print("sample_idx:", sample_idx)
    print(" ".join(generated_questions[sample_idx]))

In [None]:
hidden_states_ddim[199]

In [None]:
diffusion_model

In [None]:
sample_idx = 63
for step in [1000,1900,1940,1980,1990,1993,1994,1995,1996,1997,1998,-1]:
    hidden_state = hidden_states[step][sample_idx]
    with torch.no_grad():
        sampled_ids = diffusion_model.rounding(hidden_state)
        sampled_seq = [rev_tokenizer[token_id.item()] for token_id in sampled_ids]
        print("step:", step)
        print(" ".join(sampled_seq))

In [None]:
sample_idx = 63
for step in [0,150,180,190,195,197,198,-1]:
    hidden_state = hidden_states_ddim[step][sample_idx]
    with torch.no_grad():
        sampled_ids = diffusion_model.rounding(hidden_state)
        sampled_seq = [rev_tokenizer[token_id.item()] for token_id in sampled_ids]
        print("step:", step)
        print(" ".join(sampled_seq))

In [None]:
train_batch = next(iter(train_dataloader))

In [None]:
final_hidden_state2, hidden_states2 = diffusion_model.sample(x_T.to(device),
                                                           src_ids=train_batch['question1_input_ids'].to(device),
                                                           src_attention_mask=train_batch['question1_attention_mask'].to(device),
                                                           return_hidden_states=True,
                                                           verbose=True)

In [None]:
# src_questions = bert_tokenizer.batch_decode(eval_batch['question1_input_ids_bert'], skip_special_tokens=True)
src_questions = [[rev_tokenizer[id.item()] for id in ids] for ids in train_batch['question1_input_ids']]
for sample_idx in range(batch_size):
    print("idx:", sample_idx)
    # print(src_questions[sample_idx])
    print(" ".join(src_questions[sample_idx]))

In [None]:
# target_questions = bert_tokenizer.batch_decode(eval_batch['question2_input_ids_bert'], skip_special_tokens=True)
target_questions = [[rev_tokenizer[id.item()] for id in ids] for ids in train_batch['question2_input_ids']]
for sample_idx in range(batch_size):
    print("idx:", sample_idx)
    # print(target_questions[sample_idx])
    print(" ".join(target_questions[sample_idx]))

In [None]:
with torch.no_grad():
    sampled_ids = diffusion_model.rounding(hidden_states2[-1])
    generated_questions = [[rev_tokenizer[token_id.item()] for token_id in sampled_id] for sampled_id in sampled_ids]
for sample_idx in range(batch_size):
    print("sample_idx:", sample_idx)
    print(" ".join(generated_questions[sample_idx]))

In [None]:
bert_tokenizer.decode(eval_batch['question1_input_ids_bert'][0],skip_special_tokens=True)

In [None]:
bert_tokenizer.convert_ids_to_tokens(eval_batch['question1_input_ids_bert'][59], skip_special_tokens=True)

In [None]:
list(filter(lambda x:x not in ['[PAD]','[START]','[END]'], target_questions[0]))

In [None]:
list(filter(lambda x:x not in ['[PAD]','[START]','[END]'], generated_questions[0]))

In [None]:
bleu_score = []
for target, generate in zip(target_questions, generated_questions):
    bleu_score.append(sentence_bleu([list(filter(lambda x:x not in ['[PAD]','[START]','[END]'], target))],
    list(filter(lambda x:x not in ['[PAD]','[START]','[END]'], generate))))

In [None]:
sum(bleu_score)/len(bleu_score)

In [None]:
bleu_score = []
for src, generate in zip(src_questions, generated_questions):
    bleu_score.append(sentence_bleu([list(filter(lambda x:x not in ['[PAD]','[START]','[END]'], src))],
                                    list(filter(lambda x:x not in ['[PAD]','[START]','[END]'], generate))))

In [None]:
sum(bleu_score)/len(bleu_score)

In [None]:
len(eval_dataset)

In [None]:
eval_dataset[0]

In [None]:
small_dataset, rest_dataset = torch.utils.data.random_split(eval_dataset, [500, 12938])

In [None]:
generated_questions_mbr1 = diffusion_model.generate(
    dataset = small_dataset,
    rev_tokenizer=rev_tokenizer,
    sampling_timesteps=200,
    eta=0,
    mbr=1,
)

In [None]:
generated_questions_mbr1[0]

In [None]:
small_dataset[0]['question1_input_ids']

In [28]:
[rev_tokenizer[i.item()] for i in small_dataset[0]['question2_input_ids']]

['[START]',
 'Which',
 'fruits',
 'or',
 'vegetables',
 'should',
 'be',
 'eaten',
 'regularly',
 'to',
 'get',
 'vitamins',
 '?',
 '[END]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]']

In [30]:
def calculate_bleu(generated_questions, dataset, rev_tokenizer):
    """
    calculate BLEU metric
    :param generated_questions: list[token_list]
    :param dataset: pytorch dataset
    :param rev_tokenizer: token_id to token dict
    :return: {"bleu": val_list, "self_bleu": val_list}
    """
    bleu, self_bleu = [],[]
    for ind, sample in enumerate(dataset):
        src_question = [rev_tokenizer[id.item()] for id in sample['question1_input_ids']]
        src_question = list(filter(lambda x:x not in ['[PAD]','[START]','[END]'], src_question))
        tgt_question = [rev_tokenizer[id.item()] for id in sample['question2_input_ids']]
        tgt_question = list(filter(lambda x:x not in ['[PAD]','[START]','[END]'], tgt_question))
        bleu.append(sentence_bleu([tgt_question], generated_questions[ind]))
        self_bleu.append(sentence_bleu([src_question], generated_questions[ind]))

    return {"bleu": bleu, "self_bleu": self_bleu}

In [31]:
bleu_dict = calculate_bleu(generated_questions_mbr1, small_dataset, rev_tokenizer)

The hypothesis contains 0 counts of 3-gram overlaps.
Therefore the BLEU score evaluates to 0, independently of
how many N-gram overlaps of lower order it contains.
Consider using lower n-gram order or use SmoothingFunction()
The hypothesis contains 0 counts of 4-gram overlaps.
Therefore the BLEU score evaluates to 0, independently of
how many N-gram overlaps of lower order it contains.
Consider using lower n-gram order or use SmoothingFunction()
The hypothesis contains 0 counts of 2-gram overlaps.
Therefore the BLEU score evaluates to 0, independently of
how many N-gram overlaps of lower order it contains.
Consider using lower n-gram order or use SmoothingFunction()


In [34]:
sum(bleu_dict["bleu"])/len(bleu_dict["bleu"])

0.15005539419638644

In [35]:
sum(bleu_dict["self_bleu"])/len(bleu_dict["self_bleu"])

0.2198190329545948

In [44]:
" ".join(generated_questions_mbr1[0])

'Which is the better fruits to get eaten at critics ?'

In [46]:
from rouge import Rouge

In [45]:
def calculate_rouge(generated_questions, dataset, rev_tokenizer):
    rouge = Rouge()
    rouge_scores = []
    for ind, sample in enumerate(dataset):
        tgt_question = [rev_tokenizer[id.item()] for id in sample['question2_input_ids']]
        tgt_question = list(filter(lambda x:x not in ['[PAD]','[START]','[END]'], tgt_question))
        rouge_scores += rouge.get_scores(" ".join(generated_questions[ind]), " ".join(tgt_question))

    return rouge_scores

In [48]:
rouge_scores = calculate_rouge(generated_questions_mbr1, small_dataset, rev_tokenizer)

In [50]:
rouge_scores[0]

{'rouge-1': {'r': 0.5, 'p': 0.5454545454545454, 'f': 0.5217391254442345},
 'rouge-2': {'r': 0.09090909090909091, 'p': 0.1, 'f': 0.09523809024943337},
 'rouge-l': {'r': 0.4166666666666667,
  'p': 0.45454545454545453,
  'f': 0.434782603705104}}

In [51]:
rouge_l_f = [d['rouge-l']['f'] for d in rouge_scores]

In [52]:
rouge_l_f

[0.434782603705104,
 0.6666666616666668,
 0.42105262659279785,
 0.5185185135253774,
 0.6666666617283951,
 0.5999999950500001,
 0.42424241928374656,
 0.33333332839506175,
 0.35294117148788934,
 0.8888888839506174,
 0.8235294069204152,
 0.4347826039319471,
 0.7058823479584776,
 0.7407407357475996,
 0.47999999500800006,
 0.4705882303114187,
 0.6666666616666668,
 0.24999999545000007,
 0.4444444395061729,
 0.9729729679766254,
 0.6153846108579881,
 0.4999999950347222,
 0.38461537988165684,
 0.6249999950195313,
 0.47619047129251707,
 0.34782608204158794,
 0.7142857093877552,
 0.5555555511111112,
 0.8571428522448981,
 0.5454545406198348,
 0.6666666616666668,
 0.4285714236734694,
 0.5333333283555556,
 0.2399999953920001,
 0.518518513580247,
 0.8965517191914388,
 0.35897435437212366,
 0.6315789423822715,
 0.31578946869806096,
 0.44444444000000005,
 0.5555555505555557,
 0.5714285664399092,
 0.7999999950080001,
 0.3448275812128419,
 0.3529411716262976,
 0.37499999500000003,
 0.4999999950347222,
 0