In [109]:
from datasets import load_dataset

dataset = load_dataset("opus100", "en-zh")
print(dataset)

DatasetDict({
    test: Dataset({
        features: ['translation'],
        num_rows: 2000
    })
    train: Dataset({
        features: ['translation'],
        num_rows: 1000000
    })
    validation: Dataset({
        features: ['translation'],
        num_rows: 2000
    })
})


In [110]:
print(dataset['train'][0])

{'translation': {'en': 'Sixty-first session', 'zh': '第六十一届会议'}}


In [111]:
import torch
import torch.nn as nn 

class T5Embedding(nn.Module):
    def __init__(self, vocab_size, d_model):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, d_model)
    
    def forward(self, x):
        return self.embedding(x)

In [112]:
import math 

class PositionalEncoding(nn.Module): 
    def __init__(self, d_model, max_len=5000): 
        super().__init__()
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float()*(-math.log(10000.) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)
        self.register_buffer("pe", pe)
    
    def forward(self, x): 
        seq_len = x.size(1)
        return x + self.pe[:, :seq_len].to(x.device)


pe = PositionalEncoding(512)
x = torch.randn(size=(2,5,512))
# print(x[1,0,:])

[para.numel() for para in pe.parameters()]

[]

In [113]:
# 旋转位置编码
from typing import Tuple


class RotaryEmbedding(nn.Module): 
    # 旋转位置编码实现
    def __init__(self, dim: int, max_position: int = 10000): 
        super().__init__()
        assert dim % 2 == 0, "dim必须是偶数"
        self.dim = dim
        self.max_position = max_position
        
        # 预计算频率
        inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))
        self.register_buffer("inv_freq", inv_freq)
        
        # 预计算所有位置的cos和sin
        self._precompute_embeddings()
    
    def _precompute_embeddings(self):
        """预计算最大长度的cos和sin"""
        positions = torch.arange(self.max_position, dtype=torch.float)
        freqs = positions[:, None] * self.inv_freq[None, :]  # (max_position, dim/2)
        cos = torch.cos(freqs)
        sin = torch.sin(freqs)
        self.register_buffer("cos_cached", cos, persistent=False)  # (max_position, dim/2)
        self.register_buffer("sin_cached", sin, persistent=False)
    
    def forward(self, seq_len: int, device: torch.device) -> tuple:
        """
        在训练端，我们固定最大尺寸输入是ok的。因为我们为了训练，对传进来的序列长度对齐的。设定了最大长度规则，截断规则。
        但是在预测段，最大固定尺寸max_position需要设置的大一些。大部分任务是单样本推理。此时，如果max_position设置过小，可能对生成结果有影响
        我们，一般可以和对话最大字符度相同。或者略低。
        """
        
        """根据序列长度返回对应的cos和sin"""
        assert seq_len <= self.max_position, f"seq_len ({seq_len}) 超过 max_position ({self.max_position})"
        
        # 从缓存中截取需要的部分
        cos = self.cos_cached[:seq_len].to(device)
        sin = self.sin_cached[:seq_len].to(device)
        return cos, sin
    
