# テキスト生成

## 貪欲法によるデコード

In [None]:
# GPT-2をロード
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
import pandas as pd

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"
device

In [None]:
model_name = "gpt2-xl"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name).to(device)

In [None]:
input_txt = "Transformers are the"
input_ids = tokenizer(input_txt, return_tensors="pt")["input_ids"].to(device)
input_ids

In [None]:
iterations = []
n_steps = 8
choices_per_step = 5

In [None]:
tokenizer.decode(input_ids[0])

In [None]:
for k in input_ids[0]:
    print(k, tokenizer.decode(k))

In [None]:
output = model(input_ids=input_ids)
output.keys()

In [None]:
output.logits.shape

In [None]:
next_token_logits = output.logits[0, -1, :]
next_token_logits.shape

In [None]:
next_token_probs = torch.softmax(next_token_logits, dim=-1)
next_token_probs.shape

In [None]:
sorted_ids = torch.argsort(next_token_probs, dim=-1, descending=True)
sorted_ids

In [None]:
token_id = sorted_ids[0]
token_id

In [None]:
token_prob = next_token_probs[token_id]
token_prob

In [None]:
token_choice = (f"{tokenizer.decode(token_id)} ({100 * token_prob: .2f}%)")
token_choice

In [None]:
input_ids = torch.cat([input_ids, sorted_ids[None, 0, None]], dim=-1)
input_ids.shape

In [None]:
for k in input_ids[0]:
    print(k, tokenizer.decode(k))

In [None]:
output = model(input_ids=input_ids)
output.logits.shape

In [None]:
# 貪欲法でデコード
# 動作を理解するために手動で生成
with torch.no_grad():
    for _ in range(n_steps):
        iteration = dict()
        iteration["Input"] = tokenizer.decode(input_ids[0])
        output = model(input_ids=input_ids)
        # Select logits of the first batch and the last token and apply softmax
        next_token_logits = output.logits[0, -1, :]
        next_token_probs = torch.softmax(next_token_logits, dim=-1)
        sorted_ids = torch.argsort(next_token_probs, dim=-1, descending=True)
        # Store tokens with highest probabilities
        for choice_idx in range(choices_per_step):
            token_id = sorted_ids[choice_idx]
            token_prob = next_token_probs[token_id].cpu().numpy()
            token_choice = (
                f"{tokenizer.decode(token_id)} ({100 * token_prob:.2f}%)"
            )
            iteration[f"Choice {choice_idx+1}"] = token_choice
        # Append predicted next token to input
        input_ids = torch.cat([input_ids, sorted_ids[None, 0, None]], dim=-1)
        iterations.append(iteration)
        
pd.DataFrame(iterations)

In [None]:
# generateメソッドを使う
input_ids = tokenizer(input_txt, return_tensors="pt")["input_ids"].to(device)
output = model.generate(input_ids, max_new_tokens=n_steps, do_sample=False)
tokenizer.decode(output[0])

In [None]:
# もう少し長い例文で試す
max_length = 128
input_txt = """In a shocking finding, scientist discovered \
a herd of unicorns living in a remote, previously unexplored \
valley, in the Andes Mountains. Even more surprising to the \
researchers was the fact that the unicorns spoke perfect English.\n\n
"""

In [None]:
# 反復的な出力系列を生成している => 貪欲法でよくある欠点
input_ids = tokenizer(input_txt, return_tensors="pt")["input_ids"].to(device)
output_greedy = model.generate(input_ids, max_length=max_length, do_sample=False)
tokenizer.decode(output_greedy[0])

In [None]:
output_greedy[0]

## ビームサーチによるデコード

In [None]:
# 各入力に対する次のトークン候補の確率を計算する
# 各時刻で出力されるlogitsを正規化することで確率分布にできる

# 生成したこの128トークンの系列の生成確率を求めたい
labels = output_greedy
labels.shape

In [None]:
# 各時刻でのトークンの生成確率を求める
output = model(labels)
output.logits.shape

In [None]:
logits = output.logits[:, :-1, :]
labels = labels[:, 1:]
logits.shape, labels.shape

In [None]:
import torch.nn.functional as F

# アンダーフローしないように対数確率にする
logp = F.log_softmax(logits, dim=-1)
logp.shape

In [None]:
labels.unsqueeze(2).shape

In [None]:
# logpから生成トークンの確率を収集して足し合わせる
logp_label = torch.gather(logp, 2, labels.unsqueeze(2)).squeeze(-1)
logp_label.shape

In [None]:
# 系列の対数尤度は和をとればよい
seq_log_prob = torch.sum(logp_label[:, 47:])
seq_log_prob

In [None]:
# 上の処理を関数にまとめると
import torch.nn.functional as F

def log_probs_from_logits(logits, labels):
    logp = F.log_softmax(logits, dim=-1)
    logp_label = torch.gather(logp, 2, labels.unsqueeze(2)).squeeze(-1)
    return logp_label

In [None]:
def sequence_logprob(model, labels, input_len=0):
    with torch.no_grad():
        output = model(labels)
        log_probs = log_probs_from_logits(
            output.logits[:, :-1, :], labels[:, 1:])
        seq_log_prob = torch.sum(log_probs[:, input_len:])
    return seq_log_prob.cpu().numpy()

In [None]:
# 貪欲法で生成した系列に対する対数尤度
logp = sequence_logprob(model, output_greedy, input_len=len(input_ids[0]))
print(tokenizer.decode(output_greedy[0]))
print(f"log_prob: {logp:.2f}")

In [None]:
# ビームサーチで生成した系列に対する対数尤度
# 貪欲法に比べて対数尤度が大きくなっており、よりありえそうな系列を生成していることがわかる
output_beam = model.generate(input_ids, max_length=max_length, num_beams=5, do_sample=False)
logp = sequence_logprob(model, output_beam, input_len=len(input_ids[0]))
print(tokenizer.decode(output_beam[0]))
print(f"log_prob: {logp:.2f}")

In [None]:
# no_repeat_ngram_sizeを指定することで以前に出現したn-gramが出現しないようにする
# 文章の繰り返しが防げる
output_beam = model.generate(input_ids, max_length=max_length, num_beams=5, do_sample=False, no_repeat_ngram_size=2)
logp = sequence_logprob(model, output_beam, input_len=len(input_ids[0]))
print(tokenizer.decode(output_beam[0]))
print(f"log_prob: {logp:.2f}")