In [5]:
import torch
from torch import nn
import math

In [6]:
# 建立Dataloader

from toy_datasets import get_dataset_AddSeq

# 这里只考虑batch_first=True的情况
data_num = 10000
dataset, collate_fn, (index_bos, index_eos, index_pad, index_add, vocab_size) = (
    get_dataset_AddSeq(data_num=data_num)
)  # 一个带符号的个位数加法任务的数据集

train_size = 0.2
train_dataset = dataset[: int(data_num * train_size)]
eval_dataset = dataset[int(data_num * train_size) :]
batch_size = 64
train_dataloader = torch.utils.data.DataLoader(
    dataset=train_dataset, batch_size=batch_size, collate_fn=collate_fn
)
eval_dataloader = torch.utils.data.DataLoader(
    dataset=eval_dataset, batch_size=batch_size, collate_fn=collate_fn
)

In [7]:
train_dataset[:5]

[{'src': tensor([12,  9, 10,  4, 13]), 'tgt': tensor([12,  1,  3, 13])},
 {'src': tensor([12,  9, 10,  5, 13]), 'tgt': tensor([12,  1,  4, 13])},
 {'src': tensor([12,  8, 10,  6, 13]), 'tgt': tensor([12,  1,  4, 13])},
 {'src': tensor([12,  9, 10,  3, 13]), 'tgt': tensor([12,  1,  2, 13])},
 {'src': tensor([12,  8, 10,  6, 13]), 'tgt': tensor([12,  1,  4, 13])}]

In [8]:
# 多头注意力
class MultiHeadAttention(nn.Module):
    """
    多头注意力层，用于计算多个注意力头的输出
    """
    def __init__(self, d_model, num_heads):
        super(MultiHeadAttention, self).__init__()
        assert d_model % num_heads == 0, "d_model must be divisible by num_heads"

        self.d_model = d_model
        self.num_heads = num_heads
        self.d_k = d_model // num_heads  # 除取整

        self.W_q = nn.Linear(d_model, d_model)  # （输入的特征维度，输出的特征维度）
        self.W_k = nn.Linear(d_model, d_model)
        self.W_v = nn.Linear(d_model, d_model)
        self.W_o = nn.Linear(d_model, d_model)

    def scaled_dot_product_attention(self, Q, K, V, mask=None):
        """
        计算多头注意力的输出
        """
        # Q (N,n_head,S,d_k)
        # K (N,n_head,S,d_k)
        # V (N,n_head,S,d_k)
        attn_scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)
        # attn_scores (N,n_head,S,S)
        if mask is not None: 
            # mask (N,1,1,S)
            attn_scores = attn_scores.masked_fill(mask == 0, -1e9) # attn_scores里mask为0的地方，用负无穷填充
        attn_probs = torch.softmax(attn_scores, dim=-1)
        # attn_probs (N,n_head,S,S)
        output = torch.matmul(attn_probs, V)
        # output (N,n_head,S,d_k)
        return output

    def split_heads(self, x):
        """
        分割多头注意力的输入，将输入的特征维度分割成多个头
        """
        # (N,S,D)
        batch_size, seq_length, d_model = x.size()
        # (N,S,n_head,d_k)
        # (N,n_head,S,d_k)
        return x.view(batch_size, seq_length, self.num_heads, self.d_k).transpose(1, 2)
        # 若将tensor的第一维与第二维转置，则：
        # x_new[i][j][k] = x[i][k][j]
        # x_new[i][0][k] = x[i][k][0]
        # x_new[i][1][k] = x[i][k][1]
        # x_new[i][2][k] = x[i][k][2]
        # x_new[i][3][k] = x[i][k][3]

    def combine_heads(self, x):
        """
        组合多头注意力的输出，将多个头的输出拼接起来
        """
        # x (N,n_head,S,d_k)
        batch_size, _, seq_length, d_k = x.size()
        return x.transpose(1, 2).contiguous().view(batch_size, seq_length, self.d_model)

    def forward(self, Q, K, V, mask=None):
        # Q (N,S,D)
        Q = self.split_heads(self.W_q(Q))
        # Q (N,n_head,S,d_k)
        K = self.split_heads(self.W_k(K))
        V = self.split_heads(self.W_v(V))

        attn_output = self.scaled_dot_product_attention(Q, K, V, mask)
        # attn_output (N,n_head,S,d_k)
        output = self.W_o(self.combine_heads(attn_output))

        return output


class PositionWiseFeedForward(nn.Module):
    """
    前馈神经网络，包含两个全连接层和ReLU激活函数
    """
    def __init__(self, d_model, d_ff):
        super(PositionWiseFeedForward, self).__init__()
        self.fc1 = nn.Linear(d_model, d_ff)  
        self.fc2 = nn.Linear(d_ff, d_model)
        self.relu = nn.ReLU()

    def forward(self, x):
        return self.fc2(self.relu(self.fc1(x))) 