# 在单向or双向自注意力时可以这样操作。因为q，k同维度，当涉及交叉注意力时，此时就不能这么操作了，在训练时，由于我们设置max_length都一样长，可能并无感知。
# 但在训练时，编码端长度和解码端长度大多数情况下是未对齐的，此时就会报错。 
def apply_rotary_emb(q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> tuple:
    """应用旋转位置编码"""
    # 调整 cos 和 sin 的形状以匹配 q 和 k
    # cos, sin 原本shape是 (seq_len, d_k//2)
    cos = cos.unsqueeze(-2)  # 扩展维度为(seq_len, 1, d_k//2) 最后通过广播机制维度扩展(batch_size, seq_len, num_heads, d_k//2)
    sin = sin.unsqueeze(-2)  # 扩展维度为(seq_len, 1, d_k//2) 最后通过广播机制维度扩展(batch_size, seq_len, num_heads, d_k//2)
    q_ = q.float()
    k_ = k.float()
    trunc = q_.shape[-1] // 2
    
    q_rot = torch.cat([q_[..., :trunc] * cos - q_[..., trunc:] * sin,
                      q_[..., :trunc] * sin + q_[..., trunc:] * cos], dim=-1)
    k_rot = torch.cat([k_[..., :trunc] * cos - k_[..., trunc:] * sin,
                      k_[..., :trunc] * sin + k_[..., trunc:] * cos], dim=-1)
    return q_rot.type_as(q), k_rot.type_as(k)

# 在交叉注意力时，由于q，k的长度不一致。如果我们用常规的ROPE就不行，维度没有对齐。因此，在交叉注意力计算时，我们需要计算各自的ROPE值
def apply_rotary_emb_single(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor:
    """应用旋转位置编码到单个张量"""
    cos = cos.unsqueeze(-2)  # 扩展维度为(seq_len, 1, d_k//2) 最后通过广播机制维度扩展(batch_size, seq_len, num_heads, d_k//2)
    sin = sin.unsqueeze(-2)  # 扩展维度为(seq_len, 1, d_k//2) 最后通过广播机制维度扩展(batch_size, seq_len, num_heads, d_k//2)
    x_ = x.float()
    trunc = x_.shape[-1] // 2
    x_rot = torch.cat([
        x_[..., :trunc] * cos - x_[..., trunc:] * sin,
        x_[..., :trunc] * sin + x_[..., trunc:] * cos
    ], dim=-1)
    return x_rot.type_as(x)

In [114]:
from flash_attn import flash_attn_func

class MultiHeadAttention(nn.Module): 
    def __init__(self, d_model, num_heads, max_position=10000, dropout=0.1):
        super().__init__()
        assert d_model % num_heads == 0, "d_model必须被num_heads整除"
        self.dropout = dropout
        self.num_heads = num_heads
        self.d_model = d_model
        self.d_k = d_model // num_heads
        self.q_linear = nn.Linear(d_model, d_model)    
        self.k_linear = nn.Linear(d_model, d_model)    
        self.v_linear = nn.Linear(d_model, d_model)    
        self.out_linear = nn.Linear(d_model, d_model)
        
        # 初始化旋转位置编码
        self.rotary_emb = RotaryEmbedding(self.d_k, max_position) 
        
    # 始终应用掩码（如果 mask 为 None，则传入全 1 掩码） 这样设计的目的是，方便模型后续输出为onnx格式，该格式不建议if for等语句
    # src_mask shape可以是：(batch_size, 1, 1, seq_len)， 通过广播同步维度到 (batch_size, num_heads, seq_len, seq_len)
    # tgt_mask shape可以是：(batch_size, 1, seq_len, seq_len)， 通过广播同步维度到 (batch_size, num_heads, seq_len, seq_len)
    def forward(self, q, k, v, mask=None, use_flash_attn=False): 
        batch_size = q.size(0)
        q_seq_len = q.size(1)  # q 的序列长度
        k_seq_len = k.size(1)  # k 的序列长度
        
        # 线性变换并分割为多头
        q = self.q_linear(q).view(batch_size, -1, self.num_heads, self.d_k)
        k = self.k_linear(k).view(batch_size, -1, self.num_heads, self.d_k)
        v = self.v_linear(v).view(batch_size, -1, self.num_heads, self.d_k)
        
        # 为 q 和 k 分别生成旋转位置编码
        cos_q, sin_q = self.rotary_emb(q_seq_len, q.device)
        cos_k, sin_k = self.rotary_emb(k_seq_len, k.device)

        # 分别应用旋转位置编码
        q = apply_rotary_emb_single(q, cos_q, sin_q)
        k = apply_rotary_emb_single(k, cos_k, sin_k)


        # 单双向自注意力时可以使用，交叉注意力时，需要分别计算。
        # q, k = apply_rotary_emb(q, k, cos, sin)
        
        if use_flash_attn:
            # multi-attention 训练模式
            # multi-attention 要求输入格式为 (batch_size, num_heads, seq_len, d_k)
            q, k, v = q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2)
            scores = torch.matmul(q, k.transpose(-1, -2)) / math.sqrt(self.d_k)
            
            # -1e9在半精度等训练时，会溢出，这里选择-1e4
            # scores = scores.masked_fill(mask==0, -1e9)
            scores = scores.masked_fill(mask==0, -1e4)
            attention = torch.softmax(scores, dim=-1)
            output = torch.matmul(attention, v)
            output = output.transpose(1, 2).contiguous().view(batch_size, -1, self.d_model)
        else: 
            # flash-attention 推理模式
            # Flash Attention 要求输入格式为 (batch_size, seq_len, num_heads, d_k)
            output = flash_attn_func(
                q, k, v,
                dropout_p=0, # 推理时，不使用dropout
                softmax_scale=1. / math.sqrt(self.d_k), # 缩放因子
                causal=True, # 内置生成因果掩码，但是他不对ipaddng等处理。所以，在训练端不合适
                # 在推理端生成任务时，由于没有padding的影响。解码端或gpt等使用内置掩码即可，如果是编解码，编码端还是需要传入attention_mask
            )             
            output = output.view(batch_size, -1, self.d_model)
        
        return self.out_linear(output)
    

attn = MultiHeadAttention(512, 8)
q = torch.randn(2, 10, 512)

# attn(q, q, q)

In [115]:
params = [param.numel() for param in attn.parameters()]
print(params)
sum = 0
for i in params:
    sum += i 
print(sum)   

[262144, 512, 262144, 512, 262144, 512, 262144, 512]
1050624


In [116]:
class FeedForward(nn.Module): 
    def __init__(self, d_model, d_ff):
        super().__init__()
        self.linear1 = nn.Linear(d_model, d_ff)
        self.linear2 = nn.Linear(d_ff, d_model)
        
    def forward(self, x): 
        return self.linear2(torch.relu(self.linear1(x)))
    
ffn = FeedForward(512, 2048)
param = [para.numel() for para in ffn.parameters()]
print(param)
sum = 0 
for i in param:
    sum += i
print(sum)

[1048576, 2048, 1048576, 512]
2099712


In [117]:
class EncoderLayer(nn.Module): 
    def __init__(self, d_model, num_heads, d_ff, dropout=0.1):
        super().__init__()
        self.self_attn = MultiHeadAttention(d_model, num_heads)
        self.ff = FeedForward(d_model, d_ff)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x, mask=None):
        x = x + self.dropout(self.self_attn(self.norm1(x), self.norm1(x), self.norm1(x), mask))
        x = x + self.dropout(self.ff(self.norm2(x)))
        return x 

