In [175]:
import torch
from torch import nn
import torch.nn.functional as F
import math

In [176]:
class TokenEmbedding(nn.Embedding):
    def __init__(self,vocab_size,d_model):
        super(TokenEmbedding,self).__init__(vocab_size,d_model,padding_idx=1)
class PositionalEmbedding(nn.Module):
    def __init__(self,d_model,max_len,device):
        super(PositionalEmbedding,self).__init__()
        self.encoding=torch.zeros(max_len,d_model,device=device)
        self.encoding.requires_grad=False
        pos=torch.arange(0,max_len,device=device)
        pos=pos.float().unsqueeze(dim=1)
        _2i=torch.arange(0,d_model,step=2,device=device).float()
        self.encoding[:,0::2]=torch.sin(pos/(10000**(_2i/d_model)))
        self.encoding[:,1::2]=torch.cos(pos/(10000**(_2i/d_model)))
    def forward(self,x):
        seq_len=x.size(1)
        return self.encoding[:seq_len,:]

class TransformerEmbedding(nn.Module):
    def __init__(self, vocab_size,d_model,max_len,dropout,device):
        super().__init__()
        self.tokenEmbedding=TokenEmbedding(vocab_size,d_model)
        self.positionalEmbedding=PositionalEmbedding(d_model,max_len,device)
        self.dropout=nn.Dropout(dropout)

    def forward(self,x):
        tokenEmb=self.tokenEmbedding(x)
        posEmb=self.positionalEmbedding(x)
        out=tokenEmb+posEmb
        return self.dropout(out)

In [177]:
class ScaledDotProductAttension(nn.Module):
    def __init__(self):
        super().__init__()
    def forward(self,Q,K,V,mask=None):
        d_k=Q.size(-1)
        scores=torch.matmul(Q,K.transpose(-1,-2))/math.sqrt(d_k)
        if mask is not None:
            scores = scores.masked_fill(mask == 0, float('-inf'))
        attn=torch.softmax(scores,dim=-1)
        output=torch.matmul(attn,V)
        return output,attn
class MultiHeadAttention(nn.Module):
    def __init__(self,d_model,num_heads,dropout=0.1):
        super().__init__()
        assert d_model%num_heads==0
        self.num_heads=num_heads
        self.d_model=d_model
        self.d_k=d_model//num_heads
        self.linear_Q=nn.Linear(d_model,d_model)
        self.linear_K=nn.Linear(d_model,d_model)
        self.linear_V=nn.Linear(d_model,d_model)
        self.attention=ScaledDotProductAttension()
        self.linear_out=nn.Linear(d_model,d_model)
        self.dropout=nn.Dropout(dropout)

    def forward(self,Q,K,V,mask=None):
        batch_size=Q.size(0)
        Q=self.linear_Q(Q).view(batch_size,-1,self.num_heads,self.d_k).permute(0,2,1,3)
        K=self.linear_K(K).view(batch_size,-1,self.num_heads,self.d_k).permute(0,2,1,3)
        V=self.linear_V(V).view(batch_size,-1,self.num_heads,self.d_k).permute(0,2,1,3)

        out,attn=self.attention(Q,K,V,mask)
        out=out.transpose(1,2).contiguous().view(batch_size,-1,self.num_heads*self.d_k)
        out=self.linear_out(out)
        out=self.dropout(out)
        return out,attn



In [178]:
class LayerNorm(nn.Module):
    def __init__(self,normalized_shape,eps=1e-5):
        super().__init__()
        self.eps=eps
        self.gamma=nn.Parameter(torch.ones(normalized_shape))
        self.beta=nn.Parameter(torch.zeros(normalized_shape))
    def forward(self,x):
        mean=x.mean(dim=-1,keepdim=True)
        var=x.var(dim=-1,keepdim=True)
        x_norm=(x-mean)/torch.sqrt(var+self.eps)
        return self.gamma*x_norm+self.beta

In [179]:
class EncodeLayer(nn.Module):
    def __init__(self,d_model,num_heads,d_ff,dropout=0.1):
        super().__init__()
        self.multiHeadAttention=MultiHeadAttention(d_model,num_heads)
        self.norm1=LayerNorm(d_model)
        self.ff=nn.Sequential(
            nn.Linear(d_model,d_ff),
            nn.ReLU(),
            nn.Linear(d_ff,d_model)
        )
        self.norm2=LayerNorm(d_model)
        self.dropout=nn.Dropout(dropout)
    def forward(self,x,mask=None):
        attn_out,out=self.multiHeadAttention(x,x,x,mask)
        x=self.norm1(x+self.dropout(attn_out))
        ff_out=self.ff(x)
        x=self.norm2(x+self.dropout(ff_out))
        return x


