In [1]:
import os
import re
import warnings
import random
from collections import defaultdict
from typing import Dict, List, Tuple

import torch
import torch.nn.functional as F
import numpy as np
import numpy as np
import torch
from tqdm.notebook import tqdm
from transformers import GPT2LMHeadModel, GPT2Tokenizer

warnings.filterwarnings("ignore")

In [2]:
def seed_everything(seed: int):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True

seed_everything(42)

## Задание

1) Реализовать методы `greedy_sampling` и `generate` (1 балл)
2) Реализовать метод `random_sampling` и поддержать его в `generate` (1 балл)
3) Реализовать метод `_beam_search_generate` и поддержать его в `generate` (2 балла)
4) Реализовать методы `apply_top_p`, `apply_top_k`, `apply_temperature` и поддержать их в `generate` (1 балл)  
Все методы необходимо реализовать через векторные операции в torch/numpy везде где это возможно

In [3]:
class Model:
    def __init__(self, model_name: str = "gpt2"):
        self.model = GPT2LMHeadModel.from_pretrained(model_name)
        self.tokenizer = GPT2Tokenizer.from_pretrained(model_name)
        self.tokenizer.pad_token = self.tokenizer.eos_token
        self.vocab_size = self.tokenizer.vocab_size

    def greedy_sampling(self, logits: torch.Tensor) -> int:
        if logits.dim() > 1 and logits.size(0) == 1:
            logits = logits.squeeze(0)
        return int(torch.argmax(logits, dim=-1).item())

    def random_sampling(self, logits: torch.Tensor) -> int:
        if logits.dim() > 1 and logits.size(0) == 1:
            logits = logits.squeeze(0)
        probs = torch.softmax(logits, dim=-1)
        next_token = torch.multinomial(probs, num_samples=1).item()
        return int(next_token)

    def _beam_search_generate(
        self,
        prompt: str,
        max_length: int,
        num_beams: int
    ) -> str:
        self.model.eval()
        device = next(self.model.parameters()).device
        input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids.to(device)[0].tolist()
        beams = [(input_ids, 0.0)]
        finished = []
        for _ in range(max_length):
            all_candidates = []
            for seq, score in beams:
                seq_tensor = torch.tensor([seq], device=device)
                with torch.no_grad():
                    outputs = self.model(seq_tensor)
                    logits = outputs.logits[:, -1, :]
                    log_probs = torch.log_softmax(logits, dim=-1).squeeze(0)
                k = min(num_beams, self.vocab_size)
                topk_log_probs, topk_ids = torch.topk(log_probs, k=k)
                for lp, tid in zip(topk_log_probs.tolist(), topk_ids.tolist()):
                    new_seq = seq + [int(tid)]
                    new_score = score + float(lp)
                    all_candidates.append((new_seq, new_score))
            all_candidates.sort(key=lambda x: x[1], reverse=True)
            beams = all_candidates[:num_beams]
            new_beams = []
            for seq, score in beams:
                if self.tokenizer.eos_token_id is not None and seq[-1] == self.tokenizer.eos_token_id:
                    finished.append((seq, score))
                else:
                    new_beams.append((seq, score))
            beams = new_beams
            if len(finished) >= num_beams:
                break
            if len(beams) == 0:
                break
        if len(finished) > 0:
            finished.sort(key=lambda x: x[1], reverse=True)
            best_seq = finished[0][0]
        else:
            beams.sort(key=lambda x: x[1], reverse=True)
            best_seq = beams[0][0]
        return self.tokenizer.decode(best_seq, skip_special_tokens=True)

    def apply_temperature(self, logits: torch.Tensor, temperature: float = 1.0) -> torch.Tensor:
        if temperature is None or temperature == 1.0:
            return logits
        return logits / float(temperature)

    def apply_top_k(self, logits: torch.Tensor, top_k: int = 0) -> torch.Tensor:
        if top_k is None or top_k <= 0 or top_k >= logits.size(-1):
            return logits
        values, indices = torch.topk(logits, k=top_k, dim=-1)
        mask = torch.ones_like(logits, dtype=torch.bool)
        keep = torch.zeros_like(indices, dtype=torch.bool)
        mask = mask.scatter(dim=-1, index=indices, src=keep)
        return logits.masked_fill(mask, float('-1e9'))

    def apply_top_p(self, logits: torch.Tensor, top_p: float = 1.0) -> torch.Tensor:
        if top_p is None or top_p >= 1.0:
            return logits
        sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)
        sorted_probs = torch.softmax(sorted_logits, dim=-1)
        cumulative_probs = torch.cumsum(sorted_probs, dim=-1)
        sorted_mask = cumulative_probs > float(top_p)
        if sorted_mask.dim() == 1:
            sorted_mask[0] = False
        else:
            sorted_mask[..., 0] = False
        mask = torch.zeros_like(logits, dtype=torch.bool).scatter(dim=-1, index=sorted_indices, src=sorted_mask)
        return logits.masked_fill(mask, float('-1e9'))

    def _apply_top_p(self, logits: torch.Tensor, top_p: float = 1.0) -> torch.Tensor:
        return self.apply_top_p(logits, top_p)

    def _apply_top_k(self, logits: torch.Tensor, top_k: int = None) -> torch.Tensor:
        return self.apply_top_k(logits, top_k if top_k is not None else self.vocab_size)

    def generate(
        self,
        prompt: str,
        max_length: int = 50,
        strategy: str = "greedy",
        temperature: float = 1.0,
        top_k: int = 0,
        top_p: float = 1.0,
        num_beams: int = 3
    ) -> str:
        if strategy == "beam":
            return self._beam_search_generate(prompt=prompt, max_length=max_length, num_beams=num_beams)
        self.model.eval()
        device = next(self.model.parameters()).device
        input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids.to(device)
        generated = input_ids
        eos_id = self.tokenizer.eos_token_id
        max_new_tokens = max_length
        with torch.no_grad():
            for _ in range(max_new_tokens):
                outputs = self.model(generated)
                logits = outputs.logits[:, -1, :]
                logits = self.apply_temperature(logits, temperature)
                if top_k is not None and top_k > 0:
                    logits = self.apply_top_k(logits, top_k)
                if top_p is not None and top_p < 1.0:
                    logits = self.apply_top_p(logits, top_p)
                if strategy == "greedy":
                    next_token = self.greedy_sampling(logits)
                elif strategy == "random":
                    next_token = self.random_sampling(logits)
                else:
                    raise NotImplementedError(f"Strategy '{strategy}' not implemented in generate()")
                next_token_tensor = torch.tensor([[next_token]], device=device)
                generated = torch.cat([generated, next_token_tensor], dim=1)
                if eos_id is not None and next_token == eos_id:
                    break
        return self.tokenizer.decode(generated[0], skip_special_tokens=True)

