In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import AutoModel, AutoTokenizer
import numpy as np
import pandas as pd
import scienceplots
import matplotlib.pyplot as plt

plt.style.use(['science', 'notebook', 'grid', 'ieee'])

- Converting the model's probabilistic output (vocab size classification) to text (token)
    - Iteratively, means more computation cost
    - Quality & diversity important
- Two algorithms to use:
    - Greedy search decoding.
    - Beam search decoding.
- Sampling methods
- Top-k & nucleus sampling


- Autoregressive language models.
- $\mathbf{x} = \{x_1, ..., x_k\}$.
- $\mathbf{y} = \{y_1, ..., y_t\}$.
    - chain rule of probability to factorize it as a product of conditional probabilities.


- Detailed encoding
$$p(y_t=w_i\mid y_{<t}, x) = \text{softmax}(z_{t, i})$$
$$\hat{y}_t = \text{argmax}_{y_i}P(y_t \mid y_{<t, x}) (y_{<t} = y_{1,2,..., t-1})$$

#### Decoding 

- Greedy search decoding: 重复性较高, 多样性不足, 整体未必是最优解 (输入法都选1)

In [4]:
from transformers import AutoModelForCausalLM
# GPT2 + language model head
model_name = 'gpt2-xl'
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name)

Downloading pytorch_model.bin:   0%|          | 0.00/6.43G [00:00<?, ?B/s]

Downloading (…)neration_config.json:   0%|          | 0.00/124 [00:00<?, ?B/s]

In [15]:
sample_text = 'A long long time ago,'
model_inputs = tokenizer(sample_text, return_tensors='pt')
input_ids = model_inputs['input_ids']
input_ids

tensor([[  32,  890,  890,  640, 2084,   11]])

In [16]:
input_ids[0]

tensor([  32,  890,  890,  640, 2084,   11])

#### Greedy search

In [17]:
n_steps = 10
# top 5
choices_per_step = 5

iterations = []
with torch.no_grad():
    for _ in range(n_steps):
        iteration = {}
        iteration['input'] = tokenizer.decode(input_ids[0])

        output = model(input_ids=input_ids)

        last_token_logits = output.logits[0, -1, :]
        last_token_probs = torch.softmax(last_token_logits, dim=-1)
        sorted_ids = torch.argsort(last_token_probs, dim=-1, descending=True)

        for choice_idx in range(choices_per_step):
            token_id = sorted_ids[choice_idx]
            token_prob = last_token_probs[token_id]
            token_choice = f'{tokenizer.decode(token_id)}({100*token_prob:.2f}%)'
            iteration[f'choice {choice_idx +1}'] = token_choice

        # append
        print('before append input_ids.shape', input_ids.shape)
        input_ids = torch.cat([input_ids, sorted_ids[None, 0, None]], dim=-1)
        print('after append input_ids.shape', input_ids.shape)

        iterations.append(iteration)

before append input_ids.shape torch.Size([1, 6])
after append input_ids.shape torch.Size([1, 7])
before append input_ids.shape torch.Size([1, 7])
after append input_ids.shape torch.Size([1, 8])
before append input_ids.shape torch.Size([1, 8])
after append input_ids.shape torch.Size([1, 9])
before append input_ids.shape torch.Size([1, 9])
after append input_ids.shape torch.Size([1, 10])
before append input_ids.shape torch.Size([1, 10])
after append input_ids.shape torch.Size([1, 11])
before append input_ids.shape torch.Size([1, 11])
after append input_ids.shape torch.Size([1, 12])
before append input_ids.shape torch.Size([1, 12])
after append input_ids.shape torch.Size([1, 13])
before append input_ids.shape torch.Size([1, 13])
after append input_ids.shape torch.Size([1, 14])
before append input_ids.shape torch.Size([1, 14])
after append input_ids.shape torch.Size([1, 15])
before append input_ids.shape torch.Size([1, 15])
after append input_ids.shape torch.Size([1, 16])


In [18]:
pd.DataFrame(iterations)

Unnamed: 0,input,choice 1,choice 2,choice 3,choice 4,choice 5
0,"A long long time ago,",in(23.47%),I(9.72%),there(7.42%),the(5.44%),a(5.23%)
1,"A long long time ago, in",a(80.26%),the(6.61%),an(4.45%),another(0.85%),my(0.24%)
2,"A long long time ago, in a",galaxy(50.88%),land(11.56%),far(2.65%),place(2.60%),kingdom(2.51%)
3,"A long long time ago, in a galaxy",far(90.35%),not(5.81%),very(0.62%),that(0.35%),much(0.28%)
4,"A long long time ago, in a galaxy far",",(88.23%)",far(8.82%),away(2.12%),distant(0.14%),",(0.03%)"
5,"A long long time ago, in a galaxy far,",far(99.57%),distant(0.05%),Far(0.05%),very(0.05%),long(0.04%)
6,"A long long time ago, in a galaxy far, far",away(97.65%),",(1.90%)",distant(0.05%),far(0.05%),Away(0.05%)
7,"A long long time ago, in a galaxy far, far away",",(26.92%)",...(19.00%),…(13.21%),.(5.36%),"…""(4.21%)"
8,"A long long time ago, in a galaxy far, far away,",there(20.43%),a(18.31%),the(8.79%),in(5.54%),I(3.01%)
9,"A long long time ago, in a galaxy far, far awa...",was(66.84%),lived(18.80%),were(6.99%),existed(1.50%),once(1.04%)