class Encoder(nn.Module):
    def __init__(self,num_layers,d_model,num_heads,d_ff,dropout=0.1):
        super().__init__()
        self.layers=nn.ModuleList([
            EncodeLayer(d_model,num_heads,d_ff,dropout)
            for _ in range(num_layers)
        ])
    def forward(self,x,mask=None):
        for layer in self.layers:
            x=layer(x,mask)
        return x
        

In [180]:
class DecodeLayer(nn.Module):
    def __init__(self,d_model,num_heads,d_ff,dropout=0.1):
        super().__init__()
        self.maskedMultiHeadAttention=MultiHeadAttention(d_model,num_heads)
        self.multiHeadAttention=MultiHeadAttention(d_model,num_heads)
        self.norm1=LayerNorm(d_model)
        self.norm2=LayerNorm(d_model)
        self.norm3=LayerNorm(d_model)
        self.ff=nn.Sequential(
            nn.Linear(d_model,d_ff),
            nn.ReLU(),
            nn.Linear(d_ff,d_model)
        )
        self.dropout=nn.Dropout()

    def forward(self,x,encoder_output,tgt_mask,memory_mask):
        out1,_=self.maskedMultiHeadAttention(x,x,x,tgt_mask)
        x=self.norm1(x+self.dropout(out1))
        out2,_=self.multiHeadAttention(x,encoder_output,x,memory_mask)
        x=self.norm2(x+self.dropout(out2))
        out3=self.ff(x)
        x=self.norm3(x+self.dropout(out3))
        return x


class Decoder(nn.Module):
    def __init__(self,num_layers,d_model,num_heads,d_ff,dropout=0.1):
        super().__init__()
        self.layers = nn.ModuleList([
            DecodeLayer(d_model, num_heads, d_ff, dropout)
            for _ in range(num_layers)
        ])

    def forward(self,x,encoder_output,tgt_mask=None,memery_mask=None):
        for layer in self.layers:
            x=layer(x,encoder_output,tgt_mask,memery_mask)
        return x

In [181]:
class Transformer(nn.Module):
    def __init__(self, vocab_size,d_model,max_len,num_layers,num_heads,d_ff,dropout,device):
        super().__init__()
        self.embedding=TransformerEmbedding(vocab_size,d_model,max_len,dropout,device)
        self.encoder=Encoder(num_layers,d_model,num_heads,d_ff,dropout)
        self.decoder=Decoder(num_layers,d_model,num_heads,d_ff,dropout)
        self.output_layer=nn.Linear(d_model,vocab_size)

    def forward(self,src,tgt,src_mask=None,tgt_mask=None,memery_mask=None):
        src_emb=self.embedding(src)
        tgt_emb=self.embedding(tgt)
        memory=self.encoder(src_emb,src_mask)
        out=self.decoder(tgt_emb,memory,tgt_mask,memery_mask)
        logits=self.output_layer(out)
        return logits

In [182]:
import os
import gzip

DATASET_PATH = "./dataset/data/task1/raw"

def load_data(file_path):
    with gzip.open(file_path, "rt", encoding="utf-8") as f:
        return f.readlines() 

train_en = load_data(os.path.join(DATASET_PATH, "train.en.gz"))
train_de = load_data(os.path.join(DATASET_PATH, "train.de.gz"))
val_en = load_data(os.path.join(DATASET_PATH, "val.en.gz"))
val_de = load_data(os.path.join(DATASET_PATH, "val.de.gz"))
test_en = load_data(os.path.join(DATASET_PATH, "test_2017_flickr.en.gz"))
test_de = load_data(os.path.join(DATASET_PATH, "test_2017_flickr.de.gz"))

print(f"训练集英文样本数: {len(train_en)}")
print(f"训练集德文样本数: {len(train_de)}")

训练集英文样本数: 29000
训练集德文样本数: 29000


In [183]:
from collections import Counter
import math

