In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os
import json
import torch
import argparse
import torch.distributed as dist
import ml_collections
from datasets import disable_progress_bar
from transformers import BertConfig
import sys

sys.path.append("/home/vmeshchaninov/DiffusionTextGeneration-cond-ca/")

from diffusion_holder import DiffusionRunner
from utils.util import set_seed, _BERT_SMALL, dict_to_cuda
from estimation_utils.util import estimate_model, reduce_metrics, gather_texts
import diffusion_utils.schedulers as schedulers

In [3]:
def create_config():
    config = ml_collections.ConfigDict()

    training = config.training = ml_collections.ConfigDict()
    training.ode_sampling = False
    training.checkpoints_folder = '../checkpoints'
    training.batch_size = 512
    config.checkpoints_prefix = None

    validation = config.validation = ml_collections.ConfigDict()
    validation.batch_size = 512

    sde = config.sde = ml_collections.ConfigDict()
    sde.typename = 'vp-sde'
    sde.solver = 'euler'
    sde.N = 200
    sde.beta_min = 0.1
    sde.beta_max = 20
    sde.ode_sampling = False
    sde.scheduler = schedulers.CosineSD(d=10)

    model = config.model = ml_collections.ConfigDict()
    model.ema_rate = 0.9999
    model.enc_type = "base"
    model.embeddings_type = "encodings"
    model.dif_enc_type = "base"
    model.downstream_task = "sst2"  # "qqp"
    model.dataset = "wikipedia"  # "glue"
    model.prediction = "x_0"
    model.loss = "L_x_0"
    model.decoder_path = "decoder-wikipedia-128.pth"  # "decoder-wikipedia-128.pth"  # "decoder-t5_base-wikipedia-128.pth"

    data = config.data = ml_collections.ConfigDict()
    data.max_sequence_len = 96
    data.enc_bert_mean = "/home/vmeshchaninov/DiffusionTextGeneration-cond-ca/data/encodings-bert_base-wiki-mean.pt"
    data.enc_bert_std = "/home/vmeshchaninov/DiffusionTextGeneration-cond-ca/data/encodings-bert_base-wiki-std.pt"
    data.enc_t5_mean = "/home/vmeshchaninov/DiffusionTextGeneration-cond-ca/data/encodings-t5-wiki-mean.pth"
    data.enc_t5_std = "/home/vmeshchaninov/DiffusionTextGeneration-cond-ca/data/encodings-t5-wiki-std.pth"

    config.finetuning = False
    config.seed = 0
    config.ddp = False
    config.bert_config = BertConfig.from_pretrained("bert-base-uncased")

    config.project_name = "bert-conditional-exps"

    return config

In [4]:
config = create_config()
config.checkpoints_prefix = "wikipedia-sst2-prediction=x_0-loss=L_x_0-enc=base-bert=base-kl_cf=0.0-seq_len=96-clipgrad=1.0-lr=0.0002-min_lr=0.0002-lin_input=True-seed=0-wd=0.01-batch=512-t5-bert-womask_1000000_"

diffusion = DiffusionRunner(config, latent_mode=config.model.embeddings_type, eval=True)

For now, this behavior is kept to avoid breaking backwards compatibility when padding/encoding with `truncation is True`.
- Be aware that you SHOULD NOT rely on t5-base automatically truncating your input to 512 when padding/encoding.
- If you want to encode/pad to sequences longer than 512 you can either instantiate this tokenizer with `model_max_length` or pass `max_length` when encoding/padding.
Some weights of the model checkpoint at t5-base were not used when initializing T5EncoderModel: ['decoder.block.10.layer.0.SelfAttention.v.weight', 'decoder.block.0.layer.1.EncDecAttention.o.weight', 'decoder.block.1.layer.2.DenseReluDense.wo.weight', 'decoder.block.0.layer.1.layer_norm.weight', 'decoder.block.10.layer.0.layer_norm.weight', 'decoder.block.4.layer.1.layer_norm.weight', 'decoder.block.0.layer.0.SelfAttention.k.weight', 'decoder.block.10.layer.0.SelfAttention.o.weight', 'decoder.block.6.layer.0.layer_norm.weight', 'decoder.block.6.layer.2.layer_norm.weight', 'decoder.block.2.la

In [5]:
from estimation_utils.metrics import BloomMetricConditional, BloomMetric
from estimation_utils.util import compute_metric

In [6]:
metric_bloom_fn = BloomMetricConditional(device="cuda:0")

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [10]:
batch_size = 64

In [15]:
cond = [""] * batch_size

cond = diffusion.tokenizer_gen(cond, 
                               return_tensors="pt", 
                               add_special_tokens=True,
                               padding="max_length",
                               truncation=True,
                               max_length=96)
cond = {"cond": cond["input_ids"], "cond_mask": cond["attention_mask"]}

In [16]:
text, _ = diffusion.generate_text(batch_size, cond=cond)

100%|██████████| 200/200 [00:19<00:00, 10.29it/s]
