In [1]:
%load_ext autoreload
%autoreload 2

import sys

sys.path.append("/home/vmeshchaninov/DiffusionTextGeneration-cond-ca/")

In [2]:
import torch
from transformers import BertLMHeadModel, BertTokenizerFast
from torch.utils.data import DataLoader
from torch.nn.functional import cross_entropy
from tqdm import tqdm

In [3]:
from data.dataset import WikipediaDataset
from estimation_utils.metrics import BloomMetricConditional, BloomMetric
from utils.util import dict_to_cuda
from estimation_utils.util import compute_metric

In [4]:
tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased")

In [25]:
dataset = next(WikipediaDataset(
    split="test",
    tokenizer=tokenizer,
    max_sequence_len=128,
    pos_begin=0.33,
    pos_end=0.67,
).get_data())

In [55]:
batch_size = 2048
loader = DataLoader(
            dataset,
            batch_size=batch_size,
            num_workers=1,
            shuffle=False,
        )
loader = iter(loader)

# Условная генерация с помощью GPT-2

In [56]:
X = next(loader)

text_cond = tokenizer.batch_decode(
    X["cond_ids"], 
    skip_special_tokens=True,
)

text_gen = tokenizer.batch_decode(
    X["input_ids"], 
    skip_special_tokens=True,
)

In [57]:
torch.mean(torch.sum(X["cond_mask"], dim=1) * 1.)

tensor(53.0498)

In [58]:
torch.mean((torch.sum(X["cond_mask"], dim=1) + torch.sum(X["input_mask"], dim=1)) * 1.)

tensor(107.5869)

In [9]:
from transformers import pipeline

In [11]:
# gpt2 == gpt2-small
generator_gpt2 = pipeline('text-generation', model='gpt2', device=0)

## Условная метрика текста: BloomMetricConditional

In [12]:
metric_bloom_fn = BloomMetricConditional(device="cuda:0")

In [80]:
%%time

texts_gpt = generator_gpt2(text_cond, max_new_tokens=64, num_return_sequences=1, return_full_text=False, pad_token_id=50256)

CPU times: user 18min 17s, sys: 18 s, total: 18min 35s
Wall time: 18min 41s


In [81]:
sum_metric, num_tokens = 0., 0.

for ind in tqdm(range(len(text_cond))[:100]):
    output = metric_bloom_fn(text_cond=text_cond[ind], text_gen=texts_gpt[ind][0]["generated_text"], reduce="sum")
    if output[1] != 0:
        sum_metric += output[0]
        num_tokens += output[1]

sum_metric / num_tokens, num_tokens / batch_size

100%|██████████| 100/100 [00:13<00:00,  7.41it/s]


(3.606053533394779, 2.86474609375)

## Безусловная метрика текста: BloomMetric

In [15]:
metric_bloom_uncond_fn = BloomMetric(device="cuda:0")

In [71]:
%%time

texts_gpt = generator_gpt2(text_cond, max_new_tokens=64, num_return_sequences=1, return_full_text=True, pad_token_id=50256)

CPU times: user 18min 12s, sys: 18 s, total: 18min 30s
Wall time: 18min 32s


In [72]:
sum_metric, num_tokens = 0., 0.

for ind in tqdm(range(len(text_cond))):
    output = metric_bloom_uncond_fn(text=texts_gpt[ind][0]["generated_text"], reduce="sum")
    sum_metric += output[0]
    num_tokens += output[1]

sum_metric / num_tokens, num_tokens / batch_size

100%|██████████| 2048/2048 [04:10<00:00,  8.18it/s]


(3.6600863342248546, 112.189453125)

## Метрики батча

In [59]:
sum_metric, num_tokens = 0., 0.
texts = [f"{text_cond[i]} {text_gen[i]}" for i in range(batch_size)]

for ind in tqdm(range(len(text_cond))):
    output = metric_bloom_uncond_fn(text=texts[ind], reduce="sum")
    sum_metric += output[0]
    num_tokens += output[1]

sum_metric / num_tokens, num_tokens / batch_size

100%|██████████| 2048/2048 [04:05<00:00,  8.36it/s]


(3.4296206837482495, 107.177734375)

In [60]:
sum_metric, num_tokens = 0., 0.

for ind in tqdm(range(len(text_cond))):
    output = metric_bloom_fn(text_cond=text_cond[ind], text_gen=text_gen[ind], reduce="sum")
    sum_metric += output[0]
    num_tokens += output[1]

sum_metric / num_tokens, num_tokens / batch_size

100%|██████████| 2048/2048 [04:05<00:00,  8.33it/s]