def compute_bleu(candidate, references, max_n=4):
    """
    计算 BLEU 分数
    :param candidate: 模型生成的翻译结果（列表）
    :param references: 参考翻译（列表的列表）
    :param max_n: 最大 n-gram
    :return: BLEU 分数
    """
    weights = [1.0 / max_n] * max_n  # 平均权重
    p_n = [0] * max_n
    candidate_len = len(candidate)
    reference_lens = [len(ref) for ref in references]
    
    # 计算 n-gram 精确度
    for n in range(1, max_n + 1):
        candidate_ngrams = Counter([tuple(candidate[i:i+n]) for i in range(len(candidate) - n + 1)])
        max_counts = Counter()
        for ref in references:
            ref_ngrams = Counter([tuple(ref[i:i+n]) for i in range(len(ref) - n + 1)])
            for ngram in candidate_ngrams:
                max_counts[ngram] = max(max_counts[ngram], ref_ngrams[ngram])
        clipped_counts = {ngram: min(count, max_counts[ngram]) for ngram, count in candidate_ngrams.items()}
        p_n[n-1] = sum(clipped_counts.values()) / max(1, sum(candidate_ngrams.values()))
    
    # BP（brevity penalty）计算
    closest_ref_len = min(reference_lens, key=lambda ref_len: abs(ref_len - candidate_len))
    bp = math.exp(1 - closest_ref_len / candidate_len) if candidate_len < closest_ref_len else 1.0
    
    # BLEU 分数计算
    bleu = bp * math.exp(sum(w * math.log(p) for w, p in zip(weights, p_n) if p > 0))
    return bleu

In [184]:
import torch
from torch.nn.utils.rnn import pad_sequence


def data_process(src_data, tgt_data, SRC, TGT):
    data = []
    for src_line, tgt_line in zip(src_data, tgt_data):
        src_tensor = torch.tensor([SRC.vocab.stoi[token] for token in SRC.tokenize(src_line)], dtype=torch.long)
        tgt_tensor = torch.tensor([TGT.vocab.stoi[token] for token in TGT.tokenize(tgt_line)], dtype=torch.long)
        src_tensor = torch.cat([torch.tensor([SRC.vocab.stoi["<bos>"]]), src_tensor, torch.tensor([SRC.vocab.stoi["<eos>"]])])
        tgt_tensor = torch.cat([torch.tensor([TGT.vocab.stoi["<bos>"]]), tgt_tensor, torch.tensor([TGT.vocab.stoi["<eos>"]])])
        data.append((src_tensor, tgt_tensor))
    return data



def yield_tokens(data_iter, tokenizer):
    for text in data_iter:
        yield tokenizer(text)

from torchtext.data import Field



# 定义 Field
SRC = Field(tokenize=src_tokenizer, init_token="<bos>", eos_token="<eos>", pad_token="<pad>", unk_token="<unk>")
TGT = Field(tokenize=tgt_tokenizer, init_token="<bos>", eos_token="<eos>", pad_token="<pad>", unk_token="<unk>")

# 构建词汇表
SRC.build_vocab(train_en, specials=["<unk>", "<pad>", "<bos>", "<eos>"])
TGT.build_vocab(train_de, specials=["<unk>", "<pad>", "<bos>", "<eos>"])

src_vocab = SRC.vocab
tgt_vocab = TGT.vocab

train_data = data_process(train_en, train_de, SRC, TGT)
val_data = data_process(val_en, val_de, SRC, TGT)
test_data = data_process(test_en, test_de, SRC, TGT)

def collate_fn(batch):
    src_batch, tgt_batch = [], []
    for src_sample, tgt_sample in batch:
        src_batch.append(src_sample)
        tgt_batch.append(tgt_sample)
    src_batch = pad_sequence(src_batch, padding_value=SRC.vocab.stoi["<pad>"])
    tgt_batch = pad_sequence(tgt_batch, padding_value=TGT.vocab.stoi["<pad>"])
    return src_batch, tgt_batch

from torch.utils.data import DataLoader

train_dataloader = DataLoader(train_data, batch_size=32, shuffle=True, collate_fn=collate_fn)
val_dataloader = DataLoader(val_data, batch_size=32, shuffle=False, collate_fn=collate_fn)
test_dataloader = DataLoader(test_data, batch_size=32, shuffle=False, collate_fn=collate_fn)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# 初始化模型
vocab_size = len(SRC.vocab)
d_model = 512
max_len = 100
num_layers = 6
num_heads = 8
d_ff = 2048
dropout = 0.1

model = Transformer(vocab_size, d_model, max_len, num_layers, num_heads, d_ff, dropout, device).to(device)

# 定义优化器和损失函数
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
criterion = torch.nn.CrossEntropyLoss(ignore_index=SRC.vocab.stoi["<pad>"])

# 训练循环
def train_epoch(model, dataloader, optimizer, criterion, device):
    model.train()
    total_loss = 0
    for src_batch, tgt_batch in dataloader:
        src_batch, tgt_batch = src_batch.to(device), tgt_batch.to(device)
        tgt_input = tgt_batch[:-1, :]
        tgt_output = tgt_batch[1:, :]
        optimizer.zero_grad()
        logits = model(src_batch, tgt_input)
        loss = criterion(logits.view(-1, vocab_size), tgt_output.view(-1))
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    return total_loss / len(dataloader)

