In [16]:
import os
import re
import warnings
import random
from collections import defaultdict
from typing import Dict, List, Tuple

import torch
import torch.nn.functional as F
import numpy as np
import numpy as np
import torch
from tqdm.notebook import tqdm
from transformers import GPT2LMHeadModel, GPT2Tokenizer

warnings.filterwarnings("ignore")

In [17]:
def seed_everything(seed: int):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True

seed_everything(42)

## Задание

1) Реализовать методы `greedy_sampling` и `generate` (1 балл)
2) Реализовать метод `random_sampling` и поддержать его в `generate` (1 балл)
3) Реализовать метод `_beam_search_generate` и поддержать его в `generate` (2 балла)
4) Реализовать методы `apply_top_p`, `apply_top_k`, `apply_temperature` и поддержать их в `generate` (1 балл)  
Все методы необходимо реализовать через векторные операции в torch/numpy везде где это возможно

In [18]:
class Model:
    def __init__(self, model_name: str = "gpt2"):
        self.model = GPT2LMHeadModel.from_pretrained(model_name)
        self.tokenizer = GPT2Tokenizer.from_pretrained(model_name)
        self.tokenizer.pad_token = self.tokenizer.eos_token
        self.vocab_size = self.tokenizer.vocab_size
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.model.to(self.device)

    def greedy_sampling(self, logits: torch.Tensor) -> int:
        return int(torch.argmax(logits).item())

    def random_sampling(self, logits: torch.Tensor) -> int:
        probs = F.softmax(logits, dim=-1)
        token = torch.multinomial(probs, num_samples=1)
        return int(token.item())

    def _beam_search_generate(
        self,
        prompt: str,
        max_length: int,
        num_beams: int,
        temperature: float = 1.0,
        top_k: int = 0,
        top_p: float = 1.0
    ) -> str:
        eos_id = self.tokenizer.eos_token_id
        input_ids = self.tokenizer.encode(prompt, return_tensors="pt").to(self.device)
        seq_len = input_ids.shape[-1]
        beams = [{
            "input_ids": input_ids,
            "score": 0.0,
            "done": False
        }]

        with torch.no_grad():
            for _step in range(max_length - seq_len):
                all_candidates = []
                if all(b["done"] for b in beams):
                    break

                for b in beams:
                    if b["done"]:
                        all_candidates.append(b)
                        continue

                    outputs = self.model(b["input_ids"])
                    logits = outputs.logits[:, -1, :].squeeze(0)

                    logits = self.apply_temperature(logits, temperature)
                    if top_k and top_k > 0:
                        logits = self._apply_top_k(logits, top_k)
                    if top_p is not None and top_p < 1.0:
                        logits = self._apply_top_p(logits, top_p)

                    log_probs = F.log_softmax(logits, dim=-1)


                    k = min(num_beams, self.vocab_size)
                    topk_vals, topk_idx = torch.topk(log_probs, k=k)

                    for val, idx in zip(topk_vals.tolist(), topk_idx.tolist()):
                        new_input_ids = torch.cat([b["input_ids"], torch.tensor([[idx]], device=self.device)], dim=1)
                        new_score = b["score"] + float(val)
                        done = (idx == eos_id)
                        all_candidates.append({
                            "input_ids": new_input_ids,
                            "score": new_score,
                            "done": done
                        })

                all_candidates.sort(key=lambda x: x["score"], reverse=True)
                beams = all_candidates[:num_beams]

            finished = [b for b in beams if b["done"]]
            best = max(finished, key=lambda x: x["score"]) if finished else max(beams, key=lambda x: x["score"])

            output_ids = best["input_ids"].squeeze(0).tolist()
            return self.tokenizer.decode(output_ids, skip_special_tokens=True)

    def apply_temperature(self, logits: torch.Tensor, temperature: float = 1.0) -> torch.Tensor:
        if temperature is None or temperature == 1.0:
            return logits
        if temperature <= 0.0:
            raise ValueError("температура выше > 0")
        return logits / temperature

    def _apply_top_p(self, logits: torch.Tensor, top_p: float = 1.0) -> torch.Tensor:
        if top_p is None or top_p >= 1.0:
            return logits
        if top_p <= 0.0:
            return torch.full_like(logits, float("-inf"))

        sorted_logits, sorted_indices = torch.sort(logits, descending=True)
        sorted_probs = F.softmax(sorted_logits, dim=-1)
        cumulative_probs = torch.cumsum(sorted_probs, dim=-1)

        cutoff_mask = cumulative_probs > top_p
        cutoff_mask[..., 0] = False

        sorted_logits[cutoff_mask] = float("-inf")
        out_logits = torch.full_like(logits, float("-inf"))
        out_logits[sorted_indices] = sorted_logits
        return out_logits

    def _apply_top_k(self, logits: torch.Tensor, top_k: int = None) -> torch.Tensor:
        if top_k is None or top_k <= 0 or top_k >= logits.size(-1):
            return logits
        topk_vals, topk_idx = torch.topk(logits, k=top_k)
        mask = torch.full_like(logits, float("-inf"))
        mask[topk_idx] = logits[topk_idx]
        return mask

    def generate(
        self,
        prompt: str,
        max_length: int = 50,
        strategy: str = "greedy",
        temperature: float = 1.0,
        top_k: int = 0,
        top_p: float = 1.0,
        num_beams: int = 3
    ) -> str:

        strategy = strategy.lower()
        if strategy not in {"greedy", "random", "beam"}:
            raise ValueError("неясное значение")

        if strategy == "beam":
            return self._beam_search_generate(prompt, max_length, num_beams,
                                              temperature=temperature, top_k=top_k, top_p=top_p)

        input_ids = self.tokenizer.encode(prompt, return_tensors="pt").to(self.device)
        eos_id = self.tokenizer.eos_token_id

        with torch.no_grad():
            for _ in range(max_length - input_ids.shape[-1]):
                outputs = self.model(input_ids)
                logits = outputs.logits[:, -1, :].squeeze(0)
                logits = self.apply_temperature(logits, temperature)
                if top_k and top_k > 0:
                    logits = self._apply_top_k(logits, top_k)
                if top_p is not None and top_p < 1.0:
                    logits = self._apply_top_p(logits, top_p)

                if strategy == "greedy":
                    next_token_id = self.greedy_sampling(logits)
                else:
                    next_token_id = self.random_sampling(logits)

                next_token = torch.tensor([[next_token_id]], device=self.device)
                input_ids = torch.cat([input_ids, next_token], dim=1)
                if next_token_id == eos_id:
                    break

        return self.tokenizer.decode(input_ids.squeeze(0).tolist(), skip_special_tokens=True)