In [4]:
from time import perf_counter
import numpy as np
import torch

torch.manual_seed(42)
np.random.seed(42)

print('Создаю модель (это загрузит gpt2, займет время)...')
start = perf_counter()
model = Model()
print(f'Model loaded in {perf_counter()-start:.1f}s')

prompts = [
    "Hello, my name is",
    "Once upon a time",
    "In a shocking turn of events",
]

cases = [
    {"strategy": "greedy", "max_length": 20},
    {"strategy": "random", "max_length": 20, "temperature": 1.0},
    {"strategy": "random", "max_length": 20, "temperature": 0.7, "top_k": 50},
    {"strategy": "random", "max_length": 20, "top_p": 0.9},
    {"strategy": "random", "max_length": 30, "temperature": 1.2, "top_k": 100, "top_p": 0.95},
    {"strategy": "beam", "max_length": 25, "num_beams": 3},
    {"strategy": "beam", "max_length": 25, "num_beams": 5},
    {"strategy": "greedy", "max_length": 40},
]


def run_case(prompt: str, case: dict):
    print('\n' + '='*60)
    print('Prompt:', prompt)
    print('Params:', case)
    try:
        t0 = perf_counter()
        out = model.generate(prompt, **case)
        dt = perf_counter() - t0
        print(f"(time: {dt:.2f}s)\nResult:\n", out)
    except Exception as e:
        print('Ошибка при генерации:', e)

for p in prompts:
    for c in cases:
        run_case(p, c)

print('\nДемонстрация завершена.')

Создаю модель (это загрузит gpt2, займет время)...


config.json:   0%|          | 0.00/665 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/548M [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/124 [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/26.0 [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/1.04M [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

Model loaded in 13.1s

Prompt: Hello, my name is
Params: {'strategy': 'greedy', 'max_length': 20}
(time: 3.84s)
Result:
 Hello, my name is John. I'm a writer, and I'm a writer. I'm a writer. I'm

Prompt: Hello, my name is
Params: {'strategy': 'random', 'max_length': 20, 'temperature': 1.0}
(time: 3.57s)
Result:
 Hello, my name is Juno and I work for Fifa in Zurich."

EU Vice-President Frans Timmermans

Prompt: Hello, my name is
Params: {'strategy': 'random', 'max_length': 20, 'temperature': 0.7, 'top_k': 50}
(time: 3.43s)
Result:
 Hello, my name is N.Y.U.I.L.E.R. and I'm a member of

Prompt: Hello, my name is
Params: {'strategy': 'random', 'max_length': 20, 'top_p': 0.9}
(time: 5.39s)
Result:
 Hello, my name is Alex S. Cupp, a trained investigative journalist and human rights lawyer. I was a Senior Policy

Prompt: Hello, my name is
Params: {'strategy': 'random', 'max_length': 30, 'temperature': 1.2, 'top_k': 100, 'top_p': 0.95}
(time: 5.57s)
Result:
 Hello, my name is Jason Van Allen! 