# 位置编码
class PositionalEncoding(nn.Module):
    """
    位置编码，用于将输入序列中的每个位置映射到一个向量中。解决了原版attention中学习不到位置信息的问题。
    """
    def __init__(self, d_model, max_seq_length):
        super(PositionalEncoding, self).__init__()

        pe = torch.zeros(
            max_seq_length, d_model
        )  # 最大序列长度为行，特征维度为列，一个矩阵，和输入一样
        position = torch.arange(0, max_seq_length, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(
            torch.arange(0, d_model, 2).float() * -(math.log(10000.0) / d_model)
        )
        pe[:, 0::2] = torch.sin(position * div_term)  # 所有行取偶数索引（奇数列）
        pe[:, 1::2] = torch.cos(position * div_term)  # 所有行取奇数索引

        self.register_buffer("pe", pe.unsqueeze(0))

    def forward(self, x):
        return x + self.pe[:, : x.size(1)]


class EncoderLayer(nn.Module):
    def __init__(self, d_model, num_heads, d_ff, dropout):
        super(EncoderLayer, self).__init__()
        # 实例化，用于forward函数
        self.self_attn = MultiHeadAttention(d_model, num_heads)
        self.feed_forward = PositionWiseFeedForward(d_model, d_ff)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, mask):
        # 输入为x，qkv都一样，在多头注意力层里的前向传播将相同的x转为qkv矩阵
        attn_output = self.self_attn(x, x, x, mask)
        # 短接
        x = self.norm1(x + self.dropout(attn_output))
        # 送入前馈神经网络
        ff_output = self.feed_forward(x)
        x = self.norm2(x + self.dropout(ff_output))
        return x


class DecoderLayer(nn.Module):
    def __init__(self, d_model, num_heads, d_ff, dropout):
        super(DecoderLayer, self).__init__()

        self.self_attn = MultiHeadAttention(d_model, num_heads)
        self.cross_attn = MultiHeadAttention(d_model, num_heads)
        self.feed_forward = PositionWiseFeedForward(d_model, d_ff)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.norm3 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, enc_output, src_mask, tgt_mask):

        attn_output = self.self_attn(x, x, x, tgt_mask)
        x = self.norm1(x + self.dropout(attn_output))

        attn_output = self.cross_attn(x, enc_output, enc_output, src_mask)
        x = self.norm2(x + self.dropout(attn_output))
        ff_output = self.feed_forward(x)
        x = self.norm3(x + self.dropout(ff_output))
        return x


class Transformer(nn.Module):
    def __init__(
        self,
        vocab_size,
        d_model,
        num_heads,
        num_layers,
        d_ff,
        max_seq_length,
        dropout,
    ):
        super(Transformer, self).__init__()
        self.embedding = nn.Embedding(vocab_size, d_model)
        self.positional_encoding = PositionalEncoding(d_model, max_seq_length)
        self.encoder_layers = nn.ModuleList(
            [EncoderLayer(d_model, num_heads, d_ff, dropout) for _ in range(num_layers)]
        )
        self.decoder_layers = nn.ModuleList(
            [DecoderLayer(d_model, num_heads, d_ff, dropout) for _ in range(num_layers)]
        )
        self.fc_out = nn.Linear(d_model, vocab_size) # 分类头
        self.dropout = nn.Dropout(dropout)

    def generate_mask(self, src, tgt)->tuple[torch.Tensor,torch.Tensor]:
        # src (N,S)
        # tgt (N,T)
        src_mask = (src != index_pad).unsqueeze(1).unsqueeze(2) # (N,1,1,S)
        tgt_mask = (tgt != index_pad).unsqueeze(1).unsqueeze(3) # (N,1,T,1)

        seq_length = tgt.size(1)
        nopeak_mask = (
            1 - torch.triu(torch.ones(1, seq_length, seq_length), diagonal=1)
        ).bool().to(tgt.device)
        # decoder的遮蔽未来信息
        tgt_mask = tgt_mask & nopeak_mask
        # src_mask是遮蔽填充信息，tgt_mask是遮蔽填充和未来信息
        
        return src_mask, tgt_mask

    def forward(self, src, tgt)->torch.Tensor:
        src_mask, tgt_mask = self.generate_mask(src, tgt) # 根据src和tgt，生成mask
        src_embedded = self.dropout(self.positional_encoding(self.embedding(src)))
        tgt_embedded = self.dropout(self.positional_encoding(self.embedding(tgt)))

        enc_output = src_embedded
        for enc_layer in self.encoder_layers:
            enc_output = enc_layer(enc_output, src_mask)

        dec_output = tgt_embedded
        for dec_layer in self.decoder_layers:
            dec_output = dec_layer(dec_output, enc_output, src_mask, tgt_mask) # decoder这里做的是cross attention，所以需要src_mask
        output = self.fc_out(dec_output)
        return output
    @torch.inference_mode()
    def predict(self,src:torch.Tensor,tgt:torch.Tensor,max_seq_length:int,index_eos:int)->list:
        batch_size,tgt_len=tgt.shape
        output=[]
        for i in range(tgt_len,max_seq_length):
            logits=self.forward(src,tgt) # (batch_size,tgt_len,vocab_size)
            next_token=logits[:,-1,:].argmax(dim=-1) 
            output.append(next_token.item())
            if next_token.item()==index_eos:
                break
            tgt=torch.cat([tgt,next_token.unsqueeze(0)],dim=-1)
        return output

