In [None]:
import numpy as np
import torch
import transformers
import matplotlib.pyplot as plt

from transformers import BertConfig
from transformers import get_cosine_schedule_with_warmup
from sklearn.model_selection import train_test_split
from torch.utils.data import DataLoader
from sklearn.manifold import TSNE
from tqdm.auto import tqdm

from models import BertForDiffusion, DiffusionLM
from data_utils import load_e2enlg_dataset_and_tokenizer, E2enlgDataset, load_rocstories_dataset_and_tokenizer, RocstoriesDataset
from noise_schedule import get_named_beta_schedule
from train_utils import train, evaluate

%matplotlib inline

In [None]:
# dataset args
max_len = 72    # maximum length of input_ids
vocab_threshold = 10    # occurrence time < threshold token as [UNK]
test_size = 0.1     # size of evaluation dataset

# training args
batch_size = 64
device = torch.device("cuda:0")
lr = 2e-4
num_epoch = 50
weight_decay = 0
num_warmup_steps = 100

# model args
word_embedding_dim = 128
# hidden_size = 768
# num_hidden_layers = 12
# num_attention_heads = 12
# intermediate_size = 3072
hidden_size = 512
num_hidden_layers = 4
num_attention_heads = 8
intermediate_size = 2048
max_position_embeddings = max_len

In [None]:
tokenized_rocstories_dataset, tokenizer = load_rocstories_dataset_and_tokenizer(max_len=max_len, vocab_threshold=vocab_threshold)

rev_tokenizer = {v: k for k, v in tokenizer.items()}

train_set, eval_set = train_test_split(tokenized_rocstories_dataset, test_size=test_size, shuffle=True)

train_dataset = RocstoriesDataset(data_lst=train_set['input_ids'], attention_mask_lst=train_set['attention_mask'])
print("Training set size:",len(train_dataset))
eval_dataset = RocstoriesDataset(data_lst=eval_set['input_ids'], attention_mask_lst=eval_set['attention_mask'])
print("Evaluation set size:", len(eval_dataset))

train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
eval_dataloader = DataLoader(eval_dataset, batch_size=batch_size, shuffle=False)

In [None]:
config = BertConfig(vocab_size=len(tokenizer), hidden_size=hidden_size, num_hidden_layers=num_hidden_layers, num_attention_heads=num_attention_heads, intermediate_size=intermediate_size, max_position_embeddings=max_position_embeddings, pad_token_id=tokenizer['[PAD]'])

config.T = 2000
config.word_embedding_dim = word_embedding_dim

print(config)

In [None]:
betas = torch.Tensor(get_named_beta_schedule(schedule_name="sqrt", num_diffusion_timesteps=config.T))
# betas = torch.Tensor(get_named_beta_schedule(schedule_name="linear", num_diffusion_timesteps=config.T))

alphas = 1. - betas
alphas_bar = torch.cumprod(alphas, dim=0)
sqrt_one_minus_alphas_bar = torch.sqrt(1. - alphas_bar)
plt.plot(sqrt_one_minus_alphas_bar)

In [None]:
diffusion_model = DiffusionLM(config=config, betas=betas, use_shared_weight=True, lm_head_bias=False, add_emb_noise=False).to(device)

print("Diffusion model #parameters:")
print(sum([p.numel() for p in diffusion_model.parameters()]))

optimizer = torch.optim.AdamW(diffusion_model.parameters(), lr=lr, weight_decay=weight_decay)
scheduler = get_cosine_schedule_with_warmup(optimizer=optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_epoch*len(train_dataloader))

In [None]:
loss_terms_dict_lst = []
progress_bar = tqdm(range(num_epoch*len(train_dataloader)))

for epoch in range(num_epoch):
    print("epoch:",epoch+1)
    loss_terms_dict_lst.append(train(diffusion_model=diffusion_model, dataloader=train_dataloader, optimizer=optimizer, scheduler=scheduler ,progress_bar=progress_bar ,verbose=True))
    evaluate(diffusion_model=diffusion_model, dataloader=eval_dataloader,)

