In [1]:
from transformers import AutoTokenizer, BartForConditionalGeneration
import torch
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
# Assuming that we are on a CUDA machine, this should print a CUDA device:
print(device)
torch.cuda.empty_cache() # 清空GPU缓存，清空CUDA，清空CUDA内存

model = BartForConditionalGeneration.from_pretrained("facebook/bart-large-cnn").to(device)
tokenizer = AutoTokenizer.from_pretrained("facebook/bart-large-cnn")

ARTICLE_TO_SUMMARIZE = (
    "PG&E stated it scheduled the blackouts in response to forecasts for high winds "
    "amid dry conditions. The aim is to reduce the risk of wildfires. Nearly 800 thousand customers were "
    "scheduled to be affected by the shutoffs which were expected to last through at least midday tomorrow."
)

QUERY_DOCUMENT = (
    "I am hacking into BART. I use different documents for encoder and decoder, and try to see the influence. "
    "What do you think will happen?"
)
doc1_inputs = tokenizer([ARTICLE_TO_SUMMARIZE], max_length=1024, return_tensors="pt")
doc2_inputs = tokenizer([QUERY_DOCUMENT], max_length=1024, return_tensors="pt")

# Inside GenerationMixin:model_inputs= dict_keys(['input_ids', 'encoder_outputs', 'past_key_values', 'decoder_input_ids', 'attention_mask', 'decoder_attention_mask', 'head_mask', 'decoder_head_mask', 'cross_attn_head_mask', 'use_cache'])
# dict_keys(['encoder_outputs', 'past_key_values', 'decoder_input_ids', 'attention_mask' 'use_cache'])
# Passed parameters:
    # encoder_outputs
    # past_key_values
    # decoder_input_ids
    # attention_mask
    # use_cache = True 意味着会存储之前计算的 past_key_values 会被存储下来。 也就意味着，使用的是window attention。每次只计算新的token的kv，然后保存下来。
# '''
# encoder_outputs: torch.Size([2, 56, 1024]) 一值不变， 56 是输入序列长度，但是输入只有一个，也就是batch size 应该是1，但是这里是2.不知道为什么。
# past_key_values: 12                        一值不变
# decoder_input_ids: torch.Size([2, 1])      一值不变
# attention_mask: torch.Size([2, 56])        一值不变
# Inside BART: input_ids= tensor([[2], [2]]) 会变，这里可以看到的是，batch中的两个的值可能是一样的，也可能不是一样的。
# Inside BART: input_ids= tensor([[1768], [ 717]])
# '''

  from .autonotebook import tqdm as notebook_tqdm


cuda:0


Truncation was not explicitly activated but `max_length` is provided a specific value, please use `truncation=True` to explicitly truncate examples to max length. Defaulting to 'longest_first' truncation strategy. If you encode pairs of sequences (GLUE-style) with the tokenizer you can select this strategy more precisely by providing a specific strategy to `truncation`.


用LLM 做生成任务的一般步骤：  
1. 获得用户输入，添加开始字符
2. 创建对应的 attention mask
3. 设置模型为生成模式，可能需要设置是否cache KV
4. tokenize，得到input_ids
5. 写一个循环：
    1. 判断是否达到终止条件
    2. 运行模型，得到模型对下一个token的概率预测
    3. 通过概率和采样算法获取下一个token的id -- 这被称之为 generation strategy https://huggingface.co/docs/transformers/en/generation_strategies  
    对应的算法解释： https://huggingface.co/blog/how-to-generate
    4. 将下一个token添加到input_ids序列中
6. 将最终输出的 input_ids 输入给Tokenizer 进行decode，得到字符串。


Example by Transformers module:
1. Handle `generation_config` and kwargs that might update it, and validate the `.generate()` call
2. Set generation parameters if not already defined
3. Define model inputs
4. Define other model kwargs
5. Prepare `input_ids` which will be used for auto-regressive generation
6. Prepare `max_length` depending on other stopping criteria.
7. determine generation mode
8. prepare distribution pre_processing samplers
9. prepare stopping criteria
10. go into different generation modes, e.g. Beam search
11. prepare beam search scorer 
12. interleave input_ids with `num_beams` additional sequences per batch
13. run beam search (i.e. generation loop)
    1. stopping checking
    2. predict next token
    3. choose next token
    4. concatenate