(3.2399693615122613, 54.89990234375)

# Метрики реального текста

In [22]:
metric_bloom_fn = BloomMetric(device="cuda:0")

In [12]:
X = next(loader)

In [13]:
text_cond = tokenizer.batch_decode(
    X["cond_ids"], 
    skip_special_tokens=True
)

In [14]:
text_gen = tokenizer.batch_decode(
    X["input_ids"], 
    skip_special_tokens=True
)

### Правильно расположенный текст

In [41]:
text = [f"{text_cond[i]} {text_gen[i]}" for i in range(batch_size)]

In [42]:
compute_metric(metric_bloom_uncond_fn, text)

metric: bigscience/bloom-7b1, 3.4458: 100%|██████████| 1024/1024 [02:03<00:00,  8.30it/s]


3.4458400700800738

### Условие и генерацию поменяли местами 

In [61]:
text = [f"{text_gen[i]} {text_cond[i]}" for i in range(batch_size)]

In [62]:
compute_metric(metric_bloom_uncond_fn, text)

metric: bigscience/bloom-7b1, 3.5764: 100%|██████████| 2048/2048 [04:07<00:00,  8.28it/s]


3.5763897249226058

### Перемешенные части

In [63]:
import random

In [64]:
random.shuffle(text_cond)

In [65]:
text = [f"{text_gen[i]} {text_cond[i]}" for i in range(batch_size)]

In [66]:
compute_metric(metric_bloom_uncond_fn, text)

metric: bigscience/bloom-7b1, 3.7576: 100%|██████████| 2048/2048 [04:05<00:00,  8.33it/s]


3.757617524947899

In [67]:
text = [f"{text_cond[i]} {text_gen[i]}" for i in range(batch_size)]

In [68]:
compute_metric(metric_bloom_uncond_fn, text)

metric: bigscience/bloom-7b1, 3.8073: 100%|██████████| 2048/2048 [04:04<00:00,  8.36it/s]


3.807272912530623

### Bloom loss только на условной части

In [8]:
from transformers import BloomTokenizerFast, BloomForCausalLM

In [9]:
name = "bigscience/bloom-7b1"
bloom = BloomForCausalLM.from_pretrained(name).eval().to("cuda:0")

In [36]:
tokenizer_bloom = BloomTokenizerFast.from_pretrained(name)

In [39]:
num = 0
loss = 0

with torch.no_grad():
    for i in tqdm(range(batch_size)):
        inputs_gen = tokenizer_bloom(text_gen[i], return_tensors="pt")
        inputs_cond = tokenizer_bloom(f" {text_cond[i]}", return_tensors="pt")

        inputs = {
            "input_ids": torch.cat([inputs_cond["input_ids"], inputs_gen["input_ids"]], dim=-1).type(torch.long),
            "attention_mask": torch.cat([inputs_cond["attention_mask"], inputs_gen["attention_mask"]], dim=-1).type(torch.long)
        }

        inputs = dict_to_cuda(inputs)
        outputs = bloom(**inputs, labels=inputs["input_ids"])

        losses = cross_entropy(
                input=outputs.logits.reshape(-1, outputs.logits.shape[-1])[:-1],
                target=inputs["input_ids"].reshape(-1)[1:],
                reduce=False,
            )
        losses = losses[torch.sum(inputs_cond["attention_mask"]).item() - 1:]
        loss += losses.sum()
        num += losses.shape[0]

loss / num

100%|██████████| 1024/1024 [01:41<00:00, 10.11it/s]


tensor(3.4119, device='cuda:0')

In [109]:
inputs_gen = tokenizer_bloom(text_gen[0], return_tensors="pt")
inputs_cond = tokenizer_bloom(text_cond[0], return_tensors="pt")

In [110]:
inputs = {
    "input_ids": torch.cat([inputs_cond["input_ids"], inputs_gen["input_ids"]], dim=-1),
    "attention_mask": torch.cat([inputs_cond["attention_mask"], inputs_gen["attention_mask"]], dim=-1)
}

In [111]:
inputs = dict_to_cuda(inputs)

In [112]:
outputs = bloom(**inputs, labels=inputs["input_ids"])

In [114]:
losses = cross_entropy(
            input=outputs.logits.reshape(-1, outputs.logits.shape[-1])[:-1],
            target=inputs["input_ids"].reshape(-1)[1:],
            reduce=False,
        )

In [127]:
losses = losses[torch.sum(inputs_cond["attention_mask"]).item() - 1:]

tensor(4.3345, device='cuda:0', grad_fn=<MeanBackward0>)