In [None]:
!pip install transformers

下载模型：https://huggingface.co/uer/gpt2-chinese-cluecorpussmall 只需要下载：`config.json`、`pytorch_model.bin`、`vocab.txt`、`flax_model.msgpack`

In [7]:
import torch
import torch.nn.functional as F
from transformers import BertTokenizerFast
from transformers import GPT2LMHeadModel

if __name__ == '__main__':
    # 定义设备变量
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    print(f"device = {device}")
    # 定义分词器
    tokenizer = BertTokenizerFast(vocab_file='gpt2_chinese_base/vocab.txt',
                                  sep_token="[SEP]",
                                  pad_token="[PAD]",
                                  cls_token="[CLS]")

    text = '[CLS]今天天气真'  # 输入的句子
    print(f"input: {text}")
    # 将句子分词，并将中文词语转为数字索引的形式
    text_ids = tokenizer.encode(text, add_special_tokens=False)
    input_ids = torch.tensor(text_ids).long().to(device).unsqueeze(0)
    print(f"input_ids: {input_ids}")
    print("")

    pretrained_model = './gpt2_chinese_base'
    # 加载预训练模型
    model = GPT2LMHeadModel.from_pretrained(pretrained_model)
    model = model.to(device)
    model.eval()
    
    max_len = 20  # 向后生成20个字
    response = list()
    for i in range(max_len):  # 每循环一次，就生成一个字
        # 将输入序列input_ids输入到模型model后，会计算出推理结果
        outputs = model(input_ids=input_ids)
        probs = F.softmax(outputs.logits[0, -1, :], dim=-1)
        # 直接选择概率最大的字，作为下一个字，贪心的选择最大概率的字
        next_token = torch.max(probs, dim=-1)[1].unsqueeze(0)
        response.append(next_token.item())  # 将这个字作为下一个生成结果
        # 更新输入序列
        input_ids = torch.cat((input_ids, next_token.unsqueeze(0)), dim=1)
    
        print(f"i = {i}")
        input_ids_list = input_ids.squeeze().tolist()
        # 将input_ids的整数索引与文本形式打印出来
        print(f"input_ids: {input_ids_list}")
        input_ids_text = tokenizer.convert_ids_to_tokens(input_ids_list)
        print(f"input_ids_text: {input_ids_text}")
        print("")

device = cuda
input: [CLS]今天天气真
input_ids: tensor([[ 101,  791, 1921, 1921, 3698, 4696]], device='cuda:0')

i = 0
input_ids: [101, 791, 1921, 1921, 3698, 4696, 4638]
input_ids_text: ['[CLS]', '今', '天', '天', '气', '真', '的']

i = 1
input_ids: [101, 791, 1921, 1921, 3698, 4696, 4638, 2523]
input_ids_text: ['[CLS]', '今', '天', '天', '气', '真', '的', '很']

i = 2
input_ids: [101, 791, 1921, 1921, 3698, 4696, 4638, 2523, 1962]
input_ids_text: ['[CLS]', '今', '天', '天', '气', '真', '的', '很', '好']

i = 3
input_ids: [101, 791, 1921, 1921, 3698, 4696, 4638, 2523, 1962, 8024]
input_ids_text: ['[CLS]', '今', '天', '天', '气', '真', '的', '很', '好', '，']

i = 4
input_ids: [101, 791, 1921, 1921, 3698, 4696, 4638, 2523, 1962, 8024, 2769]
input_ids_text: ['[CLS]', '今', '天', '天', '气', '真', '的', '很', '好', '，', '我']

i = 5
input_ids: [101, 791, 1921, 1921, 3698, 4696, 4638, 2523, 1962, 8024, 2769, 812]
input_ids_text: ['[CLS]', '今', '天', '天', '气', '真', '的', '很', '好', '，', '我', '们']

i = 6
input_ids: [101, 791, 1921, 1921