In [1]:
from datasets import load_dataset

dataset = load_dataset("opus100", "en-zh")
print(dataset)

DatasetDict({
    test: Dataset({
        features: ['translation'],
        num_rows: 2000
    })
    train: Dataset({
        features: ['translation'],
        num_rows: 1000000
    })
    validation: Dataset({
        features: ['translation'],
        num_rows: 2000
    })
})


In [2]:
print(dataset['train'][0])

{'translation': {'en': 'Sixty-first session', 'zh': '第六十一届会议'}}


In [3]:
import torch
import torch.nn as nn 

class T5Embedding(nn.Module):
    def __init__(self, vocab_size, d_model):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, d_model)
    
    def forward(self, x):
        return self.embedding(x)

In [8]:
import math 

class PositionalEncoding(nn.Module): 
    def __init__(self, d_model, max_len=5000): 
        super().__init__()
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float()*(-math.log(10000.) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)
        self.register_buffer("pe", pe)
    
    def forward(self, x): 
        seq_len = x.size(1)
        return x + self.pe[:, :seq_len].to(x.device)


pe = PositionalEncoding(512)
x = torch.randn(size=(2,5,512))
# print(x[1,0,:])

[para.numel() for para in pe.parameters()]

[]

In [None]:
class MultiHeadAttention(nn.Module): 
    def __init__(self, d_model, num_heads):
        super().__init__()
        assert d_model % num_heads == 0, "d_model必须被num_heads整除"
        self.num_heads = num_heads
        self.d_model = d_model
        self.d_k = d_model // num_heads
        self.q_linear = nn.Linear(d_model, d_model)    
        self.k_linear = nn.Linear(d_model, d_model)    
        self.v_linear = nn.Linear(d_model, d_model)    
        self.out_linear = nn.Linear(d_model, d_model)
        
    def forward(self, q, k, v, mask=None): 
        batch_size = q.size(0)
        q = self.q_linear(q).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
        k = self.k_linear(k).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
        v = self.v_linear(v).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)

        scores = torch.matmul(q, k.transpose(-1, -2)) / math.sqrt(self.d_k)
        
        if mask is not None: 
            scores = scores.masked_fill(mask==0, -1e9)
        
        attention = torch.softmax(scores, dim=-1)
        output = torch.matmul(attention, v)
        
        output = output.transpose(1, 2).contiguous().view(batch_size, -1, self.d_model)
        
        return self.out_linear(output)
    

attn = MultiHeadAttention(512, 512)

In [7]:
params = [param.numel() for param in attn.parameters()]
print(params)
sum = 0
for i in params:
    sum += i 
print(sum)   

[262144, 512, 262144, 512, 262144, 512, 262144, 512]
1050624


In [9]:
class FeedForward(nn.Module): 
    def __init__(self, d_model, d_ff):
        super().__init__()
        self.linear1 = nn.Linear(d_model, d_ff)
        self.linear2 = nn.Linear(d_ff, d_model)
        
    def forward(self, x): 
        return self.linear2(torch.relu(self.linear1(x)))
    
ffn = FeedForward(512, 2048)
param = [para.numel() for para in ffn.parameters()]
print(param)
sum = 0 
for i in param:
    sum += i
print(sum)

[1048576, 2048, 1048576, 512]
2099712


In [10]:
class EncoderLayer(nn.Module): 
    def __init__(self, d_model, num_heads, d_ff, dropout=0.1):
        super().__init__()
        self.self_attn = MultiHeadAttention(d_model, num_heads)
        self.ff = FeedForward(d_model, d_ff)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x, mask=None):
        x = x + self.dropout(self.self_attn(self.norm1(x), self.norm1(x), self.norm1(x), mask))
        x = x + self.dropout(self.ff(self.norm2(x)))
        return x 

encoder_block = EncoderLayer(512, 8, 2048, 0.1)
params = [param.numel() for param in encoder_block.parameters()]
print(params)
sum = 0 
for i in params:
    sum += i
    
# 参数量统计
print(sum)

