In [2]:
from datasets import load_dataset

dataset = load_dataset("opus100", "en-zh")
print(dataset)

DatasetDict({
    test: Dataset({
        features: ['translation'],
        num_rows: 2000
    })
    train: Dataset({
        features: ['translation'],
        num_rows: 1000000
    })
    validation: Dataset({
        features: ['translation'],
        num_rows: 2000
    })
})


In [3]:
print(dataset['train'][0])

{'translation': {'en': 'Sixty-first session', 'zh': '第六十一届会议'}}


In [4]:
import torch
import torch.nn as nn 

class GPTEmbedding(nn.Module):
    def __init__(self, vocab_size, d_model):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, d_model)
    
    def forward(self, x):
        return self.embedding(x)

In [5]:
# 旋转位置编码
from typing import Tuple


class RotaryEmbedding(nn.Module): 
    # 旋转位置编码实现
    def __init__(self, dim: int, max_position: int = 10000): 
        super().__init__()
        assert dim % 2 == 0, "dim必须是偶数"
        self.dim = dim
        self.max_position = max_position
        
        # 预计算频率
        inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))
        self.register_buffer("inv_freq", inv_freq)
        
        # 预计算所有位置的cos和sin
        self._precompute_embeddings()
    
    def _precompute_embeddings(self):
        """预计算最大长度的cos和sin"""
        positions = torch.arange(self.max_position, dtype=torch.float)
        freqs = positions[:, None] * self.inv_freq[None, :]  # (max_position, dim/2)
        cos = torch.cos(freqs)
        sin = torch.sin(freqs)
        self.register_buffer("cos_cached", cos, persistent=False)  # (max_position, dim/2)
        self.register_buffer("sin_cached", sin, persistent=False)
    
    def forward(self, seq_len: int, device: torch.device) -> tuple:
        """
        在训练端，我们固定最大尺寸输入是ok的。因为我们为了训练，对传进来的序列长度对齐的。设定了最大长度规则，截断规则。
        但是在预测段，最大固定尺寸max_position需要设置的大一些。大部分任务是单样本推理。此时，如果max_position设置过小，可能对生成结果有影响
        我们，一般可以和对话最大字符度相同。或者略低。
        """
        
        """根据序列长度返回对应的cos和sin"""
        assert seq_len <= self.max_position, f"seq_len ({seq_len}) 超过 max_position ({self.max_position})"
        
        # 从缓存中截取需要的部分
        cos = self.cos_cached[:seq_len].to(device)
        sin = self.sin_cached[:seq_len].to(device)
        return cos, sin
    
