In [20]:
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import tiktoken  # 用于解码 token
import re

# GPTConfig 和 GPT 模型架构（和训练时一致）
class GPTConfig:
    def __init__(self, block_size=512, batch_size=12, n_layer=6, n_head=12, n_embd=768, dropout=0.1, vocab_size=50257):
        self.block_size = block_size
        self.batch_size = batch_size
        self.n_layer = n_layer
        self.n_head = n_head
        self.n_embd = n_embd
        self.head_size = n_embd // n_head
        self.dropout = dropout
        self.vocab_size = vocab_size

class SingleHeadAttention(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.key = nn.Linear(config.n_embd, config.head_size)
        self.value = nn.Linear(config.n_embd, config.head_size)
        self.query = nn.Linear(config.n_embd, config.head_size)

        self.register_buffer(
            'attention_mask', 
            torch.tril(torch.ones(config.block_size, config.block_size))
        )
        self.dropout = nn.Dropout(config.dropout)

    def forward(self, x):
        batch_size, seq_len, hidden_size = x.size()
        k = self.key(x)
        v = self.value(x)
        q = self.query(x)
        weight = q @ k.transpose(-2, -1)
        weight = weight.masked_fill(self.attention_mask[:seq_len, :seq_len] == 0, float('-inf')) / math.sqrt(hidden_size)
        weight = F.softmax(weight, dim=-1)
        weight = self.dropout(weight)
        out = weight @ v
        return out

class MultiHeadAttention(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.heads = nn.ModuleList([SingleHeadAttention(config) for _ in range(config.n_head)])
        self.proj = nn.Linear(config.n_embd, config.n_embd)
        self.dropout = nn.Dropout(config.dropout)

    def forward(self, x):
        output = torch.cat([h(x) for h in self.heads], dim=-1)
        output = self.proj(output)
        output = self.dropout(output)
        return output

class FeedForward(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(config.n_embd, 4 * config.n_embd),
            nn.GELU(),
            nn.Linear(4 * config.n_embd, config.n_embd),
            nn.Dropout(config.dropout)
        )
    
    def forward(self, x):
        return self.net(x)

class Block(nn.Module):
    def __init__(self, config):
        super().__init__()
        head_size = config.n_embd // config.n_head
        self.att = MultiHeadAttention(config)
        self.ffn = FeedForward(config)
        self.ln1 = nn.LayerNorm(config.n_embd)
        self.ln2 = nn.LayerNorm(config.n_embd)

    def forward(self, x):
        x = x + self.att(self.ln1(x))
        x = x + self.ffn(self.ln2(x))
        return x

class GPT(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config  # Store config to access block_size
        self.token_embedding_table = nn.Embedding(config.vocab_size, config.n_embd)
        self.position_embedding_table = nn.Embedding(config.block_size, config.n_embd)
        self.blocks = nn.Sequential(*[Block(config) for _ in range(config.n_layer)])
        self.ln_final = nn.LayerNorm(config.n_embd)
        self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
        self.apply(self._init_weights)

    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)

    def forward(self, idx, targets=None):
        batch, seq_len = idx.size()
        token_emb = self.token_embedding_table(idx)
        pos_emb = self.position_embedding_table(torch.arange(seq_len, device=idx.device))
        x = token_emb + pos_emb
        x = self.blocks(x)
        x = self.ln_final(x)
        logits = self.lm_head(x)

        loss = None
        if targets is not None:
            batch, seq_len, vocab_size = logits.size()
            logits = logits.view(batch * seq_len, vocab_size)
            targets = targets.view(batch * seq_len)
            loss = F.cross_entropy(logits, targets)
        return logits, loss

    def generate(self, idx, max_new_tokens):
        for _ in range(max_new_tokens):
            idx_cond = idx if idx.size(1) <= self.config.block_size else idx[:, -self.config.block_size:]
            logits, _ = self(idx_cond)
            
            logits = logits[:, -1, :]  # Get the last token prediction
            probs = F.softmax(logits, dim=-1)  # Apply softmax to get probabilities
            idx_next = torch.multinomial(probs, num_samples=1)  # Sample the next token

            idx = torch.cat((idx, idx_next), dim=1)
    
        return idx

# 加载模型
def load_model(checkpoint_path='C:/Users/98705/Desktop/LLMs-Zero-to-Hero-master/LLMs-Zero-to-Hero-master/src/video/checkpoints/model_epoch_1.pt', config=None):
    if not os.path.exists(checkpoint_path):
        raise FileNotFoundError(f"The model checkpoint file was not found at {checkpoint_path}")
    model = GPT(config)  # 使用训练时的配置
    checkpoint = torch.load(checkpoint_path)
    model.load_state_dict(checkpoint['model_state_dict'], strict=False)
    model.eval()
    return model

# 安全解码函数：去除无效字符
def safe_decode(tokens):
    try:
        # 尝试解码 token 为文本
        decoded = encoder.decode(tokens)
        # 去除掉可能的乱码字符
        decoded = re.sub(r'[^\w\s\u4e00-\u9fa5]', '', decoded)  # 保留中文字符和常规字母
        return decoded
    except Exception as e:
        print(f"解码错误: {e}")
        return "无法解码的文本"

# 交互式聊天函数
def chat_with_model():
    # 定义训练时的配置
    config = GPTConfig()
    
    # 加载模型
    model = load_model('C:/Users/98705/Desktop/LLMs-Zero-to-Hero-master/LLMs-Zero-to-Hero-master/src/video/checkpoints/model_epoch_1.pt', config)

    # 初始化tokenizer
    encoder = tiktoken.get_encoding("gpt2")

    print("与模型开始对话，输入 'exit' 退出。")

    while True:
        # 获取用户输入
        user_input = input("你说: ")
        if user_input.lower() == 'exit':
            print("退出对话。")
            break

        # 将用户输入转换为 token IDs
        tokens = encoder.encode(user_input)  # 将用户输入的文本转换为 tokens
        idx = torch.tensor([tokens], dtype=torch.long)  # 转换为 tensor 并确保它是二维的 (batch_size, seq_len)

        # 使用模型生成新的 tokens
        generated_tokens = model.generate(idx, max_new_tokens=50)

        # 解码生成的 token IDs
        decoded_text = safe_decode(generated_tokens[0].tolist())
        print(f"模型说: {decoded_text}")

# 启动聊天
if __name__ == "__main__":
    chat_with_model()




与模型开始对话，输入 'exit' 退出。
模型说: hi得啚通新赫法实不选整康成出为稓都将襭带
模型说: 你为什么不说话诠期网中应而学校有有师你计言就的人是化者以识的女


IndexError: index -1 is out of bounds for dimension 1 with size 0