## CS310 Natural Language Processing
## Lab 11: Explore Natural Language Generation

In this lab, we will practice using pre-trained transformer-based language models for natural language generation tasks.

In [6]:
import torch
import torch.nn.functional as F
import random
import numpy as np

## T1. Explore Pretrained GPT-2 Model

In this task, you will explore the GPT-2 model using the `transformers` library.

Just like in the previous lab, you will need to download the pretrained model and unzip it to `./gpt2zh`. 

Note that this is not the original version of GPT-2 provided by OpenAI (https://huggingface.co/openai-community/gpt2), but a fine-tuned version for Chinese text generation.

In [12]:
from transformers import AutoTokenizer, GPT2LMHeadModel

gpt2_tokenizer = AutoTokenizer.from_pretrained("./gpt2zh")
gpt2_model = GPT2LMHeadModel.from_pretrained("./gpt2zh")
# Evaluation mode
gpt2_model = gpt2_model.half()

print('vocab size:', gpt2_tokenizer.vocab_size)
print(f'special token {gpt2_tokenizer.sep_token}:', gpt2_tokenizer.sep_token_id)
print(f'special token {gpt2_tokenizer.cls_token}:', gpt2_tokenizer.cls_token_id)
print(f'special token {gpt2_tokenizer.pad_token}:', gpt2_tokenizer.pad_token_id)

# Use [SEP] as end-of-sentence token
gpt2_model.config.eos_token_id = gpt2_tokenizer.sep_token_id

vocab size: 21128
special token [SEP]: 102
special token [CLS]: 101
special token [PAD]: 0


The tokenizer can return the token IDs and the attention mask that indicates which tokens are padding tokens (`1` for real tokens, `0` for padding tokens).

Since we only have one sentence in the "batch", there is no padding used, and thus no `0` in the attention mask.

In [13]:
input_text = '学而时习之，不亦说乎！'
input_encoded = gpt2_tokenizer(input_text, return_tensors="pt")

print('input ids:', input_encoded['input_ids'])
print('input attention mask:', input_encoded['attention_mask'])

# Map token ids back to tokens
print('input tokens:', gpt2_tokenizer.convert_ids_to_tokens(input_encoded['input_ids'][0]))

input ids: tensor([[ 101, 2110, 5445, 3198,  739,  722, 8024,  679,  771, 6432,  725, 8013,
          102]])
input attention mask: tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]])
input tokens: ['[CLS]', '学', '而', '时', '习', '之', '，', '不', '亦', '说', '乎', '！', '[SEP]']


It's easy to directly use the `generate` method to generate some sentences:

In [4]:
input_text = "子曰：人"
input_encoded = gpt2_tokenizer(input_text, return_tensors="pt", add_special_tokens=False)
n_outputs = 5

output = gpt2_model.generate(**input_encoded, 
                                 max_length=20, 
                                 num_return_sequences=n_outputs,
                                 do_sample=True, 
                                 top_k=50, 
                                 top_p=0.95, 
                                 temperature=0.7,
                                 pad_token_id=0,
                                 )
# print(type(output))
# print(output.shape)

for i in range(n_outputs):
    output_text = gpt2_tokenizer.decode(output[i], skip_special_tokens=True)
    print(output_text)

子 曰 ： 人 民 ， 我 为 人 民 ， 我 为 人 民 ， 我 为 人 民
子 曰 ： 人 物 有 情 ， 所 以 有 情 。 （ 《 世 纪 英 雄 》
子 曰 ： 人 之 道 ， 君 之 道 也 ， 子 之 道 也 。 君 之 道
子 曰 ： 人 不 可 以 不 可 以 不 可 以 不 可 以 不 可 以 不
子 曰 ： 人 生 ， 我 们 在 生 命 的 路 上 ， 在 生 命 的 路


We can see that the generation is far from perfect. It still has good chances to produce a lot of repetitions.

---

## T2. Implement Top-k Sampling Manually

Let's first try greedy search, i.e., top-1 sampling.

*Hint*: Call `argmax()` on the logits; Use the `convert_ids_to_tokens()` method to convert the token ids to string.

In [14]:
input_text = "今天天气"
input_encoded = gpt2_tokenizer(input_text, return_tensors="pt", add_special_tokens=False)
print('input size:', input_encoded.input_ids.shape[1])

output = gpt2_model(input_encoded.input_ids, 
                   attention_mask=input_encoded.attention_mask)
logits = output.logits
print(logits.shape)