# 在单向or双向自注意力时可以这样操作。因为q，k同维度，当涉及交叉注意力时，此时就不能这么操作。
def apply_rotary_emb(q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> tuple:
    """应用旋转位置编码"""
    cos = cos.unsqueeze(-2)  # 扩展维度为(seq_len, 1, d_k//2) 最后通过广播机制维度扩展(batch_size, seq_len, num_heads, d_k//2)
    sin = sin.unsqueeze(-2)  # 扩展维度为(seq_len, 1, d_k//2) 最后通过广播机制维度扩展(batch_size, seq_len, num_heads, d_k//2)
    q_ = q.float()
    k_ = k.float()
    trunc = q_.shape[-1]//2
    q_rot = torch.cat([q_[..., :trunc] * cos - q_[..., trunc:] * sin,
                      q_[..., :trunc] * sin + q_[..., trunc:] * cos], dim=-1)
    k_rot = torch.cat([k_[..., :trunc] * cos - k_[..., trunc:] * sin,
                      k_[..., :trunc] * sin + k_[..., trunc:] * cos], dim=-1)
    return q_rot.type_as(q), k_rot.type_as(k)

In [6]:
import math

class MultiHeadAttention(nn.Module): 
    def __init__(self, embed_dim, num_heads, max_position=10000, dropout=0.1):
        super().__init__()
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads   
        self.qkv_proj = nn.Linear(embed_dim, embed_dim * 3)    
        self.out_linear = nn.Linear(embed_dim, embed_dim)
        self.dropout = nn.Dropout(dropout)
        
        self.scale = torch.sqrt(torch.tensor(self.head_dim, dtype=torch.float))
        
        # 初始化旋转位置编码
        self.rotary_emb = RotaryEmbedding(self.head_dim, max_position) 
        
    # 这个注意力前向传播需要调整下，由于我们编解码类型，在交叉注意力计算时，由于q，k维度不一致。不能套用常规的ROPE编码
    def forward(self, x, mask): 
        batch_size, seq_len, embed_dim = x.shape
        
        qkv = self.qkv_proj(x).view(batch_size, seq_len, 3, self.num_heads, self.head_dim)
        q, k, v = qkv[:, :, 0], qkv[:, :, 1], qkv[:, :, 2] # (batch_size, seq_len, num_heads, head_dim)

        # 生成旋转位置编码
        cos, sin = self.rotary_emb(seq_len, q.device)

        # q和k应用旋转位置编码
        q, k = apply_rotary_emb(q, k, cos, sin)

        q, k, v = q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2)
        scores = torch.matmul(q, k.transpose(-1, -2)) / self.scale
        
        # mask包含因果掩码和attention_mask
        scores = scores + mask.unsqueeze(1)
        
        weights = torch.softmax(scores, dim=-1)
        
        weights = self.dropout(weights)  # 在注意力权重上加 dropout
        
        output = torch.matmul(weights, v)
        
        output = output.transpose(1, 2).contiguous().view(batch_size, -1, embed_dim)
        
        return self.out_linear(output)
    

attn = MultiHeadAttention(512, 8)

In [7]:
params = [param.numel() for param in attn.parameters()]
print(params)
sum = 0
for i in params:
    sum += i 
print(sum)   

[786432, 1536, 262144, 512]
1050624


In [8]:
class FeedForward(nn.Module): 
    def __init__(self, embed_dim, dropout=0.1):
        super().__init__()
        self.fc1 = nn.Linear(embed_dim, embed_dim * 4)
        self.fc2 = nn.Linear(embed_dim * 4, embed_dim)
        self.gelu = nn.GELU()  # 使用 GELU 替换 ReLU
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x): 
        x = self.fc1(x)
        x = self.gelu(x)  # 添加 GELU 激活
        x = self.fc2(x)
        x = self.dropout(x)  # 在 FFN 输出上加 dropout
        return x
    
ffn = FeedForward(512)
param = [para.numel() for para in ffn.parameters()]
print(param)
sum = 0 
for i in param:
    sum += i
print(sum)

[1048576, 2048, 1048576, 512]
2099712


In [9]:
class TransformerBlock(nn.Module): 
    def __init__(self, embed_dim, num_heads, dropout=0.1):
        super().__init__()
        self.attention = MultiHeadAttention(embed_dim, num_heads)
        self.feed_forward = FeedForward(embed_dim)
        self.norm1 = nn.LayerNorm(embed_dim) 
        self.norm2 = nn.LayerNorm(embed_dim) 
        self.dropout = nn.Dropout(dropout)  # 残差连接后的 dropout
        
    def forward(self, x, mask=None): 
        attn_output = self.attention(x, mask)
        x = self.norm1(x + self.dropout(attn_output))
        ff_output = self.feed_forward(x)
        x = self.norm2(x + self.dropout(ff_output))
        
        return x
     
transformerBlock = TransformerBlock(512, 8)

params = [para.numel() for para in transformerBlock.parameters()]
print(params)
sum = 0
for i in params:
    sum += i
print(sum)

[786432, 1536, 262144, 512, 1048576, 2048, 1048576, 512, 512, 512, 512, 512]
3152384


In [10]:
torch.triu(torch.ones(size=(3, 3))*float('-inf'), diagonal=1)