In [17]:
num_heads = 8
d_ff = 128
d_model = 64
num_layers = 3
dropout = 0.1
max_seq_length = 16  # 用来初始化位置编码的矩阵的最大长度
device = "cuda:0"

model = Transformer(
    vocab_size,
    d_model,
    num_heads,
    num_layers,
    d_ff,
    max_seq_length,
    dropout,
).to(device) # 把模型里所有继承了nn.Module的参数都移动到device上

loss_func = torch.nn.CrossEntropyLoss(ignore_index=index_pad)
optimizer = torch.optim.Adam(model.parameters(), lr=0.0005, betas=(0.9, 0.98), eps=1e-9)

In [18]:
def train(model, train_dataloader, loss_func, optimizer, device):
    model.train()
    losses = []
    for batch in train_dataloader:
        src = batch["src"].to(device)
        tgt = batch["tgt"].to(device)
        tgt_input = tgt[:,:-1] # tgt的输入是tgt的前n-1个token
        logits = model(src, tgt_input)
        tgt_expected = tgt[:, 1:]  # tgt的期望输出是tgt的第2个token到第n个token
        tgt_expected = tgt_expected.reshape(-1)
        logits = logits.view(-1, logits.shape[-1])
        loss = loss_func(logits, tgt_expected)
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        losses.append(loss.item())
    return sum(losses) / (len(losses) * batch_size)


@torch.inference_mode()  # 验证的时候关闭梯度计算
def eval(model, eval_dataloader, loss_func, device):
    model.eval()
    losses = []
    for batch in eval_dataloader:
        src = batch["src"].to(device)
        tgt = batch["tgt"].to(device)
        tgt_input = tgt[:,:-1]
        logits = model(src, tgt_input)
        tgt_expected = tgt[:, 1:]  
        tgt_expected = tgt_expected.reshape(-1)
        logits = logits.view(-1, logits.shape[-1])
        loss = loss_func(logits, tgt_expected)
        losses.append(loss.item())
    return sum(losses) / (len(losses) * batch_size)

In [19]:
epochs = 20 # transformer模型需要的训练轮数一般更多一些
for epoch in range(epochs):
    train_loss = train(
        model=model,
        train_dataloader=train_dataloader,
        loss_func=loss_func,
        optimizer=optimizer,
        device=device,
    )
    eval_loss = eval(
        model=model, eval_dataloader=eval_dataloader, loss_func=loss_func, device=device
    )
    print(f"epoch:{epoch}, train_loss:{train_loss}, eval_loss:{eval_loss}")

epoch:0, train_loss:0.0263607146916911, eval_loss:0.02027321571111679
epoch:1, train_loss:0.0184405570034869, eval_loss:0.01643575119972229
epoch:2, train_loss:0.01498383900616318, eval_loss:0.01360978977382183
epoch:3, train_loss:0.012155225442256778, eval_loss:0.010887534238398075
epoch:4, train_loss:0.009749313205247745, eval_loss:0.008803229186683893
epoch:5, train_loss:0.008024444163311273, eval_loss:0.0073898396268486976
epoch:6, train_loss:0.006653803313383833, eval_loss:0.0062240084148943425
epoch:7, train_loss:0.005472302509588189, eval_loss:0.005032935362309218
epoch:8, train_loss:0.004558530366921332, eval_loss:0.004113511744886637
epoch:9, train_loss:0.0036371960522956215, eval_loss:0.0031857222188264134
epoch:10, train_loss:0.0026768270108732395, eval_loss:0.002561786691658199
epoch:11, train_loss:0.002303018391103251, eval_loss:0.001735220354050398
epoch:12, train_loss:0.0014562915621354477, eval_loss:0.0010706791603006423
epoch:13, train_loss:0.0008603533333371161, eval_

In [20]:
# 个位数加法
a=9
b=7
@torch.inference_mode()
def add(a,b):
    model.eval()
    input_list=[index_bos]+[a,index_add,b]+[index_eos]
    tgt_list=[index_bos]
    src=torch.tensor(input_list).to(device).unsqueeze(0) # (1,seq_len)
    tgt=torch.tensor(tgt_list).to(device).unsqueeze(0) # (1,seq_len)
    output=model.predict(src,tgt,max_seq_length=max_seq_length,index_eos=index_eos)
    print(f'输入序列 {src}')
    print(f'输出序列 {tgt_list+output}')
    print(f'{a} + {b}={output[:-1]}')
add(a,b)

输入序列 tensor([[12,  9, 10,  7, 13]], device='cuda:0')
输出序列 [12, 1, 6, 13]
9 + 7=[1, 6]
