In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import sys

sys.path.append("/home/vmeshchaninov/DiffusionTextGeneration-cond-ca/")

In [3]:
import torch
import numpy as np
import ml_collections
import matplotlib.pyplot as plt
import seaborn as sns

from tqdm import tqdm
from torch.utils.data import DataLoader
from diffusion_holder import DiffusionRunner
from transformers import BertConfig, BertTokenizerFast

from diffusion_holder import DiffusionRunner
from utils.util import set_seed, dict_to_cuda
from estimation_utils.util import estimate_model, reduce_metrics, gather_texts
import diffusion_utils.schedulers as schedulers

In [4]:
sns.set_theme()

In [5]:
def create_config():
    config = ml_collections.ConfigDict()

    training = config.training = ml_collections.ConfigDict()
    training.training_iters = 500_000
    training.training_iters = training.training_iters
    training.checkpoint_freq = 50_000
    training.eval_freq = 5_000
    training.batch_size = 512  # * 8

    training.ode_sampling = False
    training.checkpoints_folder = '../checkpoints/'
    config.checkpoints_prefix = ''

    loss = config.loss = ml_collections.ConfigDict()
    loss.ce_coef = 0.

    refresh = config.refresh = ml_collections.ConfigDict()
    refresh.true = False
    refresh.prefix = "./checkpoints/wikipedia--t5-bert-self_cond_500000_.pth"
    refresh.wand_id = "g5fb4af3"

    validation = config.validation = ml_collections.ConfigDict()
    validation.batch_size = 4
    validation.validation_iters = int(10_000 / validation.batch_size)
    validation.num_gen_texts = 8192
    validation.p_uncond = 0.

    dynamic = config.dynamic = ml_collections.ConfigDict()
    dynamic.solver = 'euler'
    dynamic.scheduler = "sd"
    dynamic.N = 200
    dynamic.beta_min = 0.1
    dynamic.beta_max = 20
    dynamic.ode_sampling = False
    dynamic.coef_d = 10

    model = config.model = ml_collections.ConfigDict()
    model.ema_rate = 0.9999
    model.enc_type = "base"
    model.embeddings_type = "embeddings"
    model.dif_enc_type = "base"
    model.downstream_task = ""  # "qqp"
    model.dataset = "wikipedia"  # "glue"
    model.prediction = "x_0"
    model.loss = "L_x_0"
    model.decoder_path = "decoder-wikipedia-128.pth"

    data = config.data = ml_collections.ConfigDict()
    data.max_sequence_len = 64
    data.pos_begin = 0.0
    data.pos_end = 0.67
    data.enc_bert_mean = "/home/vmeshchaninov/DiffusionTextGeneration-cond-ca/data/encodings-bert_base-wiki-mean.pt"
    data.enc_bert_std = "/home/vmeshchaninov/DiffusionTextGeneration-cond-ca/data/encodings-bert_base-wiki-std.pt"

    data.enc_t5_mean = "/home/vmeshchaninov/DiffusionTextGeneration-cond-ca/data/encodings-t5-wiki-mean.pth"
    data.enc_t5_std = "/home/vmeshchaninov/DiffusionTextGeneration-cond-ca/data/encodings-t5-wiki-std.pth"

    config.finetuning = False
    config.seed = 0
    config.ddp = False
    config.bert_config = BertConfig.from_pretrained("bert-base-uncased")
    config.use_self_cond = False
    config.project_name = "test" #"dtg-exps-1.0"
    config.timesteps = "linear"

    return config

In [6]:
config = create_config()
config.checkpoints_prefix = "wikipedia--t5-bert-initial_last_"

diffusion = DiffusionRunner(config, latent_mode=config.model.embeddings_type, eval=True)

