In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os
import json
import torch
import argparse
import torch.distributed as dist
import ml_collections
from datasets import disable_progress_bar
from transformers import BertConfig
import sys

sys.path.append("/home/vmeshchaninov/DiffusionTextGeneration-cond-ca/")

from diffusion_holder import DiffusionRunner
from utils.util import set_seed, _BERT_SMALL
from estimation_utils.util import estimate_model, reduce_metrics, gather_texts
import diffusion_utils.schedulers as schedulers

  warn(f"Failed to load image Python extension: {e}")


In [3]:
def create_config():
    config = ml_collections.ConfigDict()

    training = config.training = ml_collections.ConfigDict()
    training.ode_sampling = False
    training.checkpoints_folder = './checkpoints'
    training.batch_size = 512
    config.checkpoints_prefix = None

    validation = config.validation = ml_collections.ConfigDict()
    validation.batch_size = 512

    sde = config.sde = ml_collections.ConfigDict()
    sde.typename = 'vp-sde'
    sde.solver = 'euler'
    sde.N = 100
    sde.beta_min = 0.1
    sde.beta_max = 20
    sde.ode_sampling = False
    sde.scheduler = schedulers.CosineSD(d=10)

    model = config.model = ml_collections.ConfigDict()
    model.ema_rate = 0.9999
    model.enc_type = "base"
    model.embeddings_type = "encodings"
    model.dif_enc_type = "base"
    model.downstream_task = "sst2"  # "qqp"
    model.dataset = "wikipedia"  # "glue"
    model.prediction = "x_0"
    model.loss = "L_x_0"
    model.decoder_path = "decoder-wikipedia-128.pth"  # "decoder-wikipedia-128.pth"  # "decoder-t5_base-wikipedia-128.pth"

    data = config.data = ml_collections.ConfigDict()
    data.max_sequence_len = 96
    data.enc_bert_mean = "/home/vmeshchaninov/DiffusionTextGeneration-cond-ca/data/encodings-bert_base-wiki-mean.pt"
    data.enc_bert_std = "/home/vmeshchaninov/DiffusionTextGeneration-cond-ca/data/encodings-bert_base-wiki-std.pt"
    data.enc_t5_mean = "/home/vmeshchaninov/DiffusionTextGeneration-cond-ca/data/encodings-t5-wiki-mean.pth"
    data.enc_t5_std = "/home/vmeshchaninov/DiffusionTextGeneration-cond-ca/data/encodings-t5-wiki-std.pth"

    config.finetuning = False
    config.seed = 0
    config.ddp = False
    config.bert_config = BertConfig.from_pretrained("bert-base-uncased")

    config.project_name = "bert-conditional-exps"

    return config

In [4]:
config = create_config()
config.checkpoints_prefix = "wikipedia-sst2-encodings-prediction=x_0-loss=L_x_0-enc=base-bert=base-kl_cf=0.0-seq_len=96-clipgrad=1.0-lr=0.0002-min_lr=0.0002-lin_input=True-seed=0-wd=0.01-ting-pretrain_200000_"

diffusion = DiffusionRunner(config, latent_mode=config.model.embeddings_type, eval=True)