所以对于我的需求（自定义cross attention的KV），我其实可以给模型写一个set函数，直接设置。或者模型提供对应的参数，那么我就通过参数传进去就好。然后把模型的forward中的encoder的代码注释掉就好。  
或者呢我给模型多定义一个选择的参数。选择性使用Encoder即可。  
也就是说，我还是使用Transformers提供的 generate函数，只用修改模型代码就好。  
而且generate其实只对生成有影响，对Encoder其实没有影响。所以我要调用Encoder的话，我直接调用model.encoder(params)就行。  

In [2]:
# Reference experiment: one document for Encoder, none document for Decoder
# Generate Summary
summary_ids = model.generate(doc1_inputs["input_ids"].to(device), num_beams=2, min_length=0, max_length=30)
print(tokenizer.batch_decode(summary_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0])


Inside GenerationMinin:generation_mode=: GenerationMode.BEAM_SEARCH
Calling forward() False
Bart:  torch.Size([2, 56, 1024])
Bart:  tensor([ 0.0138,  0.0368,  0.0209,  ...,  0.0042, -0.0008, -0.0026],
       device='cuda:0')
Calling forward() False
Bart:  torch.Size([2, 56, 1024])
Bart:  tensor([ 0.0138,  0.0368,  0.0209,  ...,  0.0042, -0.0008, -0.0026],
       device='cuda:0')
Calling forward() False
Bart:  torch.Size([2, 56, 1024])
Bart:  tensor([ 0.0138,  0.0368,  0.0209,  ...,  0.0042, -0.0008, -0.0026],
       device='cuda:0')
Calling forward() False
Bart:  torch.Size([2, 56, 1024])
Bart:  tensor([ 0.0138,  0.0368,  0.0209,  ...,  0.0042, -0.0008, -0.0026],
       device='cuda:0')
Calling forward() False
Bart:  torch.Size([2, 56, 1024])
Bart:  tensor([ 0.0138,  0.0368,  0.0209,  ...,  0.0042, -0.0008, -0.0026],
       device='cuda:0')
Calling forward() False
Bart:  torch.Size([2, 56, 1024])
Bart:  tensor([ 0.0138,  0.0368,  0.0209,  ...,  0.0042, -0.0008, -0.0026],
       device=

In [3]:
# Experiment 1: one document for Encoder, none document for Decoder

# Encoder
encoder_representations = model.get_encoder().forward(doc1_inputs["input_ids"].to(device), doc1_inputs["attention_mask"].to(device))
# Decoder
# 但至少可以先验证，如果提供了 encoder_outputs， 模型是否会正常输出: -- 可以正常输出
# 这里的逻辑是，提供了 encoder_outputs 之后，在模型的forward函数中会跳过encoder
summary_ids = model.generate(num_beams=2, min_length=0, max_length=30, 
                             encoder_outputs=encoder_representations) 
print(tokenizer.batch_decode(summary_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0])

Inside GenerationMinin:generation_mode=: GenerationMode.BEAM_SEARCH
Calling forward() False
Bart:  torch.Size([2, 56, 1024])
Bart:  tensor([ 0.0138,  0.0368,  0.0209,  ...,  0.0042, -0.0008, -0.0026],
       device='cuda:0')
Calling forward() False
Bart:  torch.Size([2, 56, 1024])
Bart:  tensor([ 0.0138,  0.0368,  0.0209,  ...,  0.0042, -0.0008, -0.0026],
       device='cuda:0')
Calling forward() False
Bart:  torch.Size([2, 56, 1024])
Bart:  tensor([ 0.0138,  0.0368,  0.0209,  ...,  0.0042, -0.0008, -0.0026],
       device='cuda:0')
Calling forward() False
Bart:  torch.Size([2, 56, 1024])
Bart:  tensor([ 0.0138,  0.0368,  0.0209,  ...,  0.0042, -0.0008, -0.0026],
       device='cuda:0')
Calling forward() False
Bart:  torch.Size([2, 56, 1024])
Bart:  tensor([ 0.0138,  0.0368,  0.0209,  ...,  0.0042, -0.0008, -0.0026],
       device='cuda:0')
Calling forward() False
Bart:  torch.Size([2, 56, 1024])
Bart:  tensor([ 0.0138,  0.0368,  0.0209,  ...,  0.0042, -0.0008, -0.0026],
       device=

In [4]:
def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int, decoder_start_token_id: int):
    """
    Shift input ids one token to the right.
    """
    shifted_input_ids = input_ids.new_zeros(input_ids.shape)
    shifted_input_ids[:, 1:] = input_ids[:, :-1].clone()
    shifted_input_ids[:, 0] = decoder_start_token_id

    if pad_token_id is None:
        raise ValueError("self.model.config.pad_token_id has to be defined.")
    # replace possible -100 values in labels by `pad_token_id`
    shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id)

    return shifted_input_ids