encoder_block = EncoderLayer(512, 8, 2048, 0.1)
params = [param.numel() for param in encoder_block.parameters()]
print(params)
sum = 0 
for i in params:
    sum += i
    
# 参数量统计
print(sum)

[262144, 512, 262144, 512, 262144, 512, 262144, 512, 1048576, 2048, 1048576, 512, 512, 512, 512, 512]
3152384


In [118]:
class DecoderLayer(nn.Module): 
    def __init__(self, d_model, num_heads, d_ff, dropout=0.1):
        super().__init__()
        self.self_attn = MultiHeadAttention(d_model, num_heads)
        self.cross_attn = MultiHeadAttention(d_model, num_heads)
        self.ff = FeedForward(d_model, d_ff)
        self.norm1 = nn.LayerNorm(d_model) 
        self.norm2 = nn.LayerNorm(d_model) 
        self.norm3 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x, enc_output, src_mask=None, tgt_mask=None, use_flash_attn=False): 
         x = x + self.dropout(self.self_attn(self.norm1(x), self.norm1(x), self.norm1(x), tgt_mask, use_flash_attn))
         x = x + self.dropout(self.cross_attn(self.norm2(x), enc_output, enc_output, src_mask, use_flash_attn))
         x = x + self.dropout(self.ff(self.norm3(x)))
         return x 
     
decoder_block = DecoderLayer(512, 8, 2048)

params = [para.numel() for para in decoder_block.parameters()]
print(params)
sum = 0
for i in params:
    sum += i
print(sum)

[262144, 512, 262144, 512, 262144, 512, 262144, 512, 262144, 512, 262144, 512, 262144, 512, 262144, 512, 1048576, 2048, 1048576, 512, 512, 512, 512, 512, 512, 512]
4204032