For now, this behavior is kept to avoid breaking backwards compatibility when padding/encoding with `truncation is True`.
- Be aware that you SHOULD NOT rely on t5-base automatically truncating your input to 512 when padding/encoding.
- If you want to encode/pad to sequences longer than 512 you can either instantiate this tokenizer with `model_max_length` or pass `max_length` when encoding/padding.
Some weights of the model checkpoint at t5-base were not used when initializing T5EncoderModel: ['decoder.block.5.layer.2.DenseReluDense.wi.weight', 'decoder.block.4.layer.0.SelfAttention.k.weight', 'decoder.block.3.layer.0.SelfAttention.q.weight', 'decoder.block.1.layer.1.EncDecAttention.v.weight', 'decoder.block.6.layer.1.layer_norm.weight', 'decoder.block.0.layer.1.EncDecAttention.v.weight', 'decoder.block.2.layer.2.layer_norm.weight', 'decoder.block.4.layer.2.layer_norm.weight', 'decoder.block.9.layer.1.EncDecAttention.v.weight', 'decoder.block.6.layer.1.EncDecAttention.q.weight', 'decod

In [5]:
batch_size = 64
seq_len = 96
mask = torch.ones(batch_size, seq_len).cuda()

In [6]:
noise = torch.load("data/noise.pth")

In [7]:
pred_embeddings, outputs = diffusion.pred_embeddings(
    batch_size=batch_size,
    cond_X=noise.cuda(),
    cond_mask=None,
    attention_mask=None,
)

100%|██████████| 100/100 [00:26<00:00,  3.79it/s]


In [9]:
pred_embeddings.shape

torch.Size([64, 96, 768])

In [10]:
outputs.keys()

dict_keys([1.0, 0.9900000095367432, 0.9800000190734863, 0.9700000286102295, 0.9599999785423279, 0.949999988079071, 0.9399999976158142, 0.9300000071525574, 0.9200000166893005, 0.9100000262260437, 0.8999999761581421, 0.8899999856948853, 0.8799999952316284, 0.8700000047683716, 0.8600000143051147, 0.8500000238418579, 0.8400000333786011, 0.8299999833106995, 0.8199999928474426, 0.8100000023841858, 0.800000011920929, 0.7900000214576721, 0.7800000309944153, 0.7699999809265137, 0.7599999904632568, 0.75, 0.7400000095367432, 0.7300000190734863, 0.7200000286102295, 0.7099999785423279, 0.699999988079071, 0.6899999976158142, 0.6800000071525574, 0.6700000166893005, 0.6600000262260437, 0.6500000357627869, 0.6399999856948853, 0.6299999952316284, 0.6200000047683716, 0.6100000143051147, 0.6000000238418579, 0.5900000333786011, 0.5799999833106995, 0.5699999928474426, 0.5600000023841858, 0.550000011920929, 0.5400000214576721, 0.5300000309944153, 0.5199999809265137, 0.5099999904632568, 0.5, 0.489999979734420

In [36]:
outputs[1.0].keys()

dict_keys(['x', 'x_mean', 'score', 'x_0', 'diffusion', 'drift', 'drift_par'])

In [24]:
torch.save(outputs, "outputs.pth")

In [13]:
out = torch.load("outputs.pth")

In [14]:
outputs[0.009999999776482582]["x_0"]

tensor([[[-0.2375,  0.3566, -0.3443,  ..., -0.9245,  0.1156,  0.1162],
         [ 0.1872,  0.3632,  1.6989,  ...,  0.7197, -0.5856, -0.6825],
         [-0.5971, -0.1787,  0.1655,  ...,  0.3111, -0.3182, -0.9170],
         ...,
         [-0.3196, -0.8789,  0.7730,  ..., -0.1017,  0.6921, -0.1522],
         [-0.3990,  0.0837, -0.4579,  ..., -0.2317, -0.4622,  0.0244],
         [-0.0850,  0.6905, -0.1843,  ...,  0.8798,  2.1656, -1.9218]],

        [[-0.5513, -0.3091, -0.6040,  ..., -0.1974,  1.4466, -0.5213],
         [-0.5075,  1.7510, -0.8983,  ...,  2.1446,  0.7130,  0.0743],
         [ 0.6342,  0.1784,  1.0736,  ...,  0.8481, -1.9557, -1.2758],
         ...,
         [-1.0950, -0.5760,  0.0230,  ...,  1.3681,  1.2309, -0.7965],
         [-0.1822,  0.0809,  0.5066,  ..., -1.1838, -0.1984, -0.1939],
         [ 0.2550,  0.5510,  0.2364,  ...,  1.4841,  1.0805, -1.7359]],

        [[-1.3289,  1.7212, -1.0028,  ..., -0.1714,  1.2455, -0.8029],
         [ 1.6172, -1.8103,  1.1269,  ...,  0

In [20]:
diffusion.decoder.load_state_dict(torch.load("./checkpoints/decoder-wikipedia-128.pth")["decoder"])

<All keys matched successfully>

In [21]:
diffusion.decoder.state_dict()

OrderedDict([('predictions.bias',
              tensor([-0.4030, -0.4316, -0.4306,  ..., -0.8043, -0.8011, -0.5078],
                     device='cuda:0')),
             ('predictions.transform.dense.weight',
              tensor([[ 0.5829, -0.0084,  0.0793,  ...,  0.0357, -0.0230,  0.0262],
                      [-0.0051,  0.3722, -0.0640,  ..., -0.0260, -0.0248,  0.0135],
                      [ 0.0137, -0.0793,  0.2726,  ...,  0.0437,  0.0706, -0.0374],
                      ...,
                      [ 0.0406,  0.0040,  0.0366,  ...,  0.4009,  0.0165, -0.0415],
                      [-0.0009, -0.0414,  0.0658,  ...,  0.0038,  0.2923, -0.0669],
                      [-0.0857,  0.0041,  0.0371,  ...,  0.0188, -0.0441,  0.5341]],
                     device='cuda:0')),
             ('predictions.transform.dense.bias',
              tensor([ 7.2733e-02,  8.9592e-02,  8.0551e-02,  4.4807e-03,  5.3272e-02,
                       2.5110e-03,  2.8013e-02,  3.0577e-03,  4.4827e-02,  1.2519e