In [1]:
import torch.nn as nn
import torch.nn.functional as F
import torch
class TransformerBlock(nn.Module):
    def __init__(self, emb_dim, n_heads):
        super().__init__()
        self.attn = nn.MultiheadAttention(emb_dim, n_heads, batch_first=True)
        self.norm1 = nn.LayerNorm(emb_dim)
        self.mlp = nn.Sequential(
            nn.Linear(emb_dim, 4 * emb_dim),
            nn.GELU(),
            nn.Linear(4 * emb_dim, emb_dim)
        )
        self.norm2 = nn.LayerNorm(emb_dim)

    def forward(self, x):
        B, T, C = x.size()

        # 生成 causal mask，保证第 t 个位置只能看到 <= t 的位置
        mask = torch.tril(torch.ones(T, T, device=x.device)).unsqueeze(0).repeat(B, 1, 1)
        # nn.MultiheadAttention 需要 bool mask，True 表示被遮挡
        attn_mask = ~mask.bool()[0]  # (T, T) bool，True 表示遮挡

        attn_out, _ = self.attn(x, x, x, attn_mask=attn_mask, need_weights=False)
        x = x + attn_out
        x = self.norm1(x)
        mlp_out = self.mlp(x)
        x = x + mlp_out
        return self.norm2(x)

class TinyTransformer(nn.Module):
    def __init__(self, vocab_size, emb_dim=512, n_heads=16, n_layers=12, block_size=512):
        super().__init__()
        self.token_embedding = nn.Embedding(vocab_size, emb_dim)
        self.pos_embedding = nn.Parameter(torch.randn(1, block_size, emb_dim))
        self.blocks = nn.Sequential(*[
            TransformerBlock(emb_dim, n_heads) for _ in range(n_layers)
        ])
        self.ln = nn.LayerNorm(emb_dim)
        self.fc = nn.Linear(emb_dim, vocab_size)

    def forward(self, x):
        tok_emb = self.token_embedding(x)
        x = tok_emb + self.pos_embedding[:, :x.size(1), :]
        x = self.blocks(x)
        x = self.ln(x)
        logits = self.fc(x)
        return logits


In [2]:
import torch

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# 加载 tokenizer
from tokenizers import Tokenizer
tokenizer = Tokenizer.from_file("tiny_tokenizer.json")

# 初始化模型并加载参数
vocab_size = tokenizer.get_vocab_size()
model = TinyTransformer(vocab_size).to(device)

checkpoint = torch.load("tiny_transformer_checkpoint.pth", map_location=device)
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()


  checkpoint = torch.load("tiny_transformer_checkpoint.pth", map_location=device)


TinyTransformer(
  (token_embedding): Embedding(8192, 512)
  (blocks): Sequential(
    (0): TransformerBlock(
      (attn): MultiheadAttention(
        (out_proj): NonDynamicallyQuantizableLinear(in_features=512, out_features=512, bias=True)
      )
      (norm1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
      (mlp): Sequential(
        (0): Linear(in_features=512, out_features=2048, bias=True)
        (1): GELU(approximate='none')
        (2): Linear(in_features=2048, out_features=512, bias=True)
      )
      (norm2): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
    )
    (1): TransformerBlock(
      (attn): MultiheadAttention(
        (out_proj): NonDynamicallyQuantizableLinear(in_features=512, out_features=512, bias=True)
      )
      (norm1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
      (mlp): Sequential(
        (0): Linear(in_features=512, out_features=2048, bias=True)
        (1): GELU(approximate='none')
        (2): Linear(in_features=204

In [3]:
def generate_text(prompt, model, tokenizer, max_new_tokens=100, device="cuda"):
    from torch.nn import functional as F

    input_ids = tokenizer.encode(prompt).ids
    input_ids = torch.tensor([input_ids], dtype=torch.long).to(device)

    eos_token_id = tokenizer.token_to_id("<|endoftext|>")  # 终止符
    # print(eos_token_id)
    # model.eval()
    for _ in range(max_new_tokens):
        if input_ids.size(1) > model.pos_embedding.size(1):
            input_ids = input_ids[:, -model.pos_embedding.size(1):]  # 截断上下文

        with torch.no_grad():
            logits = model(input_ids)
            next_token_logits = logits[:, -1, :]
            probs = F.softmax(next_token_logits, dim=-1)
            next_token = torch.argmax(probs, dim=-1).unsqueeze(0)  # [1, 1]
            # print(next_token.tolist)
            # print(f"Next token text: {tokenizer.decode([next_token.item()])}")
        # 拼接生成的新 token
        input_ids = torch.cat([input_ids, next_token], dim=1)

        # 检查是否生成了 <eos>
        if next_token.item() == eos_token_id:
            print("here")
            break

    output_ids = input_ids[0].tolist()
    return tokenizer.decode(output_ids, skip_special_tokens=True)


In [4]:
prompt = ""
generated = generate_text(prompt, model, tokenizer, max_new_tokens=200, device=device)
print(generated)



Lily and Tom were playing in the garden. They liked to pretend they were chefs and had a big box. They had a lot of fun with a hat and a hat and a hat.
"Look, a hat!" Tom said. "It is a lot of a hat and a hat!"
"Wow!" Lily said. "It is a hat and a hat!"
"OK!" Tom said. "It is a flower!"
"OK!" Lily said. "They are very happy and see what we can make a lot of fun!"
"OK!" Tom said. "We can make a lot of fun!"
"OK!" Lily said. "We can make a lot of the box and see the box!"
"OK!" Tom said. "But we can make a lot of a lot of a hat and a hat and a hat and a hat!"
"Wow!" Lily said. "We can make a lot of a hat and a hat