# Get the probability distribution predicted at the last token's position
last_token_logits = logits[0, -1, :]

# Get the most likely token id from this distribution
most_likely_token_id = torch.argmax(last_token_logits).item()  # 使用 .item() 获取标量值

# Convert the token id to a token
most_likely_token = gpt2_tokenizer.convert_ids_to_tokens([most_likely_token_id])[0]  # 将标量包装成列表
print(most_likely_token)

# You should expect to see the following output:
# input size: 4
# torch.Size([1, 4, 21128])
# 预

input size: 4
torch.Size([1, 4, 21128])
预


Once you are done with the above code, you can now implement the full generation loop: at each iteration, you select the most likely token and append it to the end input, and then feed the new input to the model for predicting the next token. 

The loop continues until `max_gen_len` is reached, or a `"[SEP]"` token is generated.

**Note**: 
- Use `torch.cat` to append elements to input IDs
- The `attn_mask` also needs be updated at each iteration.

In [15]:
max_gen_len = 50

input_text = "今天天气"
input_encoded = gpt2_tokenizer(input_text, return_tensors="pt", add_special_tokens=False)
input_ids = input_encoded.input_ids
attn_mask = input_encoded.attention_mask

count = 0
while count < max_gen_len:
    output = gpt2_model(input_ids, attention_mask=attn_mask)
    logits = output.logits

    # Get the last token's logits
    last_token_logits = logits[0, -1, :]
    
    # Get the most likely token id
    sampled_token_id = torch.argmax(last_token_logits)
    
    if sampled_token_id == gpt2_tokenizer.sep_token_id:
        break

    # Append the sampled token id to the input
    input_ids = torch.cat([input_ids, sampled_token_id.unsqueeze(0).unsqueeze(0)], dim=1)
    # Increment the attention mask
    attn_mask = torch.cat([attn_mask, torch.ones(1, 1, dtype=attn_mask.dtype)], dim=1)

    count += 1


# Test
SPECIAL_TOKEN_IDS = set([gpt2_tokenizer.sep_token_id, 
                         gpt2_tokenizer.cls_token_id, 
                         gpt2_tokenizer.pad_token_id,
                         100]) # 100 for [UNK]

# Decode the generated tokens ids
for i in range(input_ids.shape[1]):
    tok_id = input_ids[0, i].item()
    # Skip the special tokens
    if tok_id not in SPECIAL_TOKEN_IDS:
        print(gpt2_tokenizer.convert_ids_to_tokens(input_ids[0, i].item()), end='')

# You should expect to see the following output:
# 今天天气预报：今天白天，我市阴天有小雨，气温：小雨转多云，气温：小雨转多云，气温：小雨转多云，气温：小雨转多

今天天气预报：今天白天，我市阴天有小雨，气温：小雨转多云，气温：小雨转多云，气温：小雨转多云，气温：小雨转多

As you can see, greedy search results in very repetitive text.

---

Now, let's implement a `top-k` sampling algorithm.

The idea is to **uniformly** sample from top-k most likely next tokens. PyTorch tensor provides a `topk` method to get the top-k values and indices. 

In the following example, you can check the **top 5** most likely words following the sentence "今天天气":

In [16]:
k = 5
input_text = "今天天气"
input_encoded = gpt2_tokenizer(input_text, return_tensors="pt", add_special_tokens=False)
input_ids = input_encoded.input_ids
attn_mask = input_encoded.attention_mask

output = gpt2_model(input_ids, attention_mask=attn_mask)
logits = output.logits

### START YOUR CODE ###
# Get the last token's logits
last_token_logits = logits[0, -1, :]

# Get top-k logits and indices
topk_logits, topk_indices = torch.topk(last_token_logits, k)


# Test
print(topk_logits)
print(topk_indices)

for i in range(k):
    tok_id = topk_indices[i].item()
    print(gpt2_tokenizer.convert_ids_to_tokens(tok_id), end=' ')

# You should expect to see the following output:
# tensor([7.8924, 7.8550, 7.5893, 7.3502, 7.3069], grad_fn=<TopkBackward0>)
# tensor([7564, 2523,  679, 1962, 6820])
# 预 很 不 好 还 

tensor([7.8906, 7.8555, 7.5859, 7.3477, 7.3047], dtype=torch.float16,
       grad_fn=<TopkBackward0>)
tensor([7564, 2523,  679, 1962, 6820])
预 很 不 好 还 