In [5]:
# Experiment 1: one document for Encoder, none document for Decoder

# Encoder
encoder_representations = model.get_encoder().forward(doc1_inputs["input_ids"].to(device), doc1_inputs["attention_mask"].to(device))
# Decoder
# 验证添加 decoder_input_ids 会有什么影响 
decoder_input_ids = shift_tokens_right(
    doc1_inputs["input_ids"].to(device), model.config.pad_token_id, model.config.decoder_start_token_id
)
print(doc1_inputs["input_ids"].shape)
summary_ids = model.generate(decoder_input_ids=decoder_input_ids, 
                             num_beams=2, min_length=0, max_length=doc1_inputs["input_ids"].shape[1] + 30, 
                             encoder_outputs=encoder_representations) 
print(tokenizer.batch_decode(summary_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0])

# document 1:
# "PG&E stated it scheduled the blackouts in response to forecasts for high winds "
# "amid dry conditions. The aim is to reduce the risk of wildfires. Nearly 800 thousand customers were "
# "scheduled to be affected by the shutoffs which were expected to last through at least midday tomorrow."

# 1. generate():
# PG&E scheduled the blackouts in response to forecasts for high winds amid dry conditions. 
# The aim is to reduce the risk of wildfires

# 2. decoder-only with decoder_input_ids and encoder_outputs from the same document:
# PG&E stated it scheduled the blackouts in response to forecasts for high winds amid dry conditions. 
# The aim is to reduce the risk of wildfires. 
# Nearly 800 thousand customers were scheduled to be affected by the shutoffs which were expected to 
# last through at least midday tomorrow.

# 结论：就是原封不动地输出了 decoder_input_ids 的内容。但是感觉也合理，首先，因为输入给Decoder的ids会全部输出，
# 其次，因为 encoder_outputs 中没有额外的信息。

torch.Size([1, 56])
Inside GenerationMinin:generation_mode=: GenerationMode.BEAM_SEARCH
Calling forward() False
Bart:  torch.Size([2, 56, 1024])
Bart:  tensor([ 0.0138,  0.0368,  0.0209,  ...,  0.0042, -0.0008, -0.0026],
       device='cuda:0')
Calling forward() False
Bart:  torch.Size([2, 56, 1024])
Bart:  tensor([ 0.0138,  0.0368,  0.0209,  ...,  0.0042, -0.0008, -0.0026],
       device='cuda:0')
PG&E stated it scheduled the blackouts in response to forecasts for high winds amid dry conditions. The aim is to reduce the risk of wildfires. Nearly 800 thousand customers were scheduled to be affected by the shutoffs which were expected to last through at least midday tomorrow.


In [6]:
# Experiment 2: one document for Encoder, another document for Decoder

# Encoder
encoder_representations = model.get_encoder().forward(doc1_inputs["input_ids"].to(device), doc1_inputs["attention_mask"].to(device))
# Decoder
decoder_input_ids = shift_tokens_right(
    doc2_inputs["input_ids"].to(device), model.config.pad_token_id, model.config.decoder_start_token_id
)
print(doc2_inputs["input_ids"].shape)
summary_ids = model.generate(decoder_input_ids=decoder_input_ids, 
                             num_beams=2, min_length=0, max_length=doc2_inputs["input_ids"].shape[1] + 30, 
                             encoder_outputs=encoder_representations) 
print(tokenizer.batch_decode(summary_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0])

# document 1:
# "PG&E stated it scheduled the blackouts in response to forecasts for high winds "
# "amid dry conditions. The aim is to reduce the risk of wildfires. Nearly 800 thousand customers were "
# "scheduled to be affected by the shutoffs which were expected to last through at least midday tomorrow."

# document 2:
# "I am hacking into BART. I use different documents for encoder and decoder, and try to see the influence. "
# "What do you think will happen?"

# 1. generate():
# PG&E scheduled the blackouts in response to forecasts for high winds amid dry conditions. 
# The aim is to reduce the risk of wildfires