In [130]:
# 组合T5模型
class T5Model(nn.Module):
    def __init__(self, vocab_size, d_model, num_heads, d_ff, num_layers, dropout=0.1):
        super().__init__()
        self.embedding = T5Embedding(vocab_size, d_model)
        
        # 使用了相对位置编码ROPE，这里注释掉去
        # self.pos_encoding = PositionalEncoding(d_model)
        
        self.encoder_layers = nn.ModuleList([EncoderLayer(d_model, num_heads, d_ff, dropout) for _ in range(num_layers)])
        self.decoder_layers = nn.ModuleList([DecoderLayer(d_model, num_heads, d_ff, dropout) for _ in range(num_layers)])
        self.output_layer = nn.Linear(d_model, vocab_size)
    
    def forward(self, src_input, tgt_input, src_mask=None, tgt_mask=None, use_flash_attn=False): 
        # src_emb = self.pos_encoding(self.embedding(src_input)) # 不需要经过绝对位置编码
        # tgt_emb = self.pos_encoding(self.embedding(tgt_input)) # 不需要经过绝对位置编码
        src_emb = self.embedding(src_input)
        tgt_emb = self.embedding(tgt_input)
        
        enc_output = src_emb 
        for layer in self.encoder_layers:
            enc_output = layer(enc_output, src_mask)
        
        dec_output = tgt_emb
        for layer in self.decoder_layers:
            dec_output = layer(dec_output, enc_output, src_mask, tgt_mask, use_flash_attn)
        
        return self.output_layer(dec_output)
    
t5 = T5Model(20000, 768, 8, 2048, 6, 0.1)

sum = 0 

for name,param in t5.named_parameters(): 
    print(f"{name} ==> 参数量：{param.numel()}")
    sum += param.numel()
    
print(sum)

embedding.embedding.weight ==> 参数量：15360000
encoder_layers.0.self_attn.q_linear.weight ==> 参数量：589824
encoder_layers.0.self_attn.q_linear.bias ==> 参数量：768
encoder_layers.0.self_attn.k_linear.weight ==> 参数量：589824
encoder_layers.0.self_attn.k_linear.bias ==> 参数量：768
encoder_layers.0.self_attn.v_linear.weight ==> 参数量：589824
encoder_layers.0.self_attn.v_linear.bias ==> 参数量：768
encoder_layers.0.self_attn.out_linear.weight ==> 参数量：589824
encoder_layers.0.self_attn.out_linear.bias ==> 参数量：768
encoder_layers.0.ff.linear1.weight ==> 参数量：1572864
encoder_layers.0.ff.linear1.bias ==> 参数量：2048
encoder_layers.0.ff.linear2.weight ==> 参数量：1572864
encoder_layers.0.ff.linear2.bias ==> 参数量：768
encoder_layers.0.norm1.weight ==> 参数量：768
encoder_layers.0.norm1.bias ==> 参数量：768
encoder_layers.0.norm2.weight ==> 参数量：768
encoder_layers.0.norm2.bias ==> 参数量：768
encoder_layers.1.self_attn.q_linear.weight ==> 参数量：589824
encoder_layers.1.self_attn.q_linear.bias ==> 参数量：768
encoder_layers.1.self_attn.k_linear.weig

In [131]:
# # 生成词汇表
# from generate_tokenizer import gen_tokenizer

# # 传入语料库文件， 输出tokenizer的json文件
# src_path = "corpus.txt"
# tokenizer_path = "translation.json"
# gen_tokenizer(src_path, tokenizer_path)


In [132]:
from transformers import PreTrainedTokenizerFast

# 加载 tokenizer，显式设置特殊 token
tokenizer = PreTrainedTokenizerFast(
    tokenizer_file="translation.json",
    pad_token="[PAD]",
    unk_token="[UNK]",
    bos_token="[BOS]",
    eos_token="[EOS]"
)