In [8]:
k = 5
input_text = "今天天气"
input_encoded = gpt2_tokenizer(input_text, return_tensors="pt", add_special_tokens=False)
input_ids = input_encoded.input_ids
attn_mask = input_encoded.attention_mask

output = gpt2_model(input_ids, attention_mask=attn_mask)
logits = output.logits

### START YOUR CODE ###
# Get the last token's logits
last_token_logits = logits[0, -1, :]

# Get top-k logits and indices
topk_logits, topk_indices = torch.topk(last_token_logits, k)


# Test
print(topk_logits)
print(topk_indices)

for i in range(k):
    tok_id = topk_indices[i].item()
    print(gpt2_tokenizer.convert_ids_to_tokens(tok_id), end=' ')

# You should expect to see the following output:
# tensor([7.8924, 7.8550, 7.5893, 7.3502, 7.3069], grad_fn=<TopkBackward0>)
# tensor([7564, 2523,  679, 1962, 6820])
# 预 很 不 好 还 

tensor([7.8906, 7.8555, 7.5859, 7.3477, 7.3047], dtype=torch.float16,
       grad_fn=<TopkBackward0>)
tensor([7564, 2523,  679, 1962, 6820])
预 很 不 好 还 

In [17]:
k = 5
input_text = "今天天气"
input_encoded = gpt2_tokenizer(input_text, return_tensors="pt", add_special_tokens=False)
input_ids = input_encoded.input_ids
attn_mask = input_encoded.attention_mask

output = gpt2_model(input_ids, attention_mask=attn_mask)
logits = output.logits

### START YOUR CODE ###
# Get the last token's logits
last_token_logits = logits[0, -1, :]

# Get top-k logits and indices
topk_logits, topk_indices = torch.topk(last_token_logits, k)


# Test
print(topk_logits)
print(topk_indices)

for i in range(k):
    tok_id = topk_indices[i].item()
    print(gpt2_tokenizer.convert_ids_to_tokens(tok_id), end=' ')

# You should expect to see the following output:
# tensor([7.8924, 7.8550, 7.5893, 7.3502, 7.3069], grad_fn=<TopkBackward0>)
# tensor([7564, 2523,  679, 1962, 6820])
# 预 很 不 好 还 

tensor([7.8906, 7.8555, 7.5859, 7.3477, 7.3047], dtype=torch.float16,
       grad_fn=<TopkBackward0>)
tensor([7564, 2523,  679, 1962, 6820])
预 很 不 好 还 

In [18]:
k = 5
input_text = "今天天气"
input_encoded = gpt2_tokenizer(input_text, return_tensors="pt", add_special_tokens=False)
input_ids = input_encoded.input_ids
attn_mask = input_encoded.attention_mask

output = gpt2_model(input_ids, attention_mask=attn_mask)
logits = output.logits

### START YOUR CODE ###
# Get the last token's logits
last_token_logits = logits[0, -1, :]

# Get top-k logits and indices
topk_logits, topk_indices = torch.topk(last_token_logits, k)


# Test
print(topk_logits)
print(topk_indices)

for i in range(k):
    tok_id = topk_indices[i].item()
    print(gpt2_tokenizer.convert_ids_to_tokens(tok_id), end=' ')

# You should expect to see the following output:
# tensor([7.8924, 7.8550, 7.5893, 7.3502, 7.3069], grad_fn=<TopkBackward0>)
# tensor([7564, 2523,  679, 1962, 6820])
# 预 很 不 好 还 

tensor([7.8906, 7.8555, 7.5859, 7.3477, 7.3047], dtype=torch.float16,
       grad_fn=<TopkBackward0>)
tensor([7564, 2523,  679, 1962, 6820])
预 很 不 好 还 

In [19]:
k = 5
input_text = "今天天气"
input_encoded = gpt2_tokenizer(input_text, return_tensors="pt", add_special_tokens=False)
input_ids = input_encoded.input_ids
attn_mask = input_encoded.attention_mask

output = gpt2_model(input_ids, attention_mask=attn_mask)
logits = output.logits

### START YOUR CODE ###
# Get the last token's logits
last_token_logits = logits[0, -1, :]

# Get top-k logits and indices
topk_logits, topk_indices = torch.topk(last_token_logits, k)


# Test
print(topk_logits)
print(topk_indices)

for i in range(k):
    tok_id = topk_indices[i].item()
    print(gpt2_tokenizer.convert_ids_to_tokens(tok_id), end=' ')

