In [1]:
%load_ext autoreload
%autoreload 2

import sys

sys.path.append("/home/vmeshchaninov/DiffusionTextGeneration-cond-ca/")

In [2]:
import torch
from transformers import BertLMHeadModel, BertTokenizerFast, BertConfig
from torch.utils.data import DataLoader
from torch.nn.functional import cross_entropy
from tqdm import tqdm
import ml_collections

In [13]:
from diffusion_utils import schedulers
from diffusion_holder import DiffusionRunner
from utils.util import set_seed, dict_to_cuda
from estimation_utils.estimate_glue import estimate_sst2

# SST-2

In [30]:
def create_config():
    config = ml_collections.ConfigDict()
    optim = config.optim = ml_collections.ConfigDict()
    optim.grad_clip_norm = 1.
    optim.linear_warmup = 0
    optim.lr = 2e-4
    optim.min_lr = 2e-4
    optim.warmup_lr = 2e-4
    optim.weight_decay = 0.01
    optim.beta_1 = 0.9
    optim.beta_2 = 0.98
    optim.eps = 1e-6

    training = config.training = ml_collections.ConfigDict()
    training.training_iters = 400_000
    training.finetuning_iters = 10_000
    training.training_iters = training.training_iters + training.finetuning_iters
    training.checkpoint_freq = 1_000
    training.eval_freq = 1_000
    training.batch_size = 512

    training.ode_sampling = False
    training.checkpoints_folder = '../checkpoints/'
    config.checkpoints_prefix = ''

    loss = config.loss = ml_collections.ConfigDict()
    loss.ce_coef = 0.

    refresh = config.refresh = ml_collections.ConfigDict()
    refresh.true = True
    refresh.prefix = ""
    refresh.wand_id = "g5fb4af3"

    validation = config.validation = ml_collections.ConfigDict()
    validation.batch_size = 1024
    validation.validation_iters = int(10_000 / validation.batch_size)
    validation.num_gen_texts = 2048
    validation.p_uncond = 0.

    sde = config.sde = ml_collections.ConfigDict()
    sde.typename = 'vp-sde'
    sde.solver = 'euler'
    sde.N = 1000
    sde.beta_min = 0.1
    sde.beta_max = 20
    sde.ode_sampling = False
    sde.scheduler = schedulers.CosineSD(d=10)

    model = config.model = ml_collections.ConfigDict()
    model.ema_rate = 0.9999
    model.enc_type = "base"
    model.embeddings_type = "encodings"
    model.dif_enc_type = "base"
    model.downstream_task = "sst2"  # "qqp"
    model.dataset = "glue"  # "glue"
    model.prediction = "x_0"
    model.loss = "L_x_0"

    data = config.data = ml_collections.ConfigDict()
    data.max_sequence_len = 64

    config.lin_input = True
    config.seed = 0
    config.ddp = False
    config.bert_config = BertConfig.from_pretrained("bert-base-uncased")

    return config

In [33]:
config = create_config()
config.checkpoints_prefix = "glue-sst2-encodings-prediction=x_0-loss=L_x_0-enc=base-bert=base-kl_cf=0.0-seq_len=64-clipgrad=1.0-lr=0.0002-min_lr=0.0002-lin_input=True-seed=0-wd=0.01-glue-sst2_405000_"

seed = config.seed
set_seed(seed)

diffusion = DiffusionRunner(config, latent_mode="encodings", eval=True)

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertLMHeadModel: ['cls.seq_relationship.bias', 'cls.seq_relationship.weight']
- This IS expected if you are initializing BertLMHeadModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertLMHeadModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [6]:
estimate_sst2(diffusion)

100%|██████████| 1000/1000 [02:13<00:00,  7.50it/s]


(113.0, 128.0)

In [35]:
diffusion.set_valid_data_generator()

In [36]:
with torch.no_grad():
    X = next(iter(diffusion.valid_loader))
    X = dict_to_cuda(X)
    clean_X = diffusion.sampler_emb({"input_ids": X["input_ids"], "attention_mask": X["input_mask"]})
    output = diffusion.decoder(clean_X)
    tokens = output.argmax(dim=-1)
    target = tokens[:, 1]
    label = X["input_ids"][:, 1]
    print(torch.mean((target == label) * 1.), target.shape)

tensor(1.0000, device='cuda:0') torch.Size([872])