# 打印特殊 token 和映射
print("Pad token:", tokenizer.pad_token)
print("特殊 token 映射：", tokenizer.special_tokens_map)
print("Pad token ID:", tokenizer.pad_token_id)
print("BOS token ID:", tokenizer.convert_tokens_to_ids("[BOS]"))
print("BOS token ID:", tokenizer.convert_tokens_to_ids("[BOS]"))
print("EOS token ID:", tokenizer.convert_tokens_to_ids("[EOS]"))
print("PAD token ID:", tokenizer.convert_tokens_to_ids("[PAD]"))
print("EOS token ID:", tokenizer.eos_token_id)

Pad token: [PAD]
特殊 token 映射： {'bos_token': '[BOS]', 'eos_token': '[EOS]', 'unk_token': '[UNK]', 'pad_token': '[PAD]'}
Pad token ID: 0
BOS token ID: 2
BOS token ID: 2
EOS token ID: 3
PAD token ID: 0
EOS token ID: 3


In [133]:
# 输入文本
text = "Hello, world! 你好，世界！"

# 转换为 token ID
input_ids = tokenizer.encode(text)
print(f"Token ID: {input_ids}")

# 查看分词后的 token
tokens = tokenizer.tokenize(text)
print("分词结果：", tokens)

encoded = tokenizer(
    text,
    return_tensors="pt",
    padding=True,
    truncation=True,
    max_length=512
)
print("Encoded 输出:", encoded)

# 提取 input_ids 和 attention_mask
input_ids = encoded["input_ids"]
attention_mask = encoded["attention_mask"]
print("Input IDs:", input_ids)
print("Attention Mask:", attention_mask)

# 转换为 token 查看
tokens_with_special = tokenizer.convert_ids_to_tokens(input_ids[0])
print("带特殊 token 的分词结果：", tokens_with_special)

Token ID: [2, 14860, 18, 10298, 7, 13791, 9209, 9863, 9198, 3]
分词结果： ['Hello', ',', 'world', '!', '你好', '，', '世界', '！']
Encoded 输出: {'input_ids': tensor([[    2, 14860,    18, 10298,     7, 13791,  9209,  9863,  9198,     3]]), 'token_type_ids': tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1]])}
Input IDs: tensor([[    2, 14860,    18, 10298,     7, 13791,  9209,  9863,  9198,     3]])
Attention Mask: tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1]])
带特殊 token 的分词结果： ['[BOS]', 'Hello', ',', 'world', '!', '你好', '，', '世界', '！', '[EOS]']


In [134]:
# 输入文本
text = "Hello, world! 你好，世界！"

# 转换为token ID 
input_ids = tokenizer.encode(text)
print(f"Token ID:{input_ids}")

# 查看分词后的token
tokens = tokenizer.tokenize(text)
print("分词结果：", tokens)

encoded = tokenizer(
    text,                   
    # return_tensors="pt",        # 返回 PyTorch 张量（"tf" 表示 TensorFlow，None 表示普通列表）
    padding="max_length",               # 自动填充（如果处理批量文本）
    truncation=True,            # 自动截断（如果超过最大长度）
    max_length=512              # 最大序列长度
)

print("Encoded 输出:", encoded)

# 提取 input_ids 和 attention_mask
input_ids = encoded["input_ids"]
attention_mask = encoded["attention_mask"]
print("Input IDs:", input_ids)
print("Attention Mask:", attention_mask)