# You should expect to see the following output:
# tensor([7.8924, 7.8550, 7.5893, 7.3502, 7.3069], grad_fn=<TopkBackward0>)
# tensor([7564, 2523,  679, 1962, 6820])
# 预 很 不 好 还 

tensor([7.8906, 7.8555, 7.5859, 7.3477, 7.3047], dtype=torch.float16,
       grad_fn=<TopkBackward0>)
tensor([7564, 2523,  679, 1962, 6820])
预 很 不 好 还 

In [20]:
k = 5
input_text = "今天天气"
input_encoded = gpt2_tokenizer(input_text, return_tensors="pt", add_special_tokens=False)
input_ids = input_encoded.input_ids
attn_mask = input_encoded.attention_mask

output = gpt2_model(input_ids, attention_mask=attn_mask)
logits = output.logits

### START YOUR CODE ###
# Get the last token's logits
last_token_logits = logits[0, -1, :]

# Get top-k logits and indices
topk_logits, topk_indices = torch.topk(last_token_logits, k)


# Test
print(topk_logits)
print(topk_indices)

for i in range(k):
    tok_id = topk_indices[i].item()
    print(gpt2_tokenizer.convert_ids_to_tokens(tok_id), end=' ')

# You should expect to see the following output:
# tensor([7.8924, 7.8550, 7.5893, 7.3502, 7.3069], grad_fn=<TopkBackward0>)
# tensor([7564, 2523,  679, 1962, 6820])
# 预 很 不 好 还 

tensor([7.8906, 7.8555, 7.5859, 7.3477, 7.3047], dtype=torch.float16,
       grad_fn=<TopkBackward0>)
tensor([7564, 2523,  679, 1962, 6820])
预 很 不 好 还 

In [21]:
k = 5
input_text = "今天天气"
input_encoded = gpt2_tokenizer(input_text, return_tensors="pt", add_special_tokens=False)
input_ids = input_encoded.input_ids
attn_mask = input_encoded.attention_mask

output = gpt2_model(input_ids, attention_mask=attn_mask)
logits = output.logits

### START YOUR CODE ###
# Get the last token's logits
last_token_logits = logits[0, -1, :]

# Get top-k logits and indices
topk_logits, topk_indices = torch.topk(last_token_logits, k)


# Test
print(topk_logits)
print(topk_indices)

for i in range(k):
    tok_id = topk_indices[i].item()
    print(gpt2_tokenizer.convert_ids_to_tokens(tok_id), end=' ')

# You should expect to see the following output:
# tensor([7.8924, 7.8550, 7.5893, 7.3502, 7.3069], grad_fn=<TopkBackward0>)
# tensor([7564, 2523,  679, 1962, 6820])
# 预 很 不 好 还 

tensor([7.8906, 7.8555, 7.5859, 7.3477, 7.3047], dtype=torch.float16,
       grad_fn=<TopkBackward0>)
tensor([7564, 2523,  679, 1962, 6820])
预 很 不 好 还 

In [22]:
k = 5
input_text = "今天天气"
input_encoded = gpt2_tokenizer(input_text, return_tensors="pt", add_special_tokens=False)
input_ids = input_encoded.input_ids
attn_mask = input_encoded.attention_mask

output = gpt2_model(input_ids, attention_mask=attn_mask)
logits = output.logits

### START YOUR CODE ###
# Get the last token's logits
last_token_logits = logits[0, -1, :]

# Get top-k logits and indices
topk_logits, topk_indices = torch.topk(last_token_logits, k)


# Test
print(topk_logits)
print(topk_indices)

for i in range(k):
    tok_id = topk_indices[i].item()
    print(gpt2_tokenizer.convert_ids_to_tokens(tok_id), end=' ')

# You should expect to see the following output:
# tensor([7.8924, 7.8550, 7.5893, 7.3502, 7.3069], grad_fn=<TopkBackward0>)
# tensor([7564, 2523,  679, 1962, 6820])
# 预 很 不 好 还 

tensor([7.8906, 7.8555, 7.5859, 7.3477, 7.3047], dtype=torch.float16,
       grad_fn=<TopkBackward0>)
tensor([7564, 2523,  679, 1962, 6820])
预 很 不 好 还 

In [23]:
k = 5
input_text = "今天天气"
input_encoded = gpt2_tokenizer(input_text, return_tensors="pt", add_special_tokens=False)
input_ids = input_encoded.input_ids
attn_mask = input_encoded.attention_mask

output = gpt2_model(input_ids, attention_mask=attn_mask)
logits = output.logits

