In [1]:
%load_ext autoreload
%autoreload 2

In [3]:
import sys
sys.path.append("/home/jovyan/vmeshchaninov/DiffusionTextGeneration-cond-ca")

In [5]:
import os
import pandas as pd
import torch.distributed as dist
import json
import torch
import numpy as np
from transformers import AutoTokenizer
from tqdm import tqdm
from torch.nn.functional import cross_entropy, softmax
import itertools
from collections import Counter
import matplotlib.pyplot as plt

In [74]:
from create_config import create_config
from diffusion_holder import DiffusionRunner
from utils import set_seed

In [8]:
%cd ..

/home/jovyan/vmeshchaninov/DiffusionTextGeneration-cond-ca


In [102]:
config = create_config()
config.device = "cuda:0"
config.ddp = False
config.training.checkpoints_prefix = "rocstory-sd-8-64-lr=0.0002-compressed"
config.bert_config.num_hidden_layers = 12
config.eval = True
config.bert_config.is_decoder = False

diffusion_1 = DiffusionRunner(config, eval=True)

Dataset tokenization (num_proc=30):   0%|          | 0/1000 [00:00<?, ? examples/s]

Checkpoint is loaded ./checkpoints/rocstory-sd-8-64-lr=0.0002-compressed/400000.pth


In [103]:
config = create_config()
config.device = "cuda:0"
config.ddp = False
config.training.checkpoints_prefix = "rocstory-sd-8-64-512-lr=0.0004-compressed-v2-0.5cfg"
config.bert_config.num_hidden_layers = 4
config.eval = True
config.bert_config.is_decoder = True

diffusion_2 = DiffusionRunner(config, eval=True)

Dataset tokenization (num_proc=30):   0%|          | 0/1000 [00:00<?, ? examples/s]

Checkpoint is loaded ./checkpoints/rocstory-sd-8-64-512-lr=0.0004-compressed-v2-0.5cfg/30000.pth


In [104]:
batch_size = 1000

In [105]:
set_seed(1)

In [106]:
diffusion_1.score_estimator.eval()

cond_x = None
cond_mask = None
attention_mask = None

shape = (
    batch_size,
    diffusion_1.config.autoencoder.compressor.num_latents,
    diffusion_1.config.autoencoder.compressor.latent_dim
)

with torch.no_grad():
    x = diffusion_1.dynamic.prior_sampling(shape).to(diffusion_1.device)
    x_0_self_cond = torch.zeros_like(x, dtype=x.dtype)
    eps_t = diffusion_1.config.generation.t_min

    timesteps = torch.linspace(diffusion_1.dynamic.T, eps_t, diffusion_1.dynamic.N + 1, device=diffusion_1.device)

    for idx in tqdm(range(diffusion_1.dynamic.N)):
        t = timesteps[idx]
        next_t = timesteps[idx + 1]

        input_t = t * torch.ones(shape[0], device=diffusion_1.device)
        next_input_t = next_t * torch.ones(shape[0], device=diffusion_1.device)

        output = diffusion_1.diff_eq_solver.step(
            x_t=x, t=input_t, next_t=next_input_t,
            cond=cond_x,
            cond_mask=cond_mask,
            attention_mask=attention_mask,
            x_0_self_cond=x_0_self_cond,
        )

        x, x_mean = output["x"], output["x_mean"]
        x_0_self_cond = output["x_0"]

    pred_latents = x_mean

100%|██████████| 100/100 [00:08<00:00, 11.74it/s]


In [117]:
diffusion_2.config.generation.cfg_coef = 1

In [118]:
diffusion_2.score_estimator.eval()

cond_x = pred_latents
cond_mask = None
attention_mask = None

shape = (
    batch_size,
    diffusion_2.config.data.max_sequence_len,
    diffusion_2.config.autoencoder.compressor.latent_dim
)

with torch.no_grad():
    x = diffusion_2.dynamic.prior_sampling(shape).to(diffusion_2.device)
    x_0_self_cond = torch.zeros_like(x, dtype=x.dtype)
    eps_t = diffusion_2.config.generation.t_min

    timesteps = torch.linspace(diffusion_2.dynamic.T, eps_t, diffusion_2.dynamic.N + 1, device=diffusion_2.device)

    for idx in tqdm(range(diffusion_2.dynamic.N)):
        t = timesteps[idx]
        next_t = timesteps[idx + 1]

        input_t = t * torch.ones(shape[0], device=diffusion_2.device)
        next_input_t = next_t * torch.ones(shape[0], device=diffusion_2.device)

        output = diffusion_2.diff_eq_solver.step(
            x_t=x, t=input_t, next_t=next_input_t,
            cond=cond_x,
            cond_mask=cond_mask,
            attention_mask=attention_mask,
            x_0_self_cond=x_0_self_cond,
        )

        x, x_mean = output["x"], output["x_mean"]
        x_0_self_cond = output["x_0"]

    pred_embeddings = x_mean

100%|██████████| 100/100 [00:15<00:00,  6.42it/s]


In [119]:
output = diffusion_2.pred_logits(pred_embeddings)
tokens = output.argmax(dim=-1)

In [120]:
eos_id = diffusion_2.tokenizer_gen.vocab[diffusion_2.tokenizer_gen.sep_token]
tokens = tokens.detach().cpu().tolist()

tokens_list = []
for seq in tokens:
    id = 0
    while id < len(seq) and seq[id] != eos_id:
        id += 1
    tokens_list.append(seq[0: id])

text = diffusion_2.tokenizer_gen.batch_decode(tokens_list, skip_special_tokens=True)

In [121]:
from estimation_utils.evaluation import *

In [122]:
compute_perplexity(all_texts_list=text)

  0%|          | 0/63 [00:00<?, ?it/s]

30.80314806365967

In [112]:
from datasets import load_from_disk

In [113]:
path = f"{diffusion_2.config.data.dataset_path}/test/"
dt = load_from_disk(path)

In [114]:
references = dt["text"]

In [115]:
mauve = compute_mauve(all_texts_list=text, human_references=references)

Featurizing p:   0%|          | 0/1000 [00:00<?, ?it/s]

Featurizing q:   0%|          | 0/1000 [00:00<?, ?it/s]



In [116]:
mauve

0.6057939736536875