# 2. decoder-only with decoder_input_ids and encoder_outputs from the same document:
# I am hacking into BART. I use different documents for encoder and decoder, and try to see the influence. 
# What do you think will happen? Tell us in the comments below.

# 结论：就是原封不动地输出了 decoder_input_ids 的内容。然后还附加了一句 "Tell us in the comments below."。 说明实验很成功。
# 因为实际上输入给 Encoder的 文本 和 输入给Decoder的问题其实没有任何关系，因此不能将Encoder的输入内容添加到 Decoder的输出中。

torch.Size([1, 33])
Inside GenerationMinin:generation_mode=: GenerationMode.BEAM_SEARCH
Calling forward() False
Bart:  torch.Size([2, 56, 1024])
Bart:  tensor([ 0.0138,  0.0368,  0.0209,  ...,  0.0042, -0.0008, -0.0026],
       device='cuda:0')
Calling forward() False
Bart:  torch.Size([2, 56, 1024])
Bart:  tensor([ 0.0138,  0.0368,  0.0209,  ...,  0.0042, -0.0008, -0.0026],
       device='cuda:0')
Calling forward() False
Bart:  torch.Size([2, 56, 1024])
Bart:  tensor([ 0.0138,  0.0368,  0.0209,  ...,  0.0042, -0.0008, -0.0026],
       device='cuda:0')
Calling forward() False
Bart:  torch.Size([2, 56, 1024])
Bart:  tensor([ 0.0138,  0.0368,  0.0209,  ...,  0.0042, -0.0008, -0.0026],
       device='cuda:0')
Calling forward() False
Bart:  torch.Size([2, 56, 1024])
Bart:  tensor([ 0.0138,  0.0368,  0.0209,  ...,  0.0042, -0.0008, -0.0026],
       device='cuda:0')