### START YOUR CODE ###
# Get the last token's logits
last_token_logits = logits[0, -1, :]

# Get top-k logits and indices
topk_logits, topk_indices = torch.topk(last_token_logits, k)


# Test
print(topk_logits)
print(topk_indices)

for i in range(k):
    tok_id = topk_indices[i].item()
    print(gpt2_tokenizer.convert_ids_to_tokens(tok_id), end=' ')

# You should expect to see the following output:
# tensor([7.8924, 7.8550, 7.5893, 7.3502, 7.3069], grad_fn=<TopkBackward0>)
# tensor([7564, 2523,  679, 1962, 6820])
# 预 很 不 好 还 

tensor([7.8906, 7.8555, 7.5859, 7.3477, 7.3047], dtype=torch.float16,
       grad_fn=<TopkBackward0>)
tensor([7564, 2523,  679, 1962, 6820])
预 很 不 好 还 

Next let's integrate the top-k sampling algorithm into the generation process. The uniform sampling can be implemented using `random.choices` among the top-k indices.

In [24]:
def generate_topk_uniform(input_text, k=5, max_gen_len=50):
    '''
    Generate tokens from the top-k logits, and yield the sampled token id.
    Tokens are sampled from a naive uniform distribution.
    '''
    input_encoded = gpt2_tokenizer(input_text, return_tensors="pt", add_special_tokens=False)
    input_ids = input_encoded.input_ids
    attn_mask = input_encoded.attention_mask

    count = 0
    while count < max_gen_len:
        output = gpt2_model(input_ids, attention_mask=attn_mask)
        logits = output.logits

        # Get the last token's logits
        last_token_logits = logits[0, -1, :]
        
        # Get top-k logits and indices
        topk_logits, topk_indices = torch.topk(last_token_logits, k)
        
        # Sample uniformly from top-k indices
        sampled_token_id = topk_indices[random.randint(0, k-1)].item()  # 添加 .item()
        
        yield sampled_token_id
        if sampled_token_id == gpt2_tokenizer.sep_token_id:
            break

        # Append the sampled token id to the input
        input_ids = torch.cat([input_ids, torch.tensor([[sampled_token_id]])], dim=1)
        # Increment the attention mask
        attn_mask = torch.cat([attn_mask, torch.ones(1, 1, dtype=attn_mask.dtype)], dim=1)

        count += 1

In [25]:
k = 5
input_text = "今天天气"
input_encoded = gpt2_tokenizer(input_text, return_tensors="pt", add_special_tokens=False)
input_ids = input_encoded.input_ids
attn_mask = input_encoded.attention_mask

output = gpt2_model(input_ids, attention_mask=attn_mask)
logits = output.logits

### START YOUR CODE ###
# Get the last token's logits
last_token_logits = logits[0, -1, :]

# Get top-k logits and indices
topk_logits, topk_indices = torch.topk(last_token_logits, k)


# Test
print(topk_logits)
print(topk_indices)

for i in range(k):
    tok_id = topk_indices[i].item()
    print(gpt2_tokenizer.convert_ids_to_tokens(tok_id), end=' ')

# You should expect to see the following output:
# tensor([7.8924, 7.8550, 7.5893, 7.3502, 7.3069], grad_fn=<TopkBackward0>)
# tensor([7564, 2523,  679, 1962, 6820])
# 预 很 不 好 还 

tensor([7.8906, 7.8555, 7.5859, 7.3477, 7.3047], dtype=torch.float16,
       grad_fn=<TopkBackward0>)
tensor([7564, 2523,  679, 1962, 6820])
预 很 不 好 还 

In [26]:
k = 5
input_text = "今天天气"
input_encoded = gpt2_tokenizer(input_text, return_tensors="pt", add_special_tokens=False)
input_ids = input_encoded.input_ids
attn_mask = input_encoded.attention_mask

output = gpt2_model(input_ids, attention_mask=attn_mask)
logits = output.logits

### START YOUR CODE ###
# Get the last token's logits
last_token_logits = logits[0, -1, :]

# Get top-k logits and indices
topk_logits, topk_indices = torch.topk(last_token_logits, k)


# Test
print(topk_logits)
print(topk_indices)

for i in range(k):
    tok_id = topk_indices[i].item()
    print(gpt2_tokenizer.convert_ids_to_tokens(tok_id), end=' ')