[262144, 512, 262144, 512, 262144, 512, 262144, 512, 1048576, 2048, 1048576, 512, 512, 512, 512, 512]
3152384


In [11]:
class DecoderLayer(nn.Module): 
    def __init__(self, d_model, num_heads, d_ff, dropout=0.1):
        super().__init__()
        self.self_attn = MultiHeadAttention(d_model, num_heads)
        self.cross_attn = MultiHeadAttention(d_model, num_heads)
        self.ff = FeedForward(d_model, d_ff)
        self.norm1 = nn.LayerNorm(d_model) 
        self.norm2 = nn.LayerNorm(d_model) 
        self.norm3 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x, enc_output, src_mask=None, tgt_mask=None): 
         x = x + self.dropout(self.self_attn(self.norm1(x), self.norm1(x), self.norm1(x), tgt_mask))
         x = x + self.dropout(self.cross_attn(self.norm2(x), enc_output, enc_output, src_mask))
         x = x + self.dropout(self.ff(self.norm3(x)))
         return x 
     
decoder_block = DecoderLayer(512, 8, 2048)

params = [para.numel() for para in decoder_block.parameters()]
print(params)
sum = 0
for i in params:
    sum += i
print(sum)

[262144, 512, 262144, 512, 262144, 512, 262144, 512, 262144, 512, 262144, 512, 262144, 512, 262144, 512, 1048576, 2048, 1048576, 512, 512, 512, 512, 512, 512, 512]
4204032


In [12]:
# 组合T5模型
class T5Model(nn.Module):
    def __init__(self, vocab_size, d_model, num_heads, d_ff, num_layers, dropout=0.1):
        super().__init__()
        self.embedding = T5Embedding(vocab_size, d_model)
        self.pos_encoding = PositionalEncoding(d_model)
        self.encoder_layers = nn.ModuleList([EncoderLayer(d_model, num_heads, d_ff, dropout) for _ in range(num_layers)])
        self.decoder_layers = nn.ModuleList([DecoderLayer(d_model, num_heads, d_ff, dropout) for _ in range(num_layers)])
        self.output_layer = nn.Linear(d_model, vocab_size)
    
    def forward(self, src_input, tgt_input, src_mask=None, tgt_mask=None): 
        src_emb = self.pos_encoding(self.embedding(src_input))
        tgt_emb = self.pos_encoding(self.embedding(tgt_input))
        
        enc_output = src_emb 
        for layer in self.encoder_layers:
            enc_output = layer(enc_output, src_mask)
        
        dec_output = tgt_emb
        for layer in self.decoder_layers:
            dec_output = layer(dec_output, enc_output, src_mask, tgt_mask)
        
        return self.output_layer(dec_output)
    
t5 = T5Model(20000, 768, 8, 2048, 6, 0.1)

sum = 0 

for name,param in t5.named_parameters(): 
    print(f"{name} ==> 参数量：{param.numel()}")
    sum += param.numel()
    
print(sum)

embedding.embedding.weight ==> 参数量：15360000
encoder_layers.0.self_attn.q_linear.weight ==> 参数量：589824
encoder_layers.0.self_attn.q_linear.bias ==> 参数量：768
encoder_layers.0.self_attn.k_linear.weight ==> 参数量：589824
encoder_layers.0.self_attn.k_linear.bias ==> 参数量：768
encoder_layers.0.self_attn.v_linear.weight ==> 参数量：589824
encoder_layers.0.self_attn.v_linear.bias ==> 参数量：768
encoder_layers.0.self_attn.out_linear.weight ==> 参数量：589824
encoder_layers.0.self_attn.out_linear.bias ==> 参数量：768
encoder_layers.0.ff.linear1.weight ==> 参数量：1572864
encoder_layers.0.ff.linear1.bias ==> 参数量：2048
encoder_layers.0.ff.linear2.weight ==> 参数量：1572864
encoder_layers.0.ff.linear2.bias ==> 参数量：768
encoder_layers.0.norm1.weight ==> 参数量：768
encoder_layers.0.norm1.bias ==> 参数量：768
encoder_layers.0.norm2.weight ==> 参数量：768
encoder_layers.0.norm2.bias ==> 参数量：768
encoder_layers.1.self_attn.q_linear.weight ==> 参数量：589824
encoder_layers.1.self_attn.q_linear.bias ==> 参数量：768
encoder_layers.1.self_attn.k_linear.weig