tensor([[0., -inf, -inf],
        [0., 0., -inf],
        [0., 0., 0.]])

In [11]:
mask = torch.triu(torch.ones(size=(4, 4))*float('-inf'), diagonal=1)
mask = mask.unsqueeze(0)
print(mask)
padding_mask = torch.tensor([[True, True, True, False],[True, True, False, False]])
if padding_mask is not None: 
    padding_mask = padding_mask.unsqueeze(1)  # (batch_size, 1, seq_len)
    mask = mask.masked_fill(~padding_mask, float('-inf'))
    
print(mask)

tensor([[[0., -inf, -inf, -inf],
         [0., 0., -inf, -inf],
         [0., 0., 0., -inf],
         [0., 0., 0., 0.]]])
tensor([[[0., -inf, -inf, -inf],
         [0., 0., -inf, -inf],
         [0., 0., 0., -inf],
         [0., 0., 0., -inf]],

        [[0., -inf, -inf, -inf],
         [0., 0., -inf, -inf],
         [0., 0., -inf, -inf],
         [0., 0., -inf, -inf]]])


In [None]:
# 组合GPT模型
class GPT(nn.Module):
    def __init__(self, vocab_size, embed_dim, num_layers, num_heads):
        super().__init__()
        
        # 词嵌入层
        self.embedding = GPTEmbedding(vocab_size, embed_dim)      
        
        # Transformer blocks
        self.layers = nn.ModuleList([TransformerBlock(embed_dim, num_heads) for _ in range(num_layers)])
        
        # 输出头
        self.output_layer = nn.Linear(embed_dim, vocab_size)
    
    def forward(self, token_ids, padding_mask=None): 
        batch_size, seq_len = token_ids.shape
        x = self.embedding(token_ids)
        
        mask = torch.triu(torch.ones(size=(seq_len, seq_len))*float('-inf'), diagonal=1)
        mask = mask.unsqueeze(0).to(token_ids.device)
        
        if padding_mask is not None: 
            padding_mask = padding_mask.unsqueeze(1)  # (batch_size, 1, seq_len)
            mask = mask.masked_fill(~padding_mask, float('-inf'))
            
        for layer in self.layers:
            x = layer(x, mask)

        logits = self.output_layer(x)
        return logits


gpt = GPT(32000, 768, 6, 8)

sum = 0 

for name,param in gpt.named_parameters(): 
    print(f"{name} ==> 参数量：{param.numel()}")
    sum += param.numel()
    
print(sum)

embedding.embedding.weight ==> 参数量：24576000
layers.0.attention.qkv_proj.weight ==> 参数量：1769472
layers.0.attention.qkv_proj.bias ==> 参数量：2304
layers.0.attention.out_linear.weight ==> 参数量：589824
layers.0.attention.out_linear.bias ==> 参数量：768
layers.0.feed_forward.fc1.weight ==> 参数量：2359296
layers.0.feed_forward.fc1.bias ==> 参数量：3072
layers.0.feed_forward.fc2.weight ==> 参数量：2359296
layers.0.feed_forward.fc2.bias ==> 参数量：768
layers.0.norm1.weight ==> 参数量：768
layers.0.norm1.bias ==> 参数量：768
layers.0.norm2.weight ==> 参数量：768
layers.0.norm2.bias ==> 参数量：768
layers.1.attention.qkv_proj.weight ==> 参数量：1769472
layers.1.attention.qkv_proj.bias ==> 参数量：2304
layers.1.attention.out_linear.weight ==> 参数量：589824
layers.1.attention.out_linear.bias ==> 参数量：768
layers.1.feed_forward.fc1.weight ==> 参数量：2359296
layers.1.feed_forward.fc1.bias ==> 参数量：3072
layers.1.feed_forward.fc2.weight ==> 参数量：2359296
layers.1.feed_forward.fc2.bias ==> 参数量：768
layers.1.norm1.weight ==> 参数量：768
layers.1.norm1.bias ==> 参数量：

