In [None]:
import torch
from torch import nn
import torch.nn.functional as F
import math

In [None]:
class TokenEmbedding(nn.Embedding):
    def __init__(self,vocab_size,d_model):
        super(TokenEmbedding,self).__init__(vocab_size,d_model,padding_idx=1)
class PositionalEmbedding(nn.Module):
    def __init__(self,d_model,max_len,device):
        super(PositionalEmbedding,self).__init__(max_len,d_model,device=device)
        self.encoding=torch.zeros(max_len,d_model,device=device)
        self.encoding.requires_grad=False
        pos=torch.arange(0,max_len,device=device)
        pos=pos.float().unsqueeze(dim=1)
        _2i=torch.arange(0,d_model,step=2,device=device).float()
        self.encoding[:,0::2]=torch.sin(pos/(10000**(_2i/d_model)))
        self.encoding[:,1::2]=torch.cos(pos/(10000**(_2i/d_model)))
    def forward(self,x):
        seq_len=x.size(1)
        return self.encoding[:,:seq_len,:]

class TransformerEmbedding(nn.Module):
    def __init__(self, vocab_size,d_model,max_len,dropout,device):
        super().__init__()
        self.tokenEmbedding=TokenEmbedding(vocab_size,d_model)
        self.positionalEmbedding=PositionalEmbedding(d_model,max_len,device)
        self.dropout=nn.Dropout(dropout)

    def forward(self,x):
        tokenEmb=self.tokenEmbedding(x)
        posEmb=self.positionalEmbedding(x)
        out=tokenEmb+posEmb
        return self.dropout(out)

In [None]:
class ScaledDotProductAttension(nn.Module):
    def __init__(self):
        super().__init__()
    def forward(self,Q,K,V,mask=None):
        d_k=Q.size(-1)
        scores=torch.matmul(Q,K.transpose(-1,-2))/math.sqrt(d_k)
        if mask is not None:
            scores = scores.masked_fill(mask == 0, float('-inf'))
        attn=torch.softmax(scores,dim=-1)
        output=torch.matmul(attn,V)
        return output,attn
class MultiHeadAttention(nn.Module):
    def __init__(self,d_model,num_heads,dropout=0.1):
        super().__init__()
        assert d_model%num_heads==0
        self.num_heads=num_heads
        self.d_model=d_model
        self.d_k=d_model//num_heads
        self.linear_Q=nn.Linear(d_model,d_model)
        self.linear_K=nn.Linear(d_model,d_model)
        self.linear_V=nn.Linear(d_model,d_model)
        self.attention=ScaledDotProductAttension()
        self.linear_out=nn.Linear(d_model,d_model)
        self.dropout=nn.Dropout(dropout)

    def forward(self,Q,K,V,mask=None):
        batch_size=Q.size(0)
        Q=self.linear_Q(Q).view(batch_size,-1,self.num_heads,self.d_k).permute(0,2,1,3)
        K=self.linear_K(K).view(batch_size,-1,self.num_heads,self.d_k).permute(0,2,1,3)
        V=self.linear_V(V).view(batch_size,-1,self.num_heads,self.d_k).permute(0,2,1,3)

        out,attn=self.attention(Q,K,V,mask)
        out=out.transpose(1,2).contiguous().view(batch_size,-1,self.num_heads*self.d_k)
        out=self.linear_out(out)
        out=self.dropout(out)
        return out,attn



In [None]:
class LayerNorm(nn.Module):
    def __init__(self,normalized_shape,eps=1e-5):
        super().__init__()
        self.eps=eps
        self.gamma=nn.Parameter(torch.ones(normalized_shape))
        self.beta=nn.Parameter(torch.zeros(normalized_shape))
    def forward(self,x):
        mean=x.mean(dim=-1,keepdim=True)
        var=x.var(dim=-1,keemdim=True)
        x_norm=(x-mean)/torch.sqrt(var+self.eps)
        return self.gamma*x_norm+self.beta

In [None]:
class EncodeLayer(nn.Module):
    def __init__(self,d_model,num_heads,d_ff,dropout=0.1):
        super().__init__()
        self.multiHeadAttention=MultiHeadAttention(d_model,num_heads)
        self.norm1=LayerNorm(d_model)
        self.ff=nn.Sequential(
            nn.Linear(d_model,d_ff),
            nn.ReLU(),
            nn.Linear(d_ff,d_model)
        )
        self.norm2=LayerNorm(d_model)
        self.dropout=nn.Dropout(dropout)
    def forward(self,x,mask=None):
        attn_out,out=self.multiHeadAttention(x,x,x,mask)
        x=self.norm1(x+self.dropout(attn_out))
        ff_out=self.ff(x)
        x=self.norm2(x+self.dropout(ff_out))
        return x


class Encoder(nn.Module):
    def __init__(self,num_layers,d_model,num_heads,d_ff,dropout=0.1):
        super().__init__()
        self.layers=nn.ModuleList([
            EncodeLayer(d_model,num_heads,d_ff,dropout)
            for _ in range(num_layers)
        ])
    def forward(self,x,mask=None):
        for layer in self.layers:
            x=layer(x,mask)
        return x
        

