In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os
import json
import torch
import argparse
import torch.distributed as dist
import ml_collections
from datasets import disable_progress_bar
from transformers import BertConfig
import sys

sys.path.append("/home/vmeshchaninov/DiffusionTextGeneration-cond-ca/")

from diffusion_holder import DiffusionRunner
from utils.util import set_seed, _BERT_SMALL, dict_to_cuda
from estimation_utils.util import estimate_model, reduce_metrics, gather_texts
import diffusion_utils.schedulers as schedulers

In [17]:
def create_config():
    config = ml_collections.ConfigDict()

    training = config.training = ml_collections.ConfigDict()
    training.ode_sampling = False
    training.checkpoints_folder = '../checkpoints'
    training.batch_size = 512
    config.checkpoints_prefix = None

    validation = config.validation = ml_collections.ConfigDict()
    validation.batch_size = 512

    sde = config.sde = ml_collections.ConfigDict()
    sde.typename = 'vp-sde'
    sde.solver = 'euler'
    sde.N = 200
    sde.beta_min = 0.1
    sde.beta_max = 20
    sde.ode_sampling = False
    sde.scheduler = schedulers.CosineSD(d=10)

    model = config.model = ml_collections.ConfigDict()
    model.ema_rate = 0.9999
    model.enc_type = "base"
    model.embeddings_type = "encodings"
    model.dif_enc_type = "base"
    model.downstream_task = "sst2"  # "qqp"
    model.dataset = "wikipedia"  # "glue"
    model.prediction = "x_0"
    model.loss = "L_x_0"
    model.decoder_path = "decoder-wikipedia-128.pth"  # "decoder-wikipedia-128.pth"  # "decoder-t5_base-wikipedia-128.pth"

    data = config.data = ml_collections.ConfigDict()
    data.max_sequence_len = 96
    data.enc_bert_mean = "/home/vmeshchaninov/DiffusionTextGeneration-cond-ca/data/encodings-bert_base-wiki-mean.pt"
    data.enc_bert_std = "/home/vmeshchaninov/DiffusionTextGeneration-cond-ca/data/encodings-bert_base-wiki-std.pt"
    data.enc_t5_mean = "/home/vmeshchaninov/DiffusionTextGeneration-cond-ca/data/encodings-t5-wiki-mean.pth"
    data.enc_t5_std = "/home/vmeshchaninov/DiffusionTextGeneration-cond-ca/data/encodings-t5-wiki-std.pth"

    config.finetuning = False
    config.seed = 0
    config.ddp = False
    config.bert_config = BertConfig.from_pretrained("bert-base-uncased")

    config.project_name = "bert-conditional-exps"

    return config

In [18]:
config = create_config()
config.checkpoints_prefix = "wikipedia-sst2-prediction=x_0-loss=L_x_0-enc=base-bert=base-kl_cf=0.0-seq_len=96-clipgrad=1.0-lr=0.0002-min_lr=0.0002-lin_input=True-seed=0-wd=0.01-batch=512-t5-bert-womask_1000000_"

diffusion = DiffusionRunner(config, latent_mode=config.model.embeddings_type, eval=True)

Some weights of the model checkpoint at t5-base were not used when initializing T5EncoderModel: ['decoder.block.1.layer.0.SelfAttention.k.weight', 'decoder.block.11.layer.0.SelfAttention.q.weight', 'decoder.block.5.layer.1.layer_norm.weight', 'decoder.block.8.layer.1.layer_norm.weight', 'decoder.block.3.layer.1.EncDecAttention.q.weight', 'decoder.block.3.layer.0.SelfAttention.k.weight', 'decoder.block.5.layer.2.DenseReluDense.wi.weight', 'decoder.block.8.layer.0.layer_norm.weight', 'decoder.block.1.layer.2.DenseReluDense.wo.weight', 'decoder.block.10.layer.2.layer_norm.weight', 'decoder.block.1.layer.1.layer_norm.weight', 'decoder.block.0.layer.1.EncDecAttention.v.weight', 'decoder.block.10.layer.0.SelfAttention.o.weight', 'decoder.block.9.layer.1.EncDecAttention.k.weight', 'decoder.block.10.layer.1.EncDecAttention.q.weight', 'decoder.block.3.layer.1.EncDecAttention.k.weight', 'decoder.block.7.layer.0.SelfAttention.q.weight', 'decoder.block.6.layer.1.EncDecAttention.v.weight', 'decoder

In [19]:
batch_size = 64
diffusion.config.validation.bnnatch_size = batch_size

diffusion.set_valid_data_generator()
loader = iter(diffusion.valid_loader)

In [20]:
condition = next(loader)
cond = dict_to_cuda({"cond": condition["cond_ids"], "cond_mask": condition["cond_mask"]},)
cond_X = diffusion.encoder_cond(**{"input_ids": cond["cond"], "attention_mask": cond["cond_mask"]})

In [21]:
pred_embeddings = diffusion.pred_embeddings(batch_size, cond_X=cond_X, cond_mask=cond["cond_mask"])

100%|██████████| 200/200 [00:19<00:00, 10.30it/s]


In [22]:
pred_embeddings = diffusion.gen_enc_normalizer.denormalize(pred_embeddings)

In [23]:
logits = diffusion.decoder(pred_embeddings)

In [24]:
probs = torch.softmax(logits, dim=-1)

In [32]:
n = 9
k = 3

p_s = (torch.topk(probs, k=k, dim=-1)[0][n] * 100).int()
ind_s = (torch.topk(probs, k=k, dim=-1)[1][n]).int()

n_chars_per_col = 30
sep = ' ' 

print(f"probabilities:{sep * (n_chars_per_col - 14)}indexes:{sep * (n_chars_per_col - 8)}")
for i in range(96):
    col1 = f"{p_s[i].tolist()}"
    col2 = f"{ind_s[i].tolist()}"

    col3 = ""
    for ind in ind_s[i].tolist():
        token = diffusion.tokenizer_gen.decode(ind)
        col3 += f"{token},  "    
    
    print(f"{col1}{sep * (n_chars_per_col - len(col1))}" \
          f"{col2}{sep * (n_chars_per_col - len(col2))}" \
          f"{col3}{sep * (n_chars_per_col - len(col3))}"
         )

probabilities:                indexes:                      
[100, 0, 0]                   [101, 1037, 1043]             [CLS],  a,  g,                
[100, 0, 0]                   [1001, 1008, 1526]            #,  *,  †,                    
[100, 0, 0]                   [1001, 1008, 1030]            #,  *,  @,                    
[99, 0, 0]                    [1045, 1051, 9932]            i,  o,  ai,                   
[91, 0, 0]                    [3211, 3089, 5638]            ##ki,  ##ri,  ##bi,           
[32, 3, 3]                    [3067, 13173, 23214]          mine,  kara,  meiji,          
[100, 0, 0]                   [1000, 1006, 1524]            ",  (,  ”,                    
[99, 0, 0]                    [1012, 102, 1037]             .,  [SEP],  a,                
[71, 18, 2]                   [10556, 5003, 7842]           ka,  ma,  sa,                 
[14, 13, 12]                  [28260, 23778, 16566]         ##kura,  ##rika,  ##kara,     
[67, 3, 3]                   

In [43]:
diffusion.gen_tokenizer.decode()

torch.Size([64, 96, 30522])