In [13]:
# # 生成词汇表
# from generate_tokenizer import gen_tokenizer

# # 传入语料库文件， 输出tokenizer的json文件
# src_path = "corpus.txt"
# tokenizer_path = "translation.json"
# gen_tokenizer(src_path, tokenizer_path)


In [14]:
from transformers import PreTrainedTokenizerFast

# 加载 tokenizer，显式设置特殊 token
tokenizer = PreTrainedTokenizerFast(
    tokenizer_file="translation.json",
    pad_token="[PAD]",
    unk_token="[UNK]",
    bos_token="[BOS]",
    eos_token="[EOS]"
)

# 打印特殊 token 和映射
print("Pad token:", tokenizer.pad_token)
print("特殊 token 映射：", tokenizer.special_tokens_map)
print("Pad token ID:", tokenizer.pad_token_id)
print("BOS token ID:", tokenizer.convert_tokens_to_ids("[BOS]"))
print("BOS token ID:", tokenizer.convert_tokens_to_ids("[BOS]"))
print("EOS token ID:", tokenizer.convert_tokens_to_ids("[EOS]"))
print("PAD token ID:", tokenizer.convert_tokens_to_ids("[PAD]"))
print("EOS token ID:", tokenizer.eos_token_id)

Pad token: [PAD]
特殊 token 映射： {'bos_token': '[BOS]', 'eos_token': '[EOS]', 'unk_token': '[UNK]', 'pad_token': '[PAD]'}
Pad token ID: 0
BOS token ID: 2
BOS token ID: 2
EOS token ID: 3
PAD token ID: 0
EOS token ID: 3


In [15]:
# 输入文本
text = "Hello, world! 你好，世界！"

# 转换为 token ID
input_ids = tokenizer.encode(text)
print(f"Token ID: {input_ids}")

# 查看分词后的 token
tokens = tokenizer.tokenize(text)
print("分词结果：", tokens)

encoded = tokenizer(
    text,
    return_tensors="pt",
    padding=True,
    truncation=True,
    max_length=512
)
print("Encoded 输出:", encoded)

# 提取 input_ids 和 attention_mask
input_ids = encoded["input_ids"]
attention_mask = encoded["attention_mask"]
print("Input IDs:", input_ids)
print("Attention Mask:", attention_mask)

# 转换为 token 查看
tokens_with_special = tokenizer.convert_ids_to_tokens(input_ids[0])
print("带特殊 token 的分词结果：", tokens_with_special)

Token ID: [2, 14860, 18, 10298, 7, 13791, 9209, 9863, 9198, 3]
分词结果： ['Hello', ',', 'world', '!', '你好', '，', '世界', '！']
Encoded 输出: {'input_ids': tensor([[    2, 14860,    18, 10298,     7, 13791,  9209,  9863,  9198,     3]]), 'token_type_ids': tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1]])}
Input IDs: tensor([[    2, 14860,    18, 10298,     7, 13791,  9209,  9863,  9198,     3]])
Attention Mask: tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1]])
带特殊 token 的分词结果： ['[BOS]', 'Hello', ',', 'world', '!', '你好', '，', '世界', '！', '[EOS]']


In [16]:
# 输入文本
text = "Hello, world! 你好，世界！"

# 转换为token ID 
input_ids = tokenizer.encode(text)
print(f"Token ID:{input_ids}")

# 查看分词后的token
tokens = tokenizer.tokenize(text)
print("分词结果：", tokens)

encoded = tokenizer(
    text,                   
    # return_tensors="pt",        # 返回 PyTorch 张量（"tf" 表示 TensorFlow，None 表示普通列表）
    padding="max_length",               # 自动填充（如果处理批量文本）
    truncation=True,            # 自动截断（如果超过最大长度）
    max_length=512              # 最大序列长度
)

print("Encoded 输出:", encoded)