# 训练过程
num_epochs = 10
for epoch in range(num_epochs):
    train_loss = train_epoch(model, train_dataloader, optimizer, criterion, device)
    print(f"Epoch {epoch + 1}, Train Loss: {train_loss:.4f}")

torch.save(model.state_dict(),"transformer_model.pth")

RuntimeError: shape '[24, -1, 8, 64]' is invalid for input of size 409600

In [None]:
model.load_state_dict(torch.load("transformer_model.pth"))
model.eval()

def greedy_decode(model, src, max_len, start_symbol, device):
    src = src.to(device)
    memory = model.encoder(model.embedding(src))
    ys = torch.ones(1, 1).fill_(start_symbol).type_as(src).to(device)
    for i in range(max_len - 1):
        out = model.decoder(model.embedding(ys), memory)
        prob = model.output_layer(out[:, -1])
        _, next_word = torch.max(prob, dim=1)
        ys = torch.cat([ys, next_word.unsqueeze(0)], dim=1)
        if next_word.item() == SRC.vocab.stoi["<eos>"]:
            break
    return ys
    
def evaluate_bleu(model, dataloader, SRC, TGT, device):
    """
    测试模型并计算 BLEU 分数
    :param model: Transformer 模型
    :param dataloader: 测试数据加载器
    :param SRC: 源语言 Field
    :param TGT: 目标语言 Field
    :param device: 设备
    :return: BLEU 分数
    """
    model.eval()
    total_bleu = 0
    num_samples = 0
    with torch.no_grad():
        for src_batch, tgt_batch in dataloader:
            src_batch, tgt_batch = src_batch.to(device), tgt_batch.to(device)
            for i in range(src_batch.size(1)):
                # 贪婪解码生成翻译结果
                result = greedy_decode(model, src_batch[:, i], max_len=100, start_symbol=TGT.vocab.stoi["<bos>"], device=device)
                candidate = [TGT.vocab.itos[token.item()] for token in result if token.item() not in [TGT.vocab.stoi["<pad>"], TGT.vocab.stoi["<bos>"], TGT.vocab.stoi["<eos>"]]]
                reference = [[TGT.vocab.itos[token.item()] for token in tgt_batch[:, i] if token.item() not in [TGT.vocab.stoi["<pad>"], TGT.vocab.stoi["<bos>"], TGT.vocab.stoi["<eos>"]]]]
                # 计算 BLEU 分数
                total_bleu += compute_bleu(candidate, reference)
                num_samples += 1
    return total_bleu / num_samples

# 计算 BLEU 分数
bleu_score = evaluate_bleu(model, test_dataloader, SRC, TGT, device)
print(f"BLEU Score: {bleu_score:.4f}")

def print_translations(model, dataloader, SRC, TGT, device, num_samples=5):
    """
    打印模型翻译结果与参考翻译
    :param model: Transformer 模型
    :param dataloader: 测试数据加载器
    :param SRC: 源语言 Field
    :param TGT: 目标语言 Field
    :param device: 设备
    :param num_samples: 打印样本数量
    """
    model.eval()
    count = 0
    with torch.no_grad():
        for src_batch, tgt_batch in dataloader:
            src_batch, tgt_batch = src_batch.to(device), tgt_batch.to(device)
            for i in range(src_batch.size(1)):
                if count >= num_samples:
                    return
                # 贪婪解码生成翻译结果
                result = greedy_decode(model, src_batch[:, i], max_len=100, start_symbol=TGT.vocab.stoi["<bos>"], device=device)
                source = [SRC.vocab.itos[token.item()] for token in src_batch[:, i] if token.item() not in [SRC.vocab.stoi["<pad>"], SRC.vocab.stoi["<bos>"], SRC.vocab.stoi["<eos>"]]]
                target = [TGT.vocab.itos[token.item()] for token in tgt_batch[:, i] if token.item() not in [TGT.vocab.stoi["<pad>"], TGT.vocab.stoi["<bos>"], TGT.vocab.stoi["<eos>"]]]
                prediction = [TGT.vocab.itos[token.item()] for token in result if token.item() not in [TGT.vocab.stoi["<pad>"], TGT.vocab.stoi["<bos>"], TGT.vocab.stoi["<eos>"]]]
                print(f"Source: {' '.join(source)}")
                print(f"Target: {' '.join(target)}")
                print(f"Prediction: {' '.join(prediction)}")
                print("-" * 50)
                count += 1

# 打印翻译结果
print_translations(model, test_dataloader, SRC, TGT, device, num_samples=5)