# You should expect to see the following output:
# tensor([7.8924, 7.8550, 7.5893, 7.3502, 7.3069], grad_fn=<TopkBackward0>)
# tensor([7564, 2523,  679, 1962, 6820])
# 预 很 不 好 还 

tensor([7.8906, 7.8555, 7.5859, 7.3477, 7.3047], dtype=torch.float16,
       grad_fn=<TopkBackward0>)
tensor([7564, 2523,  679, 1962, 6820])
预 很 不 好 还 

In [27]:
k = 5
input_text = "今天天气"
input_encoded = gpt2_tokenizer(input_text, return_tensors="pt", add_special_tokens=False)
input_ids = input_encoded.input_ids
attn_mask = input_encoded.attention_mask

output = gpt2_model(input_ids, attention_mask=attn_mask)
logits = output.logits

### START YOUR CODE ###
# Get the last token's logits
last_token_logits = logits[0, -1, :]

# Get top-k logits and indices
topk_logits, topk_indices = torch.topk(last_token_logits, k)


# Test
print(topk_logits)
print(topk_indices)

for i in range(k):
    tok_id = topk_indices[i].item()
    print(gpt2_tokenizer.convert_ids_to_tokens(tok_id), end=' ')

# You should expect to see the following output:
# tensor([7.8924, 7.8550, 7.5893, 7.3502, 7.3069], grad_fn=<TopkBackward0>)
# tensor([7564, 2523,  679, 1962, 6820])
# 预 很 不 好 还 

tensor([7.8906, 7.8555, 7.5859, 7.3477, 7.3047], dtype=torch.float16,
       grad_fn=<TopkBackward0>)
tensor([7564, 2523,  679, 1962, 6820])
预 很 不 好 还 

In [28]:
# Test
input_text = "今天天气"
print(input_text, end='')
for tok_id in generate_topk_uniform(input_text, k=50):
    if tok_id not in SPECIAL_TOKEN_IDS:
        print(gpt2_tokenizer.convert_ids_to_tokens(tok_id), end='')

今天天气寒意丝件可可兰还不上？真要是看来可我得加温不到？又，为的要知晓他家一般来这有钱了没准一来想让吃，想想

In [29]:
# Test
input_text = "子曰：人"
print(input_text, end='')
for tok_id in generate_topk_uniform(input_text, k=50):
    if tok_id not in SPECIAL_TOKEN_IDS:
        print(gpt2_tokenizer.convert_ids_to_tokens(tok_id), end='')

子曰：人间所述难于自白不用说你在家和母。何因他不喜那般于物易自是何愁天时日落花下我只见花时开当晚还下面笑如雪

In [30]:
k = 5
input_text = "今天天气"
input_encoded = gpt2_tokenizer(input_text, return_tensors="pt", add_special_tokens=False)
input_ids = input_encoded.input_ids
attn_mask = input_encoded.attention_mask

output = gpt2_model(input_ids, attention_mask=attn_mask)
logits = output.logits

### START YOUR CODE ###
# Get the last token's logits
last_token_logits = logits[0, -1, :]

# Get top-k logits and indices
topk_logits, topk_indices = torch.topk(last_token_logits, k)


# Test
print(topk_logits)
print(topk_indices)

for i in range(k):
    tok_id = topk_indices[i].item()
    print(gpt2_tokenizer.convert_ids_to_tokens(tok_id), end=' ')

# You should expect to see the following output:
# tensor([7.8924, 7.8550, 7.5893, 7.3502, 7.3069], grad_fn=<TopkBackward0>)
# tensor([7564, 2523,  679, 1962, 6820])
# 预 很 不 好 还 

tensor([7.8906, 7.8555, 7.5859, 7.3477, 7.3047], dtype=torch.float16,
       grad_fn=<TopkBackward0>)
tensor([7564, 2523,  679, 1962, 6820])
预 很 不 好 还 

We can note that although the above uniform top-k sampling solves repetition issue, it will however produce *extremely incoherent* text. We can remedy this by using a proportional sampling instead of uniform sampling.

There are plenty of different ways to implement proportionaly sampling. You can either:
- Create list of cumulative relative probabilities of the top k tokens. For instance, if the relative probabilities of $k=5$ tokens are $0.1$, $0.2$, $0.5$, $0.1$, and $0.1$, then you cumulative probability list is `cum_prob = [0.1, 0.3, 0.8, 0.9, 1.0]`. 
- Then you draw a random number $r$ from the unifrom distribution $[0,1]$ by `random.random()`, and you decide which token is sampled by telling which bin of `cum_prob` that $r$ falls into.
- Or instead, you use the `torch.multinomial()` function to accomplish similar sampling. *Note* the input weight provided to `torch.multinomial` should be the relative probabilities of the top $k$ tokens, which can be obtained from applying softmax to the logits. 