In [None]:
class DecodeLayer(nn.Module):
    def __init__(self,d_model,num_heads,d_ff,dropout=0.1):
        super().__init__()
        self.maskedMultiHeadAttention=MultiHeadAttention(d_model,num_heads)
        self.multiHeadAttention=MultiHeadAttention(d_model,num_heads)
        self.norm1=LayerNorm(d_model)
        self.norm2=LayerNorm(d_model)
        self.norm3=LayerNorm(d_model)
        self.ff=nn.Sequential(
            nn.Linear(d_model,d_ff),
            nn.ReLU(),
            nn.Linear(d_ff,d_model)
        )
        self.dropout=nn.Dropout()

    def forward(self,x,encoder_output,tgt_mask,memory_mask):
        out1,_=self.maskedMultiHeadAttention(x,x,x,tgt_mask)
        x=self.norm1(x+self.dropout(out1))
        out2,_=self.multiHeadAttention(x,encoder_output,x,memory_mask)
        x=self.norm2(x+self.dropout(out2))
        out3=self.ff(x)
        x=self.norm3(x+self.dropout(out3))
        return x


class Decoder(nn.Module):
    def __init__(self,num_layers,d_model,num_heads,d_ff,dropout=0.1):
        super().__init__()
        self.layers = nn.ModuleList([
            DecodeLayer(d_model, num_heads, d_ff, dropout)
            for _ in range(num_layers)
        ])

    def forward(self,x,encoder_output,tgt_mask=None,memery_mask=None):
        for layer in self.layers:
            x=layer(x,encoder_output,tgt_mask,memery_mask)
        return x

In [None]:
class Transformer(nn.Module):
    def __init__(self, vocab_size,d_model,max_len,num_layers,num_heads,d_ff,dropout,device):
        super().__init__()
        self.embedding=TransformerEmbedding(vocab_size,d_model,max_len,dropout,device)
        self.encoder=Encoder(num_layers,d_model,num_heads,d_ff,dropout)
        self.decoder=Decoder(num_layers,d_model,num_heads,d_ff,dropout)
        self.output_layer=nn.Linear(d_model,vocab_size)

    def forward(self,src,tgt,src_mask=None,tgt_mask=None,memery_mask=None):
        src_emb=self.embedding(src)
        tgt_emb=self.embedding(tgt)
        memory=self.encoder(src_emb,src_mask)
        out=self.decoder(tgt_emb,memory,tgt_mask,memery_mask)
        logits=self.output_layer(out)
        return logits

In [None]:
from torch.utils.data import DataLoader
from torch.optim import Adam

# 示例：使用Transformer模型进行训练和推理
# 这里以英文小型翻译任务为例（如英-法），可用torchtext的数据集
# 若无法自动下载，可手动下载Multi30k数据集：https://github.com/multi30k/dataset


# 假设你已经有如下变量和数据集
# src_vocab_size, tgt_vocab_size, src_pad_idx, tgt_pad_idx
# train_dataloader, valid_dataloader
# 你可以用torchtext.datasets.Multi30k和torchtext.vocab.build_vocab_from_iterator来构建

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# 参数定义
vocab_size = 10000  # 示例，实际应为你的词表大小
d_model = 512
max_len = 100
num_layers = 4
num_heads = 8
d_ff = 2048
dropout = 0.1

# 实例化模型
model = Transformer(
    vocab_size=vocab_size,
    d_model=d_model,
    max_len=max_len,
    num_layers=num_layers,
    num_heads=num_heads,
    d_ff=d_ff,
    dropout=dropout,
    device=device
).to(device)

optimizer = Adam(model.parameters(), lr=1e-4)
criterion = nn.CrossEntropyLoss(ignore_index=1)  # 假设padding_idx=1

# 示例训练循环
for epoch in range(10):
    model.train()
    for src, tgt in train_dataloader:
        src, tgt = src.to(device), tgt.to(device)
        tgt_input = tgt[:, :-1]
        tgt_output = tgt[:, 1:]
        logits = model(src, tgt_input)
        loss = criterion(logits.reshape(-1, logits.size(-1)), tgt_output.reshape(-1))
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    print(f"Epoch {epoch} Loss: {loss.item()}")

# 推理示例（贪婪解码）
def greedy_decode(model, src, max_len, start_symbol):
    src = src.to(device)
    src_mask = None
    memory = model.encoder(model.embedding(src), src_mask)
    ys = torch.ones(1, 1).fill_(start_symbol).type_as(src).to(device)
    for i in range(max_len-1):
        out = model.decoder(model.embedding(ys), memory)
        prob = model.output_layer(out[:, -1])
        _, next_word = torch.max(prob, dim=1)
        ys = torch.cat([ys, next_word.unsqueeze(1)], dim=1)
        if next_word.item() == 1:  # 假设1是<eos>
            break
    return ys

# 数据集准备参考：
# https://pytorch.org/text/stable/tutorials/translation_transformer.html
# 或 https://github.com/multi30k/dataset