In [13]:
#  构建词汇表 (标点符号是不是也在内了)
from collections import Counter

train_en = [item["translation"]["en"] for item in dataset["train"]]
train_zh = [item["translation"]["zh"] for item in dataset["train"]]

en_words = [word.lower() for sentence in train_en for word in sentence.split()]
zh_chars = [char for sentence in train_zh for char in sentence]

en_counter = Counter(en_words)
zh_counter = Counter(zh_chars)

In [14]:
print(len(en_words))
print(len(zh_chars))
print(en_counter.total())
print(zh_counter.total())

16702317
32233389
16702317
32233389


In [15]:
min_freq = 5
en_vocab = [word for word, freq in en_counter.items() if freq >= min_freq]
zh_vocab = [char for char, freq in zh_counter.items() if freq >= min_freq]
print(len(en_vocab))
print(len(zh_vocab))

70658
6334


In [16]:
vocab = list(set(en_vocab + zh_vocab))

print(len(vocab))

vocab = ['<PAD>','<UNK>','<BOS>','<EOS>'] + vocab

print(len(vocab))

76901
76905


In [17]:
word_to_id = {word: idx for idx, word in enumerate(vocab)}
id_to_word = {idx: word for word, idx in word_to_id.items()}

print(f"词汇表大小：{len(vocab)}")

词汇表大小：76905


In [18]:
# 文本编码
def encode_text(text, word_to_id, is_zh = False): 
    if is_zh:
        tokens = list(text)
    else: 
        tokens = text.lower().split()
    token_ids = [word_to_id.get(token, word_to_id['<UNK>']) for token in tokens]
    return [word_to_id['<BOS>']] + token_ids + [word_to_id['<EOS>']]

In [31]:
# 创建Dataset
from torch.utils.data import Dataset, DataLoader

class TranslationDataset(Dataset): 
    def __init__(self, dataset, word_to_id, max_length = 10):
        self.dataset = dataset
        self.word_to_id = word_to_id
        self.max_length = max_length
    
    def __len__(self): 
        return len(self.dataset)
    
    def __getitem__(self, index):
        translation = self.dataset[index]['translation']
        src_ids = encode_text(translation["en"], self.word_to_id, is_zh=False)
        tgt_ids = encode_text(translation["zh"], self.word_to_id, is_zh=True)
        
        src_ids = src_ids[:self.max_length] + [self.word_to_id['<PAD>']]*(self.max_length - len(src_ids))
        tgt_input = tgt_ids[:-1][:self.max_length] + [self.word_to_id['<PAD>']]*(self.max_length - len(tgt_ids[:-1]))
        tgt_output = tgt_ids[1:][:self.max_length] + [self.word_to_id['<PAD>']]*(self.max_length - len(tgt_ids[1:]))
        
        return {
            'src_input': torch.tensor(src_ids, dtype=torch.long),
            'tgt_input': torch.tensor(tgt_input, dtype=torch.long),
            'tgt_output':torch.tensor(tgt_output, dtype=torch.long)
        }
        
train_dataset = TranslationDataset(dataset=dataset["train"], word_to_id=word_to_id)
train_loader = DataLoader(train_dataset, batch_size=2, shuffle=True)

In [20]:
#训练模型

vocab_size = len(vocab)
d_model = 512
num_heads = 8
d_ff = 2048
num_layers = 6
dropout = 0.1

model = T5Model(vocab_size, d_model, num_heads, d_ff, num_layers, dropout)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)

loss_fn = nn.CrossEntropyLoss(ignore_index=word_to_id['<PAD>'])
optimizer = torch.optim.Adam(model.parameters(), lr = 0.0001)