# 提取 input_ids 和 attention_mask
input_ids = encoded["input_ids"]
attention_mask = encoded["attention_mask"]
print("Input IDs:", input_ids)
print("Attention Mask:", attention_mask)

Token ID:[2, 14860, 18, 10298, 7, 13791, 9209, 9863, 9198, 3]
分词结果： ['Hello', ',', 'world', '!', '你好', '，', '世界', '！']
Encoded 输出: {'input_ids': [2, 14860, 18, 10298, 7, 13791, 9209, 9863, 9198, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0

In [17]:
# 输入文本
text = "该句子的中文翻译为："

# 转换为token ID 
input_ids = tokenizer.encode(text)[1:-1]
print(f"Token ID:{input_ids}")

Token ID:[7036, 1678, 22447, 30075, 19849, 1031, 9223]


In [18]:
text = "该句子的中文翻译为："
text = dataset["train"][2]['translation']["en"]+text+dataset["train"][2]['translation']["zh"]

input = tokenizer(
            text,                   
            # return_tensors="pt",        # 返回 PyTorch 张量（"tf" 表示 TensorFlow，None 表示普通列表）
            padding="max_length",               # 自动填充（如果处理批量文本）
            truncation=True,            # 自动截断（如果超过最大长度）
            max_length=100              # 最大序列长度
        )

input

{'input_ids': [2, 9553, 13, 89, 71, 14306, 20, 11351, 9313, 27651, 9359, 20, 9617, 13, 89, 16990, 9359, 71, 9371, 18425, 20, 7036, 1678, 22447, 30075, 19849, 1031, 9223, 14481, 2060, 11186, 1180, 9209, 13750, 11540, 26767, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 'token_type_ids': [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,

In [19]:
tokenizer.pad_token_id

0

In [20]:
# 创建Dataset
from torch.utils.data import Dataset, DataLoader

class TranslationDataset(Dataset): 
    def __init__(self, dataset, tokenizer, default_text, max_length = 100):
        self.default_text = default_text
        self.dataset = dataset
        self.max_length = max_length
        self.tokenizer = tokenizer
        self.pad_token_id = [self.tokenizer.pad_token_id]

    def __len__(self): 
        return len(self.dataset)
    
    def __getitem__(self, index):
        translation = self.dataset[index]['translation']
        text = translation["en"] + self.default_text + translation["zh"]
        input = self.tokenizer(
            text,                   
            padding="max_length",               # 自动填充（如果处理批量文本）
            truncation=True,            # 自动截断（如果超过最大长度）
            max_length=self.max_length              # 最大序列长度
        )
        
        # 提取 input_ids（去掉 batch 维度）
        input_ids = input["input_ids"] # [seq_len]
        attention_mask = input["attention_mask"]
        label = input_ids[1:] + self.pad_token_id # 去掉第一个 token（ <BOS>）
        
        return [input_ids, attention_mask, label]

def collate_fn(batch):
    input_ids, attention_masks, labels = zip(*batch)
    return (torch.tensor(input_ids, dtype=torch.long)
            , torch.tensor(attention_masks, dtype=torch.bool)
            , torch.tensor(labels, dtype=torch.long))
        
train_dataset = TranslationDataset(dataset=dataset["train"], tokenizer=tokenizer, default_text="该句子的中文翻译为：", max_length=100)
train_loader = DataLoader(train_dataset, batch_size=48, shuffle=True, collate_fn=collate_fn)

In [21]:
next(iter(train_loader))[1].shape

torch.Size([48, 100])

In [22]:
#训练模型

vocab_size = tokenizer.vocab_size
embed_dim = 512
num_heads = 8
d_ff = 2048
num_layers = 6
dropout = 0.1

model = GPT(vocab_size, embed_dim, num_layers, num_heads)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)

loss_fn = nn.CrossEntropyLoss(ignore_index=tokenizer.pad_token_id)
optimizer = torch.optim.Adam(model.parameters(), lr = 0.0001)

In [23]:
param_num = [para.numel() for para in model.parameters()]
sum = 0 
for n in param_num:
    sum += n
    
print(sum)

51714304


In [24]:
print(next(iter(train_loader))[1].shape)
aa = next(iter(train_loader))[1].unsqueeze(1).unsqueeze(2)

print(aa.shape)


tgt_test = next(iter(train_loader))[2]
tgt_test

torch.Size([48, 100])
torch.Size([48, 1, 1, 100])


tensor([[10037,  9313,  9356,  ...,     0,     0,     0],
        [ 9604, 13419, 20922,  ...,     0,     0,     0],
        [10089,  9468,  9709,  ...,     0,     0,     0],
        ...,
        [   47, 10340, 10236,  ...,     0,     0,     0],
        [14423,    18,  9320,  ...,  1425,     3,     0],
        [   47,    13,    83,  ...,     0,     0,     0]])

In [25]:
from tqdm import tqdm
from torch.amp import autocast, GradScaler

scaler = GradScaler()

pad_token_id = tokenizer.convert_tokens_to_ids(tokenizer.pad_token)
num_epochs = 10
for epoch in range(num_epochs): 
    model.train()
    total_loss = 0
    for (input_ids, attention_masks, labels) in tqdm(train_loader):
        input_ids = input_ids.to(device)
        attention_masks = attention_masks.to(device)
        labels = labels.to(device)

        
        optimizer.zero_grad()
        
        with autocast(device_type="cuda", dtype=torch.float16):
            logits = model(input_ids, attention_masks)
            loss = loss_fn(logits.view(-1, vocab_size), labels.view(-1))

        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
                
        total_loss += loss.item()
        
    print(f"Epoch {epoch + 1} / {num_epochs}, Loss:{total_loss / len(train_loader)}")
    
        

 22%|██▏       | 4553/20834 [07:48<27:56,  9.71it/s]


KeyboardInterrupt: 

In [84]:
def translate(model, src_text, max_length=10000):
    model.eval()
    input = tokenizer(text=src_text)["input_ids"][:-1]
    eos_id = tokenizer.eos_token_id

    for _ in range(max_length):
        with torch.no_grad():
            with autocast(device_type="cuda", dtype=torch.float16):
                logits = model(torch.tensor([input], dtype=torch.long).to(device))
            next_token = logits[0, -1].argmax().item()
        if next_token == eos_id:
            break
        input.append(next_token)

    tokens = tokenizer.convert_ids_to_tokens(input)[1:]
    output = ""
    for i, token in enumerate(tokens):
        # 清洗BPE空格符并分类字符
        clean_token = token.replace('▁', '')
        is_chinese = all(0x4E00 <= ord(c) <= 0x9FFF for c in clean_token)
        is_eng_punct = clean_token in {".", ",", "!", "?", ":"}  # 英文标点
        is_cn_punct = clean_token in {"，", "。", "！", "？", "："}  # 中文标点
        
        # 空格判断逻辑
        space_needed = False
        if i > 0:
            prev_clean = tokens[i-1].replace('▁', '')
            prev_eng_punct = prev_clean in {".", ",", "!", "?", ":"}
            prev_cn_punct = prev_clean in {"，", "。", "！", "？", "："}
            
            # 英文标点后需要空格的条件
            if prev_eng_punct and not (is_chinese or is_cn_punct):
                space_needed = True
            # 连续英文非标点需要空格
            elif not (is_chinese or is_eng_punct or is_cn_punct) and \
                 not (prev_cn_punct or prev_eng_punct or prev_clean == ""):
                space_needed = True
        
        output += (" " if space_needed else "") + clean_token
    
    return output

In [79]:
import torch
torch.cuda.is_available()

True

In [86]:


src_text = "Hello? I'm here" 

translated = translate(model, src_text)

print(f"翻译结果：{translated}")

翻译结果：Hello? I ' m here.该句子的中文翻译为：-哦，我
