In [1]:
import torch
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import Dataset, DataLoader
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoModel
from datasets import load_dataset
from tqdm import tqdm
from nltk.translate.bleu_score import corpus_bleu

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
DECODER_NAME = 'Qwen/Qwen2.5-0.5B'
ENCODER_NAME = 'FacebookAI/xlm-roberta-base'
CHECKPOINT_PATH = 'D:\\diploma\\checkpoints\\best_model.pt'
DATASET_NAME = 'databricks/databricks-dolly-15k'
OUTPUT_DIM = 896
MIN_WORDS = 5
MAX_WORDS = 200
BATCH_SIZE = 8

In [None]:
class TextDataset(Dataset):
    def __init__(self, texts):
        self.texts = texts
    
    def __len__(self):
        return len(self.texts)
    
    def __getitem__(self, idx):
        return self.texts[idx]
    
def collate_fn(batch, tokenizer, device):
    input_ids = [tokenizer(text, add_special_tokens=True, return_tensors='pt', truncation=True, max_length=tokenizer.model_max_length)['input_ids'].reshape(-1) for text in batch]
    input_ids = pad_sequence(input_ids, batch_first=True, padding_value=tokenizer.pad_token_id).to(device)
    attention_mask = (input_ids != tokenizer.pad_token_id).long().to(device)
    return {
        'texts': batch,
        'input_ids': input_ids,
        'attention_mask': attention_mask
    }
    
# Класс кастомной модели
class Model(torch.nn.Module):
    def __init__(self, model_name, output_dim, freeze_bert=True):
        super().__init__()

        # Bert енкодер
        self.bert = AutoModel.from_pretrained(model_name)

        # Проекционная голова для e вектора
        self.e_proj = torch.nn.Linear(self.bert.config.hidden_size, output_dim)

        # Проекционная голова для m вектора
        self.m_proj = torch.nn.Linear(self.bert.config.hidden_size, output_dim)

        # Проекционная голова для среднего значения распределения длин
        self.mu = torch.nn.Linear(self.bert.config.hidden_size, 1)

        # Проекционная голова для стандартного отклонения распределения длин
        self.std = torch.nn.Linear(self.bert.config.hidden_size, 1)

        # Заморозка модели
        if freeze_bert:
            for param in self.bert.parameters():
                param.requires_grad = False

    def forward(self, input_ids, attention_mask=None):
        out = self.bert(input_ids, attention_mask)
        cls = out.last_hidden_state[:, 0, :]
        e = self.e_proj(cls)
        m = self.m_proj(cls)
        mu = self.mu(cls)
        std = self.std(cls)
        return e, m, mu, std

# Формирование длин овтетов    
def get_lengths(mu, std, strategy):
    std = torch.exp(0.5 * std)
    if strategy == 'sample':
        lengths = torch.normal(mean=mu, std=std)
        lengths = torch.clamp(lengths, min=mu - 2 * std, max=mu + 2 * std)
    if strategy == 'mean':
        lengths = mu
    lengths = torch.clamp(lengths.round(), min=1).long()
    return lengths

# Создание схемы с одним e вектором и text_length - 1 m векторов
def generate_input_one(vectors, text_length):
    return torch.cat([vectors[:1, None, :], vectors[1:2, None, :].expand(-1, text_length - 1, -1)], dim=1)

# Создание схемы для целого батча
def generate_input(batch_vectors, lengths):
    embeds = []
    for vectors, length in zip(batch_vectors, lengths):
        embeds.append(generate_input_one(vectors, length).squeeze(0))
    return pad_sequence(embeds, batch_first=True, padding_value=0.0)


In [None]:
dataset = load_dataset(DATASET_NAME)
decoder_model = AutoModelForCausalLM.from_pretrained(DECODER_NAME).to(DEVICE)
decoder_tokenizer = AutoTokenizer.from_pretrained(DECODER_NAME)
encoder_model = Model(model_name=ENCODER_NAME, output_dim=OUTPUT_DIM).to(DEVICE)
encoder_model.load_state_dict(torch.load(CHECKPOINT_PATH, map_location=DEVICE))
encoder_tokenizer = AutoTokenizer.from_pretrained(ENCODER_NAME, truncation=True)

In [5]:
dataset = load_dataset(DATASET_NAME)
df = dataset['train'].to_pandas()
df = df[df['response'].apply(lambda x: len(x.split(' ')) > MIN_WORDS)].reset_index(drop=True)
texts = list(df['response'])
texts = [' '.join(text.split(' ')[:MAX_WORDS]) if len(text.split(' ')) > MAX_WORDS else text for text in texts]

In [6]:
test_dataset = TextDataset(texts)
test_dataloader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, collate_fn=lambda x: collate_fn(x, encoder_tokenizer, DEVICE))

In [7]:
corpus_preds = []
for batch in tqdm(test_dataloader, desc='Processing dataset'):
    input_ids = batch['input_ids']
    attention_mask = batch['attention_mask']

    e, m, mu, std = encoder_model(input_ids, attention_mask)
    lengths = get_lengths(mu, std, 'mean')
    max_len = torch.max(lengths)
    vectors = torch.stack([e, m], dim=1)

    current_input = generate_input(vectors, lengths)    
    attention_mask = torch.arange(max_len, device=lengths.device).expand(BATCH_SIZE, max_len) < lengths

    logits = decoder_model(inputs_embeds=current_input, attention_mask=attention_mask).logits
    pred_ids = torch.argmax(logits, dim=-1)
    preds = decoder_tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
    preds = [pred.split(' ') for pred in preds]
    corpus_preds.extend(preds)

Processing dataset:  12%|█▏        | 212/1701 [00:19<02:18, 10.71it/s]


KeyboardInterrupt: 

In [None]:
texts = [[text.split(' ')] for text in texts]
bleu_score_corpus = corpus_bleu(texts, corpus_preds)
print(f'Corpus BLEU Score: {bleu_score_corpus}')