In [None]:
torch.save(diffusion_model.state_dict(), "checkpoints/epoch50_unshared_dim24.pth")

In [None]:
loss_terms_dict = {'mse':[], 'L_T':[], 'rounding':[]}
for key in loss_terms_dict_lst[0].keys():
    for ep in range(num_epoch):
        loss_terms_dict[key] += loss_terms_dict_lst[ep][key]

In [None]:
plt.plot(loss_terms_dict['mse'], label='mse')
plt.plot(loss_terms_dict['rounding'], label='rounding')
plt.legend()
plt.yscale('log')

In [None]:
diffusion_model.load_state_dict(torch.load("checkpoints/roc_unshared_dim128.pth"))

In [None]:
x_T = torch.randn(size=(batch_size, max_len, word_embedding_dim))

In [None]:
logits, hidden_states = diffusion_model.sample(x_T.to(device), return_hidden_states=True, verbose=True)

In [None]:
sample_idx = 62
for step in [0,1000,1500,1800,1900,1950,1990,1995,1998,1999]:
    hidden_state = hidden_states[step][sample_idx]
    with torch.no_grad():
        hidden_logits = diffusion_model.lm_head(hidden_state)
        sampled_ids = torch.argmax(hidden_logits,dim=-1).cpu()
        sampled_seq = [rev_tokenizer[token_id.item()] for token_id in sampled_ids]
        print("step:", step)
        print(" ".join(sampled_seq))

In [None]:
for sample_idx in range(batch_size):
    hidden_state = hidden_states[-1][sample_idx]
    with torch.no_grad():
        hidden_logits = diffusion_model.lm_head(hidden_state)
        sampled_ids = torch.argmax(hidden_logits,dim=-1).cpu()
        sampled_seq = [rev_tokenizer[token_id.item()] for token_id in sampled_ids]
        print("sample_idx:", sample_idx)
        print(" ".join(sampled_seq))

In [None]:
diffusion_model.config.word_embedding_dim

In [None]:
logits2, hidden_states2 = diffusion_model.sample(x_T.to(device), clamp='rounding', return_hidden_states=True, verbose=True)

In [None]:
for sample_idx in range(32):
    hidden_state = hidden_states2[-1][sample_idx]
    with torch.no_grad():
        hidden_logits = diffusion_model.lm_head(hidden_state)
        sampled_ids = torch.argmax(hidden_logits,dim=-1).cpu()
        sampled_seq = [rev_tokenizer[token_id.item()] for token_id in sampled_ids]
        print("sample_idx:", sample_idx)
        print(" ".join(sampled_seq))

In [None]:
for n,p in diffusion_model.named_parameters():
    print(n)

In [None]:
diffusion_model.betas

In [None]:
diffusion_model.word_embeddings.weight

In [None]:
diffusion_model.lm_head.weight

In [None]:
diffusion_model.word_embeddings.weight[3]

In [None]:
diffusion_model.lm_head

In [None]:
print(diffusion_model.lm_head.bias.data[:10])

In [None]:
learned_emb = diffusion_model.word_embeddings.weight.data.cpu().numpy()

In [None]:
learned_emb.shape

In [None]:
emb_2d = TSNE(learning_rate='auto').fit_transform(learned_emb)

In [None]:
emb_2d.shape

In [None]:
plt.scatter(x=emb_2d[:,0], y=emb_2d[:,1])

In [None]:
learned_rounding = diffusion_model.lm_head.weight.data.cpu().numpy()

In [None]:
learned_rounding.shape

In [None]:
plt.figure(figsize=(8,8))
plt.imshow(np.matmul(learned_emb[:20,:], learned_rounding[:20,:].T), cmap='gray')

In [None]:
t = torch.randn(size=(100,10))
plt.imshow(torch.matmul(t, t.T), cmap='gray')