Calling forward() False
Bart:  torch.Size([2, 56, 1024])
Bart:  tensor([ 0.0138,  0.0368,  0.0209,  ...,  0.0042, -0.0008, -0.0

In [9]:
# Experiment 3: one document for Encoder, another document for Decoder. The two documents are related.

# REFERENCF_DOCUMENT = (
#     "If you use different documents for encoder and decoder, "
#     "the decoder will try to extract practical information from the reference of encoder."
#     " That means the words from the document for encoder may appear in the outputs of decoder."
# )
REFERENCF_DOCUMENT = QUERY_DOCUMENT + 'I do not know!'
doc3_inputs = tokenizer([REFERENCF_DOCUMENT], max_length=1024, return_tensors="pt")

# Encoder
encoder_representations = model.get_encoder().forward(doc3_inputs["input_ids"].to(device), doc3_inputs["attention_mask"].to(device))
print(encoder_representations[0].shape) # torch.Size([1, 49, 1024])
print(encoder_representations[0][0, 0, :])
# Decoder
decoder_input_ids = shift_tokens_right(
    doc2_inputs["input_ids"].to(device), model.config.pad_token_id, model.config.decoder_start_token_id
)
print(doc2_inputs["input_ids"].shape)
summary_ids = model.generate(decoder_input_ids=decoder_input_ids, 
                             num_beams=2, min_length=0, max_length=doc2_inputs["input_ids"].shape[1] + 30, 
                             encoder_outputs=encoder_representations) 
print(tokenizer.batch_decode(summary_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0])

# 问题： 
# 1. 不知道模型在哪里调用了 Encoder，

# document 1:
# "PG&E stated it scheduled the blackouts in response to forecasts for high winds "
# "amid dry conditions. The aim is to reduce the risk of wildfires. Nearly 800 thousand customers were "
# "scheduled to be affected by the shutoffs which were expected to last through at least midday tomorrow."

# document 2:
# "I am hacking into BART. I use different documents for encoder and decoder, and try to see the influence. "
# "What do you think will happen?"

# document 3:
# "If you use different documents for encoder and decoder, "
# "the decoder will try to extract practical information from the reference of encoder."
# " That means the words from the document for encoder may appear in the outputs of decoder."

# 1. generate():
# PG&E scheduled the blackouts in response to forecasts for high winds amid dry conditions. 
# The aim is to reduce the risk of wildfires

# 2. decoder-only with decoder_input_ids and encoder_outputs from the same document:
# I am hacking into BART. I use different documents for encoder and decoder, and try to see the influence. 
# What do you think will happen? 
# I do not know!

# 结论：就是原封不动地输出了 decoder_input_ids 的内容。然后还附加了一句 "I do not know!"。新增的这句话是来自于 reference document的。
# 说明实验很成功。
# 因为实际上输入给 Encoder的 文本 和 输入给Decoder的问题其实没有任何关系，因此不能将Encoder的输入内容添加到 Decoder的输出中。
# 存在的缺陷，模型只能续写，而不能进行推理。因为当我把document3 换成了相关但是需要推理的文本时，模型就直接忽略了 document3.
# Decoder 的行为主要还是复述 encoder中的内容，但是是基于自己的input来进行复述的。当两者的输入没有直接关系时，Decoder就会忽视encoder的内容。
# 一个猜测是，这种情况下的Decoder其实是在做一个转述/翻译的工作，而没有推理能力，或者说，没有in-context-learning的能力。所以会忽略掉 Decoder
# 给它的参考内容。
# 那么验证这个猜想的办法就是微调。

torch.Size([1, 38, 1024])
tensor([ 0.0190,  0.0207,  0.0212,  ..., -0.0005, -0.0067, -0.0023],
       device='cuda:0', grad_fn=<SliceBackward0>)
torch.Size([1, 33])
Inside GenerationMinin:generation_mode=: GenerationMode.BEAM_SEARCH
Calling forward() False
Bart:  torch.Size([2, 38, 1024])
Bart:  tensor([ 0.0190,  0.0207,  0.0212,  ..., -0.0005, -0.0067, -0.0023],
       device='cuda:0')
Calling forward() False
Bart:  torch.Size([2, 38, 1024])
Bart:  tensor([ 0.0190,  0.0207,  0.0212,  ..., -0.0005, -0.0067, -0.0023],
       device='cuda:0')
Calling forward() False
Bart:  torch.Size([2, 38, 1024])
Bart:  tensor([ 0.0190,  0.0207,  0.0212,  ..., -0.0005, -0.0067, -0.0023],
       device='cuda:0')
Calling forward() False
Bart:  torch.Size([2, 38, 1024])
Bart:  tensor([ 0.0190,  0.0207,  0.0212,  ..., -0.0005, -0.0067, -0.0023],
       device='cuda:0')
Calling forward() False
Bart:  torch.Size([2, 38, 1024])
Bart:  tensor([ 0.0190,  0.0207,  0.0212,  ..., -0.0005, -0.0067, -0.0023],
      

In [12]:
# Experiment 4: one document for Encoder, another document for Decoder. The two documents are related.

REFERENCF_DOCUMENT =  (
    "What do you think will happen?"
    "If you use different documents for encoder and decoder in BART, "
    "the decoder will try to extract practical information from the reference of encoder."
    " That means the words from the document for encoder may appear in the outputs of decoder."
)
doc3_inputs = tokenizer([REFERENCF_DOCUMENT], max_length=1024, return_tensors="pt")

# Encoder
encoder_representations = model.get_encoder().forward(doc3_inputs["input_ids"].to(device), doc3_inputs["attention_mask"].to(device))
print(encoder_representations[0].shape) # torch.Size([1, 49, 1024])
print(encoder_representations[0][0, 0, :])
# Decoder
decoder_input_ids = shift_tokens_right(
    doc2_inputs["input_ids"].to(device), model.config.pad_token_id, model.config.decoder_start_token_id
)
print(doc2_inputs["input_ids"].shape)
summary_ids = model.generate(decoder_input_ids=decoder_input_ids, 
                             num_beams=2, min_length=0, max_length=doc2_inputs["input_ids"].shape[1] + 30, 
                             encoder_outputs=encoder_representations) 
print(tokenizer.batch_decode(summary_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0])

# 问题： 
# 1. 不知道模型在哪里调用了 Encoder，

# document 1:
# "PG&E stated it scheduled the blackouts in response to forecasts for high winds "
# "amid dry conditions. The aim is to reduce the risk of wildfires. Nearly 800 thousand customers were "
# "scheduled to be affected by the shutoffs which were expected to last through at least midday tomorrow."

# document 2:
# "I am hacking into BART. I use different documents for encoder and decoder, and try to see the influence. "
# "What do you think will happen?"

# document 3:
# "What do you think will happen?"
# "If you use different documents for encoder and decoder in BART, "
# "the decoder will try to extract practical information from the reference of encoder."
# " That means the words from the document for encoder may appear in the outputs of decoder."

# 1. generate():
# PG&E scheduled the blackouts in response to forecasts for high winds amid dry conditions. 
# The aim is to reduce the risk of wildfires

# 2. decoder-only with decoder_input_ids and encoder_outputs from the same document:
# I am hacking into BART. I use different documents for encoder and decoder, and try to see the influence. 
# What do you think will happen? Share your thoughts.



torch.Size([1, 58, 1024])
tensor([ 0.0191,  0.0232,  0.0188,  ...,  0.0005, -0.0022, -0.0001],
       device='cuda:0', grad_fn=<SliceBackward0>)
torch.Size([1, 33])
Inside GenerationMinin:generation_mode=: GenerationMode.BEAM_SEARCH
Calling forward() False
Bart:  torch.Size([2, 58, 1024])
Bart:  tensor([ 0.0191,  0.0232,  0.0188,  ...,  0.0005, -0.0022, -0.0001],
       device='cuda:0')
Calling forward() False
Bart:  torch.Size([2, 58, 1024])
Bart:  tensor([ 0.0191,  0.0232,  0.0188,  ...,  0.0005, -0.0022, -0.0001],
       device='cuda:0')
Calling forward() False
Bart:  torch.Size([2, 58, 1024])
Bart:  tensor([ 0.0191,  0.0232,  0.0188,  ...,  0.0005, -0.0022, -0.0001],
       device='cuda:0')
Calling forward() False
Bart:  torch.Size([2, 58, 1024])
Bart:  tensor([ 0.0191,  0.0232,  0.0188,  ...,  0.0005, -0.0022, -0.0001],
       device='cuda:0')
Calling forward() False
Bart:  torch.Size([2, 58, 1024])
Bart:  tensor([ 0.0191,  0.0232,  0.0188,  ...,  0.0005, -0.0022, -0.0001],
      

TODOs:
fine-tuning 
https://github.com/facebookresearch/fairseq/tree/main/examples/bart
https://github.com/facebookresearch/fairseq/blob/main/examples/bart/README.glue.md
https://gluebenchmark.com/tasks
https://github.com/nyu-mll/GLUE-baselines
https://huggingface.co/facebook/bart-large
https://huggingface.co/docs/transformers/training#train-in-native-pytorch
https://github.com/Mooler0410/LLMsPracticalGuide?tab=readme-ov-file
https://medium.com/@ferlatti.aldo/fine-tuning-a-chat-summarizer-c18625bc817d


可用的数据集：  
DROP, LAMBADA, CBT (Children’s Book Test), RACE, SQuAD (Stanford Question Answering Dataset), SuperGLUE-boolq, SuperGLUE-copa, GSM8k(或许也可以用？), AlpacaEval(历史数据可以视为参考数据), AlpacaEval(instruction follow), MT Bench (Multi-Turn Benchmark), MMMU(或许部分数据集可用), MMLU, 


In [8]:
####################### Generative loop example (from https://www.junronglin.com/article/why_left_padding):
def sample_sequence(model, length, start_token=None, batch_size=None, context=None, temperature=1, top_k=0, device='cuda', sample=True, enc=None):
    if start_token is None:
				 # if start_token is None, use context
        assert context is not None, 'Specify exactly one of start_token and context!'
        context = torch.tensor(context, device=device, dtype=torch.long).unsqueeze(
            0).repeat(batch_size, 1)
    else:
				 # if start_token isn't None, use start_token as the beginning of each sentences
        assert context is None, 'Specify exactly one of start_token and context!'
        context = torch.full((batch_size, 1), start_token,
                             device=device, dtype=torch.long)
    prev = context
    output = context
    # past is KV-cache
    past = None
    with torch.no_grad():
        for i in trange(length):  # generate `length` tokens for all sentences
            logits, past = model(prev, past=past)

            # logits.shape=[batch, text, vocab_szie], in Causal model, the logits of the last token in each sentence is used to predict next token, so pick `-1` here
            logits = logits[:, -1, :] / temperature
            logits = top_k_logits(logits, k=top_k)
            log_probs = F.softmax(logits, dim=-1)
            if sample:
                prev = torch.multinomial(log_probs, num_samples=1)
            else:
                _, prev = torch.topk(log_probs, k=1, dim=-1)

            # concatenate the sampled tokens to the original sentences, 
            # e.g. output = [I have] and sampled `an`
            # output = [I have an]
            output = torch.cat((output, prev), dim=1)

    return output