For now, this behavior is kept to avoid breaking backwards compatibility when padding/encoding with `truncation is True`.
- Be aware that you SHOULD NOT rely on t5-base automatically truncating your input to 512 when padding/encoding.
- If you want to encode/pad to sequences longer than 512 you can either instantiate this tokenizer with `model_max_length` or pass `max_length` when encoding/padding.
Some weights of the model checkpoint at t5-base were not used when initializing T5EncoderModel: ['decoder.block.3.layer.0.layer_norm.weight', 'decoder.block.8.layer.0.SelfAttention.q.weight', 'decoder.block.10.layer.0.SelfAttention.q.weight', 'decoder.block.11.layer.1.EncDecAttention.v.weight', 'decoder.block.1.layer.0.SelfAttention.k.weight', 'decoder.block.4.layer.0.SelfAttention.k.weight', 'decoder.block.6.layer.1.EncDecAttention.k.weight', 'decoder.block.9.layer.1.EncDecAttention.k.weight', 'decoder.block.8.layer.2.DenseReluDense.wo.weight', 'decoder.block.6.layer.2.DenseReluDense.wi.wei

Dataset tokenization (num_proc=30):   0%|          | 0/38661 [00:00<?, ? examples/s]

In [7]:
diffusion.set_valid_data_generator()



In [8]:
loader = iter(diffusion.valid_loader)

In [9]:
X = next(loader)
X = dict_to_cuda(X)

with torch.no_grad():
    clean_X = diffusion.encoder_gen(**{"input_ids": X["input_ids"], "attention_mask": X["input_mask"]})
    cond_X = diffusion.encoder_cond(**{"input_ids": X["cond_ids"], "attention_mask": X["cond_mask"]})
    
cond_mask = X["cond_mask"]
attention_mask = None

In [10]:
eps_t = 1. / diffusion.dynamic.N
timesteps = torch.linspace(diffusion.dynamic.T, eps_t, diffusion.dynamic.N, device=diffusion.device)

In [40]:
x_list = []
set_seed(0)

In [41]:
with torch.no_grad():
    batch_size = config.validation.batch_size
    x = diffusion.dynamic.prior_sampling(clean_X.shape).to(diffusion.device)

    for idx in tqdm(range(diffusion.dynamic.N)):
        t = timesteps[idx]
        next_t = timesteps[idx + 1] if idx < diffusion.dynamic.N - 1 else eps_t

        input_t = t * torch.ones(batch_size, device=diffusion.device)
        next_input_t = next_t * torch.ones(batch_size, device=diffusion.device)

        output = diffusion.diff_eq_solver.step(
            x_t=x, t=input_t, next_t=next_input_t,
            cond=cond_X,
            cond_mask=cond_mask,
            attention_mask=attention_mask,
        )

        x, x_mean = output["x"], output["x_mean"]
        x_list.append(output["x"])

100%|██████████| 200/200 [00:04<00:00, 41.33it/s]


In [42]:
from PIL import Image

def make_gif_from_list_of_tensors(x_list):
    images = []
    for idx in range(diffusion.dynamic.N):
        embeds = x_list[idx][0].cpu().numpy()
        image = np.array(embeds[:, :128])
        image -= image.min()
        image /= image.max()

        image = Image.fromarray(np.uint8(image * 255))
        image = image.resize((128 * 4, 32 * 4), resample=Image.NEAREST)

        images.append(image)

    images[0].save(
                'animation-x_t.gif',
                save_all=True,
                append_images=images[1:], # append rest of the images
                duration=100, # in milliseconds
                loop=0)


In [43]:
make_gif_from_list_of_tensors(x_list)

## Text

In [87]:
text_list = []
token_list = []
set_seed(0)

In [88]:
with torch.no_grad():
    batch_size = config.validation.batch_size
    x = diffusion.dynamic.prior_sampling(clean_X.shape).to(diffusion.device)

    for idx in tqdm(range(diffusion.dynamic.N)):
        t = timesteps[idx]
        next_t = timesteps[idx + 1] if idx < diffusion.dynamic.N - 1 else eps_t

        input_t = t * torch.ones(batch_size, device=diffusion.device)
        next_input_t = next_t * torch.ones(batch_size, device=diffusion.device)

        output = diffusion.diff_eq_solver.step(
            x_t=x, t=input_t, next_t=next_input_t,
            cond=cond_X,
            cond_mask=cond_mask,
            attention_mask=attention_mask,
        )

        x, x_mean = output["x"], output["x_mean"]
        
        tokens = diffusion.pred_logits(output["x_0"]).argmax(dim=-1)[:, :16]
        token_list.append(tokens[0].cpu().numpy())
        text = diffusion.tokenizer_gen.batch_decode(tokens, skip_special_tokens=True)[0]
        text_list.append(text)

100%|██████████| 200/200 [00:05<00:00, 39.32it/s]


In [97]:
for i, t in enumerate([t.replace(",", " ,").replace(".", " .").split() for t in text_list[::10]]):
    n = 11
    s = f"{i}{' ' * (2 - len(str(i)))} "
    for c in t:
        s += f"{c}{' ' * (n - len(str(c)))}"
    print(s)

0  -          and        and        and        and        and        and        and        -          and        and        -          and        and        -          
1  -          and        and        and        and        and        and        and        and        and        and        and        and        and        -          
2  -          and        and        and        and        and        and        and        and        and        -          -          and        and        and        
3  -          and        and        -          and        and        and        and        and        and        -          -          and        and        and        
4  '          and        and        .          and        and        and        and        and        and        the        the        and        and        and        
5  and        and        and        -          -          and        and        and        and        and        and        and        and        and      

In [99]:
for i, t in enumerate(token_list[::10]):
    n = 6
    s = f"{i}{' ' * (2 - len(str(i)))} "
    for c in t:
        s += f"{c}{' ' * (n - len(str(c)))}"
    print(s)

0  101   1011  1998  1998  1998  1998  1998  1998  1998  1011  1998  1998  1011  1998  1998  1011  
1  101   1011  1998  1998  1998  1998  1998  1998  1998  1998  1998  1998  1998  1998  1998  1011  
2  101   1011  1998  1998  1998  1998  1998  1998  1998  1998  1998  1011  1011  1998  1998  1998  
3  101   1011  1998  1998  1011  1998  1998  1998  1998  1998  1998  1011  1011  1998  1998  1998  
4  101   1005  1998  1998  1012  1998  1998  1998  1998  1998  1998  1996  1996  1998  1998  1998  
5  101   1998  1998  1998  1011  1011  1998  1998  1998  1998  1998  1998  1998  1998  1998  1998  
6  101   1998  1998  2000  1996  1998  12241 1011  1998  1998  1998  1998  1998  1998  1998  1998  
7  101   1011  1998  2000  1037  1011  12241 14399 1998  1998  2004  1998  1996  1998  1011  14399 
8  101   1011  1011  1997  1037  8991  15134 15134 1998  2001  1998  1998  2000  1998  1011  2835  
9  101   2918  10515 2007  1037  1018  1011  1019  1011  2944  1998  1998  2000  1037  1011  1011  
