In [1]:
import os
import re
import warnings
import random
from collections import defaultdict
from typing import Dict, List, Tuple

import torch
import torch.nn.functional as F
import numpy as np
import numpy as np
import torch
from tqdm.notebook import tqdm
from transformers import GPT2LMHeadModel, GPT2Tokenizer

warnings.filterwarnings("ignore")

In [2]:
def seed_everything(seed: int):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True

seed_everything(42)

## Задание

1) Реализовать методы `greedy_sampling` и `generate` (1 балл)
2) Реализовать метод `random_sampling` и поддержать его в `generate` (1 балл)
3) Реализовать метод `_beam_search_generate` и поддержать его в `generate` (2 балла)
4) Реализовать методы `apply_top_p`, `apply_top_k`, `apply_temperature` и поддержать их в `generate` (1 балл)  
Все методы необходимо реализовать через векторные операции в torch/numpy везде где это возможно

In [20]:
class Model:
    def __init__(self, model_name: str = "gpt2"):
        self.model = GPT2LMHeadModel.from_pretrained(model_name)
        self.tokenizer = GPT2Tokenizer.from_pretrained(model_name)
        self.tokenizer.pad_token = self.tokenizer.eos_token
        self.vocab_size = self.tokenizer.vocab_size

    def greedy_sampling(self, logits: torch.Tensor) -> int:
        return torch.argmax(logits, dim=-1).item()

    def random_sampling(self, logits: torch.Tensor) -> int:
        probs = torch.softmax(logits, dim=-1)
        return torch.multinomial(probs, num_samples=1).item()

    def _beam_search_generate(
        self,
        prompt: str,
        max_length: int,
        num_beams: int
    ) -> str:
        input_ids = self.tokenizer.encode(prompt, return_tensors="pt")
        beams = [(input_ids, 0.0)]
        for _ in range(max_length):
            new_beams = []
            for cur_ids, log_prob in beams:
                if cur_ids[0, -1].item() == self.tokenizer.eos_token_id:
                    new_beams.append((cur_ids, log_prob))
                    continue
                with torch.no_grad():
                    outputs = self.model(cur_ids)
                    logits = outputs.logits[:, -1, :]
                probs = torch.log_softmax(logits, dim=-1)
                top_probs, top_indices = torch.topk(probs, k=num_beams, dim=-1)
                for i in range(num_beams):
                    token_prob = top_probs[0, i].item()
                    token_id = top_indices[0, i].item()
                    new_token = torch.tensor([[token_id]], dtype=torch.long)
                    new_ids = torch.cat([cur_ids, new_token], dim=-1)
                    new_log_prob = log_prob + token_prob
                    new_beams.append((new_ids, new_log_prob))
            new_beams.sort(key=lambda x: x[1], reverse=True)
            beams = new_beams[:num_beams]
            if all(beam[0][0, -1].item() == self.tokenizer.eos_token for beam in beams):
                break
        best_sequence = beams[0][0].squeeze(0)
        return self.tokenizer.decode(best_sequence, skip_special_tokens=True)


    def apply_temperature(self, logits: torch.Tensor, temperature: float = 1.0) -> torch.Tensor:
        return logits / temperature

    def _apply_top_p(self, logits: torch.Tensor, top_p: float = 0.9) -> int:
        sorted_logits, sorted_indices = torch.sort(logits, descending=True)
        cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1)
        sorted_indices_to_remove = cumulative_probs > top_p
        sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
        sorted_indices_to_remove[..., 0] = 0
        indices_to_remove = sorted_indices_to_remove.scatter(-1, sorted_indices, sorted_indices_to_remove)
        filtered_logits = logits.masked_fill(indices_to_remove, float('-inf'))
        return filtered_logits


    def _apply_top_k(self, logits: torch.Tensor, top_k: int = 0) -> torch.Tensor:
        if top_k == 0:
            return logits
        top_k = min(top_k, logits.size(-1))
        indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
        filtered_logits = logits.masked_fill(indices_to_remove, float('-inf'))
        return filtered_logits

    def generate(
        self,
        prompt: str,
        max_length: int = 50,
        strategy: str = "greedy",
        temperature: float = 1.0,
        top_k: int = 0,
        top_p: float = 1.0,
        num_beams: int = 3
    ) -> str:
        if strategy == "beam_search":
            return self._beam_search_generate(prompt, max_length, num_beams)

        input_ids = self.tokenizer.encode(prompt, return_tensors="pt")
        generated_ids = input_ids.clone()

        for _ in range(max_length):
            with torch.no_grad():
                outputs = self.model(generated_ids)
                next_token_logits = outputs.logits[:, -1, :]

            if temperature != 1.0:
                next_token_logits = self.apply_temperature(next_token_logits, temperature)

            if top_k > 0:
                next_token_logits = self._apply_top_k(next_token_logits, top_k)

            if top_p < 1.0:
                next_token_logits = self._apply_top_p(next_token_logits, top_p)

            if strategy == "greedy":
                next_token_id = self.greedy_sampling(next_token_logits)

            elif strategy == "random":
                next_token_id = self.random_sampling(next_token_logits)

            else:
                raise ValueError()

            next_token = torch.tensor([[next_token_id]], dtype=torch.long)
            generated_ids = torch.cat([generated_ids, next_token], dim=-1)

            if next_token_id == self.tokenizer.eos_token_id:
                break

        return self.tokenizer.decode(generated_ids.squeeze(0), skip_special_tokens=True)

In [21]:
model = Model()


In [22]:
prompt = "Once upon a time in Hollywood, "

### 1. Greedy Sampling

In [23]:
greedy_output = model.generate(prompt, strategy="greedy")
print(greedy_output)

Once upon a time in Hollywood,  the first thing you would notice is that the actors were all wearing black.  The actors were all wearing black.  The actors were all wearing black.  The actors were all wearing black.  The actors were all wearing


### 2. Random Sampling

In [24]:
random_output = model.generate(prompt, strategy="random")
print(random_output)

Once upon a time in Hollywood,  Mick Carter became estranged from Chip Cullen -- who now lives by the name "Playboy" -- , who read in a screen check "Roger" on that day with Carter in 2001. Mickey Cullen \  was an accomplished playboy from


### 3. Temperature


In [25]:
temp = model.generate(prompt, strategy="random", temperature=0.5)
print(temp)

Once upon a time in Hollywood,  the only way to be successful was to be like the rest of the world. People were curious, but they were not looking for another opportunity. So, in the fall of 2013, I started working on an application for an agency called Glam


### 4. Top-K

In [16]:
topk = model.generate(prompt, strategy="random", top_k=30)
print(topk)

Once upon a time in Hollywood, _____________ was the most admired man in America with the world in his grip, the most respected journalist at the center of an organization that would make any man want to be a man again, and the most loved by the media and the nation, as


### 5. Top-P

In [17]:
topp = model.generate(prompt, strategy="random", top_p=0.75)
print(topp)

Once upon a time in Hollywood,  Hollywood, based in the rural Southwest, decided to have a kid named Hank who lived in the hills.  But Hank went on to marry a woman he never knew existed.  And so it was that the sun shone on the


### 6. Beam Search

In [18]:
beam_search = model.generate(prompt, strategy="beam_search", num_beams=10)
print(beam_search)

Once upon a time in Hollywood,  there was a time when there was a time when there was a time when there was a time when there was a time when there was a time when there was a time when there was a time when there was a time when there was a time