In [31]:
def generate_topk_proportion(input_text, k=50, max_gen_len=50):
    '''
    Generate tokens from the top-k logits, and yield the sampled token id.
    Tokens are sampled proportional to their logits.
    '''
    input_encoded = gpt2_tokenizer(input_text, return_tensors="pt", add_special_tokens=False)
    input_ids = input_encoded.input_ids
    attn_mask = input_encoded.attention_mask

    count = 0
    while count < max_gen_len:
        output = gpt2_model(input_ids, attention_mask=attn_mask)
        logits = output.logits

        # Get the last token's logits
        last_token_logits = logits[0, -1, :]
        
        # Get top-k logits and indices
        topk_logits, topk_indices = torch.topk(last_token_logits, k)
        
        # Convert logits to probabilities
        topk_probs = F.softmax(topk_logits, dim=0)
        
        # Sample from top-k probabilities
        sampled_token_id = topk_indices[torch.multinomial(topk_probs, 1).item()].item()  # 添加 .item()
        
        yield sampled_token_id
        if sampled_token_id == gpt2_tokenizer.sep_token_id:
            break

        # Append the sampled token id to the input
        input_ids = torch.cat([input_ids, torch.tensor([[sampled_token_id]])], dim=1)
        # Increment the attention mask
        attn_mask = torch.cat([attn_mask, torch.ones(1, 1, dtype=attn_mask.dtype)], dim=1)

        count += 1

In [32]:
# Test
input_text = "今天天气"
print(input_text, end='')
for tok_id in generate_topk_proportion(input_text, k=50):
    if tok_id not in SPECIAL_TOKEN_IDS:
        print(gpt2_tokenizer.convert_ids_to_tokens(tok_id), end='')

今天天气转好，预计今天白天多云有阵雨，但是气温会比较平稳。今天是星期一，气温在22##℃左右，也许还会再##℃.3至23

In [33]:
# Test
input_text = "子曰：人"
print(input_text, end='')
for tok_id in generate_topk_proportion(input_text, k=50):
    if tok_id not in SPECIAL_TOKEN_IDS:
        print(gpt2_tokenizer.convert_ids_to_tokens(tok_id), end='')

子曰：人生所有的经历都是不一样的，人生是一个有序循序渐进有结果有过程的，它不仅仅是个短暂的，如果需要时间的话

Do you think the proportional sampling produces better text?

Have fun sampling! :)

## T3. Implement Top-p Sampling

Next, we will implement top-p sampling, which works in parallel to top-k sampling.

In `filter_topk_topp()`, we first filter out the logits that are not in the top-k, by setting their logit values to `-float('inf')`. 

And then filter out the logits whose cumulative probability (as computed from the altered logits from the previous step) is greater than `p`.

Note that it is possible that the first logit alone dominates the distribution, and its cumulative probability is greater than `p`. In this case, we want to keep this logit, and remove all other logits.

In [34]:
def filter_topk_topp(logits: torch.Tensor, k=50, p=0.9) -> torch.Tensor: 
    '''
    Filter a distribution of logits using top-k and/or nucleus (top-p) filtering
    '''
    assert logits.dim() == 1
    logits = logits.clone()

    if k > 0:
        # Get top-k logits and indices
        topk_logits, topk_indices = torch.topk(logits, k)
        # Create a mask for logits to keep
        mask = torch.zeros_like(logits, dtype=torch.bool)
        mask[topk_indices] = True
        # Set non-top-k logits to -inf
        logits[~mask] = -float('Inf')
    
    if p > 0.0:
        # Sort logits in descending order
        logits_sorted, indices_sorted = torch.sort(logits, descending=True)
        # Compute cumulative probabilities
        probs_sorted = F.softmax(logits_sorted, dim=0)
        cum_probs = torch.cumsum(probs_sorted, dim=0)
        
        # Find indices to remove
        indices_to_remove = cum_probs > p
        # Always keep the first token
        indices_to_remove[0] = False
        
        # Set filtered logits to -inf
        logits[indices_sorted[indices_to_remove]] = -float('Inf')
    
    return logits

In [35]:
# Test filter_topk_topp
logits = torch.tensor(list(range(10))).float()
print('original logits:', logits)