Token ID:[2, 14860, 18, 10298, 7, 13791, 9209, 9863, 9198, 3]
分词结果： ['Hello', ',', 'world', '!', '你好', '，', '世界', '！']
Encoded 输出: {'input_ids': [2, 14860, 18, 10298, 7, 13791, 9209, 9863, 9198, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0

In [135]:
# 创建Dataset
from torch.utils.data import Dataset, DataLoader

class TranslationDataset(Dataset): 
    def __init__(self, dataset, tokenizer, max_length = 10):
        self.dataset = dataset
        self.max_length = max_length
        self.tokenizer = tokenizer
        self.pad_token_id = [self.tokenizer.convert_tokens_to_ids(self.tokenizer.pad_token)]

    def __len__(self): 
        return len(self.dataset)
    
    def __getitem__(self, index):
        translation = self.dataset[index]['translation']
        input = self.tokenizer(
            translation["en"],                   
            # return_tensors="pt",        # 返回 PyTorch 张量（"tf" 表示 TensorFlow，None 表示普通列表）
            padding="max_length",               # 自动填充（如果处理批量文本）
            truncation=True,            # 自动截断（如果超过最大长度）
            max_length=self.max_length              # 最大序列长度
        )
        output = self.tokenizer(
            translation["zh"],                   
            # return_tensors="pt",        # 返回 PyTorch 张量（"tf" 表示 TensorFlow，None 表示普通列表）
            padding="max_length",               # 自动填充（如果处理批量文本）
            truncation=True,            # 自动截断（如果超过最大长度）
            max_length=self.max_length              # 最大序列长度
        )

        # 提取 input_ids（去掉 batch 维度）
        src_input = input["input_ids"] # [seq_len]
        src_attention_mask = input["attention_mask"]
        tgt_input = output["input_ids"] # [seq_len] 
        tgt_output = tgt_input[1:] + self.pad_token_id # 去掉第一个 token（通常是 <BOS>）
        
        return [src_input, src_attention_mask, tgt_input, tgt_output]

def collate_fn(batch):
    src_input, src_attention_mask, tgt_input, tgt_output = zip(*batch)
    return (torch.tensor(src_input, dtype=torch.long)
            , torch.tensor(src_attention_mask, dtype=torch.long).unsqueeze(1).unsqueeze(2)
            , torch.tensor(tgt_input, dtype=torch.long)
            , torch.tensor(tgt_output, dtype=torch.long))
        
train_dataset = TranslationDataset(dataset=dataset["train"], tokenizer=tokenizer, max_length=10)
train_loader = DataLoader(train_dataset, batch_size=256, shuffle=True, collate_fn=collate_fn)

In [136]:
#训练模型

vocab_size = tokenizer.vocab_size
d_model = 512
num_heads = 8
d_ff = 2048
num_layers = 6
dropout = 0.1

model = T5Model(vocab_size, d_model, num_heads, d_ff, num_layers, dropout)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)

loss_fn = nn.CrossEntropyLoss(ignore_index=tokenizer.convert_tokens_to_ids(tokenizer.pad_token))
optimizer = torch.optim.Adam(model.parameters(), lr = 0.0001)

In [137]:
param_num = [para.numel() for para in model.parameters()]
sum = 0 
for n in param_num:
    sum += n
    
print(sum)

76938496


In [138]:
print(next(iter(train_loader))[1].shape)
aa = next(iter(train_loader))[1].unsqueeze(1).unsqueeze(2)

print(aa.shape)


tgt_test = next(iter(train_loader))[2]
tgt_test

torch.Size([256, 1, 1, 10])
torch.Size([256, 1, 1, 1, 1, 10])


tensor([[    2, 18129,  1701,  ...,  2480,  5911,     3],
        [    2,  2060, 12497,  ...,  9574, 10031,     3],
        [    2, 10464, 10264,  ...,  4038, 24211,     3],
        ...,
        [    2, 27628,    20,  ..., 11399, 12046,     3],
        [    2,  7432, 10031,  ..., 11077,  7050,     3],
        [    2, 17166,  2402,  ...,     0,     0,     0]])

In [139]:
def create_mask(tgt, pad_idx): 
    tgt_seq_len = tgt.size(1)
    tgt_mask = torch.tril(torch.ones((tgt_seq_len, tgt_seq_len))).bool().to(tgt.device)
    tgt_mask = tgt_mask & (tgt != pad_idx).unsqueeze(1).unsqueeze(2)
    return tgt_mask 

pad_token_id = tokenizer.convert_tokens_to_ids(tokenizer.pad_token)
print(pad_token_id)

test_size = tgt_test.shape[1]

tgt_mask = torch.tril(torch.ones((test_size, test_size))).bool()
tgt_mask = tgt_mask & (tgt_test != pad_token_id).unsqueeze(1).unsqueeze(2)
print(test_size)
tgt_mask

0
10


tensor([[[[ True, False, False,  ..., False, False, False],
          [ True,  True, False,  ..., False, False, False],
          [ True,  True,  True,  ..., False, False, False],
          ...,
          [ True,  True,  True,  ...,  True, False, False],
          [ True,  True,  True,  ...,  True,  True, False],
          [ True,  True,  True,  ...,  True,  True,  True]]],


        [[[ True, False, False,  ..., False, False, False],
          [ True,  True, False,  ..., False, False, False],
          [ True,  True,  True,  ..., False, False, False],
          ...,
          [ True,  True,  True,  ...,  True, False, False],
          [ True,  True,  True,  ...,  True,  True, False],
          [ True,  True,  True,  ...,  True,  True,  True]]],


        [[[ True, False, False,  ..., False, False, False],
          [ True,  True, False,  ..., False, False, False],
          [ True,  True,  True,  ..., False, False, False],
          ...,
          [ True,  True,  True,  ...,  True, Fa

In [140]:
from tqdm import tqdm
from torch.amp import autocast, GradScaler

scaler = GradScaler()

pad_token_id = tokenizer.convert_tokens_to_ids(tokenizer.pad_token)
num_epochs = 1
use_flash_attn = False
for epoch in range(num_epochs): 
    model.train()
    total_loss = 0
    for (src_input, src_attention_mask, tgt_input, tgt_output) in tqdm(train_loader):
        src_input = src_input.to(device)
        tgt_input = tgt_input.to(device)
        tgt_output = tgt_output.to(device)
        src_attention_mask = src_attention_mask.to(device)
        
        tgt_mask = create_mask(tgt_input, pad_token_id)
        
        optimizer.zero_grad()
        with autocast(device_type="cuda", dtype=torch.float16):
            logits = model(src_input, tgt_input, src_attention_mask, tgt_mask, use_flash_attn)    
            loss = loss_fn(logits.view(-1, vocab_size), tgt_output.view(-1))
        
        scaler.scale(loss).backward()        
        scaler.step(optimizer)
        scaler.update()
                
        
        total_loss += loss.item()
        
    print(f"Epoch {epoch + 1} / {num_epochs}, Loss:{total_loss / len(train_loader)}")
    
        

  6%|▋         | 252/3907 [00:47<11:32,  5.28it/s]


KeyboardInterrupt: 

In [141]:
def translate(model, src_text, max_length=10000): 
    model.eval()
    input = tokenizer(text=src_text, return_tensors="pt")
    src_input = input['input_ids'].to(device)
    src_mask_attention = input['attention_mask'].to(device)
    print(src_input)
    print(src_mask_attention)

    bos_id = tokenizer.bos_token_id
    eos_id = tokenizer.eos_token_id

    tgt_ids = [bos_id]
    for _ in range(max_length):
        tgt_input = torch.tensor([tgt_ids], dtype=torch.long).to(device)
        # 使用falsh_attention模式，不需要mask
        tgt_mask = torch.tril(torch.ones((len(tgt_ids), len(tgt_ids)))).bool().to(device)
        tgt_mask = tgt_mask & (tgt_input != pad_token_id).unsqueeze(1).unsqueeze(2)
        
        with torch.no_grad():
            with autocast(device_type="cuda", dtype=torch.float16):
                logits = model(src_input, tgt_input, src_mask_attention, use_flash_attn=True)
            next_token = logits[0, -1].argmax().item()
            
        if next_token == eos_id:
            break 
        
        tgt_ids.append(next_token)
        

    return ''.join(tokenizer.convert_ids_to_tokens(tgt_ids)[1:])

In [142]:


src_text = "maybe you know!"

translated = translate(model, src_text)

print(f"翻译结果：{translated}")

tensor([[    2, 15363,  9359,  9674,     7,     3]], device='cuda:0')
tensor([[1, 1, 1, 1, 1, 1]], device='cuda:0')


TypeError: masked_fill() received an invalid combination of arguments - got (bool, float), but expected one of:
 * (Tensor mask, Tensor value)
      didn't match because some of the arguments have invalid types: (!bool!, !float!)
 * (Tensor mask, Number value)
      didn't match because some of the arguments have invalid types: (!bool!, !float!)
