In [1]:
import torch
import torch.nn as nn

In [2]:
# 测试集
import json 
from data import TextDataset
from torch.utils.data import Dataset, DataLoader


def load_data(file_path):
    with open(file_path, 'r') as f:
        data = json.load(f)
    return data

test_data = load_data('./src/dataset/test.json')["test"]

In [3]:
from src.net import Net
from src.model.embedding.token_embedding import Embedding
from src.model.embedding.position import PositionalEmbedding

In [4]:
# 加载模型
gpt = Net()
gpt.load_state_dict(torch.load('./trainer/gpt_100_.pth', map_location='cuda:0' if torch.cuda.is_available() else 'cpu'))

<All keys matched successfully>

In [5]:
embedd = Embedding(vocab_size=128000, dim=64)
position_emb = PositionalEmbedding(max_len=3, dim=64)

In [6]:
test_data[0]['input']

'They are discussing a'

In [7]:
# 编码器
from tokenization.tokenizer import tokenizer
max_length = 3
tokenizer = tokenizer()

In [8]:
tokenizer.encode(test_data[0]['input'])

[7009, 527, 25394, 264]

In [9]:
# 词编码
word_embedding = embedd(torch.tensor(tokenizer.encode(test_data[0]['input'])[:max_length]))
# 位置编码
position_embedding = position_emb(word_embedding)
# 词向量
context = word_embedding + position_embedding

print(context.shape)

torch.Size([1, 3, 64])


In [10]:
next_word = gpt(context.to(device='cuda:0'if torch.cuda.is_available() else 'cpu'))

In [11]:
next_word = tokenizer.decode([next_word.argmax(-1).item()])

In [12]:
import sys
import time

def writer_output(text, delay=0.1):
    """
    模拟打字机效果逐字输出文本
    
    Parameters
    ----------
    text : str
        需要逐字输出的文本
    delay : float, optional
        每个字符输出的延迟时间，默认为 0.1 秒
    """
    for char in text:
        sys.stdout.write(char)  # 输出字符
        sys.stdout.flush()      # 刷新输出缓冲区
        time.sleep(delay)       # 延迟



In [13]:
text = test_data[0]['input'] + next_word

writer_output(text)

They are discussing agift

In [20]:
def generate_response(input_text, max_length):
    # 词编码
    word_embedding = embedd(torch.tensor(tokenizer.encode(input_text)[:max_length]))
    # 位置编码
    position_embedding = position_emb(word_embedding)
    # 词向量
    context = word_embedding + position_embedding

    next_word = gpt(context.to(device='cuda:0'if torch.cuda.is_available() else 'cpu'))
    next_word = tokenizer.decode([next_word.argmax(-1).item()])

    response = input_text + ' ' + next_word
    # 模拟逐步输出
    displayed_text = ""
    for char in response:
        displayed_text += char
        time.sleep(0.05)  # 每个字符延迟 50ms
        yield displayed_text  # 逐步更新输出

    # return input_text + ' ' + next_word

In [21]:
# 前端界面
import gradio as gr


with gr.Blocks() as demo:
    gr.HTML("""<h1 align="center">Mini GPT-2</h1>""")
    with gr.Row():
        with gr.Column(scale=3):
            query = gr.Textbox(placeholder='输入内容:', lines=2, label='Content')
            with gr.Row():
                answer = gr.Textbox(placeholder='对话结果：', lines=2, label='Content')
            with gr.Row():
                submit = gr.Button('提交', variant='primary')
                clear = gr.Button('清空', variant='secondary')

        with gr.Column(scale=1):
            max_length = gr.Slider(0, 3, value=99, step=1.0, label="Maximum length", interactive=True)
            top_p = gr.Slider(0, 1, value=0.8, step=0.01, label="Top P", interactive=True)
            temperature = gr.Slider(0, 1, value=0.95, step=0.01, label="Temperature", interactive=True)

    submit.click(generate_response, inputs=[query, max_length], outputs=[answer], show_progress=True)
    clear.click(lambda: "", None, answer)  
    demo.queue().launch(share=False, inbrowser=True)

Running on local URL:  http://127.0.0.1:7864

To create a public link, set `share=True` in `launch()`.