Exception ignored in: <bound method IPythonKernel._clean_thread_parent_frames of <ipykernel.ipkernel.IPythonKernel object at 0x7c69b676b550>>
Traceback (most recent call last):
  File "/home/edwin/miniconda3/envs/pytorch/lib/python3.11/site-packages/ipykernel/ipkernel.py", line 775, in _clean_thread_parent_frames
    def _clean_thread_parent_frames(

KeyboardInterrupt: 


In [21]:
param_num = [para.numel() for para in model.parameters()]
sum = 0 
for n in param_num:
    sum += n
    
print(sum)

122966121


In [38]:
def create_mask(src, tgt, pad_idx): 
    src_mask = (src != pad_idx).unsqueeze(1).unsqueeze(2)
    tgt_seq_len = tgt.size(1)
    tgt_mask = torch.tril(torch.ones((tgt_seq_len, tgt_seq_len))).bool().to(tgt.device)
    tgt_mask = tgt_mask & (tgt != pad_idx).unsqueeze(1).unsqueeze(2)
    return src_mask, tgt_mask 

In [None]:
from tqdm import tqdm

num_epochs = 1 
for epoch in range(num_epochs): 
    model.train()
    total_loss = 0
    for batch in tqdm(train_loader):
        src_input = batch['src_input'].to(device)
        tgt_input = batch['tgt_input'].to(device)
        tgt_output = batch['tgt_output'].to(device)
        
        src_mask, tgt_mask = create_mask(src_input, tgt_input, word_to_id['<PAD>'])
        print(tgt_mask.shape)
        
        break
        
    print(f"Epoch {epoch + 1} / {num_epochs}, Loss:{total_loss / len(train_loader)}")
    
        

  0%|          | 0/500000 [00:00<?, ?it/s]

torch.Size([2, 1, 10, 10])
Epoch 1 / 1, Loss:0.0





In [25]:
from tqdm import tqdm

num_epochs = 10 
for epoch in range(num_epochs): 
    model.train()
    total_loss = 0
    for batch in tqdm(train_loader):
        src_input = batch['src_input'].to(device)
        tgt_input = batch['tgt_input'].to(device)
        tgt_output = batch['tgt_output'].to(device)
        
        src_mask, tgt_mask = create_mask(src_input, tgt_input, word_to_id['<PAD>'])
        
        optimizer.zero_grad()
        logits = model(src_input, tgt_input, src_mask, tgt_mask)
        loss = loss_fn(logits.view(-1, vocab_size), tgt_output.view(-1))
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
        
    print(f"Epoch {epoch + 1} / {num_epochs}, Loss:{total_loss / len(train_loader)}")
    
        

  0%|          | 15/10417 [00:05<1:06:30,  2.61it/s]


KeyboardInterrupt: 

In [24]:
def translate(model, src_text, word_to_id, id_to_word, max_length=50): 
    model.eval()
    src_ids = encode_text(src_text, word_to_id, is_zh=False)
    src_ids = src_ids[:max_length] + [word_to_id['<PAD>']]*(max_length - len(src_ids))
    src_input = torch.tensor([src_ids], dtype=torch.long).to(device)
    src_mask = (src_input != word_to_id['<PAD>']).unsqueeze(1).unsqueeze(2)
    
    tgt_ids = [word_to_id['<BOS>']]
    for _ in range(max_length): 
        tgt_input = torch.tensor([tgt_ids], dtype=torch.long).to(device)
        tgt_mask = torch.tril(torch.ones((len(tgt_ids), len(tgt_ids)))).bool().to(device)
        tgt_mask = tgt_mask & (tgt_input != word_to_id['<PAD>']).unsqueeze(1).unsqueeze(2)
        
        with torch.no_grad():
            logits = model(src_input, tgt_input, src_mask, tgt_mask)
            next_token = logits[0, -1].argmax().item()
            
        if next_token == word_to_id['<EOS>']:
            break
        
        tgt_ids.append(next_token)
    
    return ''.join([id_to_word.get(idx, '<UNK>') for idx in tgt_ids[1:]])


src_text = "This is a test!"

translated = translate(model, src_text, word_to_id, id_to_word)

print(f"翻译结果：{translated}")

翻译结果：这是个疯狂的
