In [1]:
%load_ext autoreload
%autoreload 2

import sys

sys.path.append("/home/vmeshchaninov/DiffusionTextGeneration-cond-ca/")

In [10]:
import torch
from transformers import BertLMHeadModel, BertTokenizerFast
from torch.utils.data import DataLoader
from torch.nn.functional import cross_entropy
from tqdm import tqdm

In [3]:
from data.dataset import WikipediaDataset
from estimation_utils.metrics import BloomMetric
from utils.util import dict_to_cuda
from estimation_utils.util import compute_metric

In [4]:
pretrained_enc_type = "bert-base-uncased"

In [5]:
tokenizer = BertTokenizerFast.from_pretrained(pretrained_enc_type)

In [6]:
dataset = next(WikipediaDataset(
    split="test",
    tokenizer=tokenizer,
    max_sequence_len=64,
    p_uncond=0
).get_data())

In [7]:
batch_size = 1024
loader = DataLoader(
            dataset,
            batch_size=batch_size,
            num_workers=1,
            shuffle=False,
        )
loader = iter(loader)

## Метрики реального текста

In [22]:
metric_bloom_fn = BloomMetric(device="cuda:0")

In [12]:
X = next(loader)

In [13]:
text_cond = tokenizer.batch_decode(
    X["cond_ids"], 
    skip_special_tokens=True
)

In [14]:
text_gen = tokenizer.batch_decode(
    X["input_ids"], 
    skip_special_tokens=True
)

### Правильно расположенный текст

In [63]:
text = [f"{text_cond[i]} {text_gen[i]}" for i in range(batch_size)]

In [64]:
compute_metric(metric_bloom_fn, text)

metric: bigscience/bloom-7b1, 3.5594: 100%|██████████| 1024/1024 [01:46<00:00,  9.63it/s]


3.5594316921614437

### Условие и генерацию поменяли местами 

In [65]:
text = [f"{text_gen[i]} {text_cond[i]}" for i in range(batch_size)]

In [66]:
compute_metric(metric_bloom_fn, text)

metric: bigscience/bloom-7b1, 3.7407: 100%|██████████| 1024/1024 [01:46<00:00,  9.63it/s]


3.7406894671179236

### Перемешенные части

In [67]:
import random

In [68]:
random.shuffle(text_cond)

In [70]:
text = [f"{text_gen[i]} {text_cond[i]}" for i in range(batch_size)]

In [71]:
compute_metric(metric_bloom_fn, text)

metric: bigscience/bloom-7b1, 3.9254: 100%|██████████| 1024/1024 [01:46<00:00,  9.62it/s]


3.9254232835413743

### Bloom loss только на условной части

In [8]:
from transformers import BloomTokenizerFast, BloomForCausalLM

In [9]:
name = "bigscience/bloom-7b1"
bloom = BloomForCausalLM.from_pretrained(name).eval().to("cuda:0")

In [36]:
tokenizer_bloom = BloomTokenizerFast.from_pretrained(name)

In [39]:
num = 0
loss = 0

with torch.no_grad():
    for i in tqdm(range(batch_size)):
        inputs_gen = tokenizer_bloom(text_gen[i], return_tensors="pt")
        inputs_cond = tokenizer_bloom(f" {text_cond[i]}", return_tensors="pt")

        inputs = {
            "input_ids": torch.cat([inputs_cond["input_ids"], inputs_gen["input_ids"]], dim=-1).type(torch.long),
            "attention_mask": torch.cat([inputs_cond["attention_mask"], inputs_gen["attention_mask"]], dim=-1).type(torch.long)
        }

        inputs = dict_to_cuda(inputs)
        outputs = bloom(**inputs, labels=inputs["input_ids"])

        losses = cross_entropy(
                input=outputs.logits.reshape(-1, outputs.logits.shape[-1])[:-1],
                target=inputs["input_ids"].reshape(-1)[1:],
                reduce=False,
            )
        losses = losses[torch.sum(inputs_cond["attention_mask"]).item() - 1:]
        loss += losses.sum()
        num += losses.shape[0]

loss / num

100%|██████████| 1024/1024 [01:41<00:00, 10.11it/s]


tensor(3.4119, device='cuda:0')

In [109]:
inputs_gen = tokenizer_bloom(text_gen[0], return_tensors="pt")
inputs_cond = tokenizer_bloom(text_cond[0], return_tensors="pt")

In [110]:
inputs = {
    "input_ids": torch.cat([inputs_cond["input_ids"], inputs_gen["input_ids"]], dim=-1),
    "attention_mask": torch.cat([inputs_cond["attention_mask"], inputs_gen["attention_mask"]], dim=-1)
}

In [111]:
inputs = dict_to_cuda(inputs)

In [112]:
outputs = bloom(**inputs, labels=inputs["input_ids"])

In [114]:
losses = cross_entropy(
            input=outputs.logits.reshape(-1, outputs.logits.shape[-1])[:-1],
            target=inputs["input_ids"].reshape(-1)[1:],
            reduce=False,
        )

In [127]:
losses = losses[torch.sum(inputs_cond["attention_mask"]).item() - 1:]

tensor(4.3345, device='cuda:0', grad_fn=<MeanBackward0>)