logits2 = filter_topk_topp(logits, k=5, p=0.0)
print('\nk=5, p=0.0:', logits2)

logits3 = filter_topk_topp(logits, k=0, p=0.9)
print('\nk=0, p=0.9:', logits3)

logits4 = filter_topk_topp(logits, k=0, p=0.9999999)
print('\nk=0, p=0.9999999:', logits4)

logits5 = filter_topk_topp(logits, k=5, p=0.9999999)
print('\nk=5, p=0.9999999:', logits5)


# You are expected to see the following output:
# original logits: tensor([0., 1., 2., 3., 4., 5., 6., 7., 8., 9.])
# k=5, p=0.0: tensor([-inf, -inf, -inf, -inf, -inf, 5., 6., 7., 8., 9.])
# k=0, p=0.9: tensor([-inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, 8., 9.])
# k=0, p=0.9999999: tensor([-inf, 1., 2., 3., 4., 5., 6., 7., 8., 9.])
# k=5, p=0.9999999: tensor([-inf, -inf, -inf, -inf, -inf, 5., 6., 7., 8., 9.])

original logits: tensor([0., 1., 2., 3., 4., 5., 6., 7., 8., 9.])

k=5, p=0.0: tensor([-inf, -inf, -inf, -inf, -inf, 5., 6., 7., 8., 9.])

k=0, p=0.9: tensor([-inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, 8., 9.])

k=0, p=0.9999999: tensor([-inf, 1., 2., 3., 4., 5., 6., 7., 8., 9.])

k=5, p=0.9999999: tensor([-inf, -inf, -inf, -inf, -inf, 5., 6., 7., 8., 9.])


In the following test, if all logits are `-inf`, then your top-p sampling is not correctly implemented. 

You wan to keep at least one element in the logits, whose logit value dominates the distribution. 

In [36]:
logits_special = torch.tensor(np.arange(10)**2).float()
print('original logits:', logits_special)

logits6 = filter_topk_topp(logits_special, k=0, p=0.9)
print('\nk=0, p=0.9:', logits6)


# You are expected to see the following output:
# original logits: tensor([ 0.,  1.,  4.,  9., 16., 25., 36., 49., 64., 81.])
# k=0, p=0.9: tensor([-inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, 81.])

original logits: tensor([ 0.,  1.,  4.,  9., 16., 25., 36., 49., 64., 81.])

k=0, p=0.9: tensor([-inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, 81.])


Finally, we integrate the filtering to the generation process.

In [37]:
def generate_topk_topp(input_text, k=50, p=0.9, max_gen_len=20):
    '''
    Generate tokens from the top-k and top-p filtered logits, and yield the sampled token id.
    '''
    input_encoded = gpt2_tokenizer(input_text, return_tensors="pt", add_special_tokens=False)
    input_ids = input_encoded.input_ids
    attn_mask = input_encoded.attention_mask

    count = 0
    while count < max_gen_len:
        output = gpt2_model(input_ids, attention_mask=attn_mask)
        logits = output.logits

        # Get last token logits
        last_token_logits = logits[0, -1, :]
        
        # Get the filtered logits
        filtered_logits = filter_topk_topp(last_token_logits, k=k, p=p)
        
        # Sample from the remaining tokens
        filtered_probs = F.softmax(filtered_logits, dim=0)
        try:
            sampled_index = torch.multinomial(filtered_probs, 1).item()
        except RuntimeError:
            raise

        # Yield the sampled token id
        yield sampled_index
        if sampled_index == gpt2_tokenizer.sep_token_id:
            break

        # Append the sampled token id to the input_ids, and extend the attention mask
        input_ids = torch.cat([input_ids, torch.tensor([[sampled_index]])], dim=1)
        attn_mask = torch.cat([attn_mask, torch.ones(1, 1, dtype=attn_mask.dtype)], dim=1)

        count += 1

In [38]:
# Test
input_text = "今天天气"
print(input_text, end='')
for tok_id in generate_topk_topp(input_text, k=50, p=0.95):
    print(gpt2_tokenizer.convert_ids_to_tokens(tok_id), end='')

今天天气预报[UNK]的字样被封杀了，而且在官方发布的文

In [39]:
# Test
input_text = "子曰：人"
print(input_text, end='')
for tok_id in generate_topk_topp(input_text, k=50, p=0.95):
    print(gpt2_tokenizer.convert_ids_to_tokens(tok_id), end='')

子曰：人心不古，有圣人之德，必有子民之德。古人是