In [19]:
m = Model("gpt2")

prompt = "Once upon a time"
print("Prompt:", prompt)
print("\nGreedy:")
print(m.generate(prompt, max_length=40, strategy="greedy"))

print("\nRandom (temperature=1.0):")
print(m.generate(prompt, max_length=40, strategy="random", temperature=1.0))

print("\nRandom (temperature=0.7, top_k=50):")
print(m.generate(prompt, max_length=40, strategy="random", temperature=0.7, top_k=50))

print("\nBeam search (num_beams=4):")
print(m.generate(prompt, max_length=40, strategy="beam", num_beams=4, temperature=1.0, top_k=50))

Prompt: Once upon a time

Greedy:
Once upon a time, the world was a place of great beauty and great danger. The world was a place of great danger, and the world was a place of great danger. The world was a

Random (temperature=1.0):
Once upon a time, there was a furnace at Charles Town Hall, Zurich, to brew electricity from and would supply power. The fuel was pumped off tightly inside the building, and its Salernoelve

Random (temperature=0.7, top_k=50):
Once upon a time, there was a long, long time before the gods could be brought to order, and the people of the world were divided into many sects. Then came a time when the gods

Beam search (num_beams=4):
Once upon a time, the world was filled with the sounds of the sun and the moon, the sounds of the wind, the sounds of the waves, the sounds of the waves, the sounds of
