In [1]:
import torch
import torch.nn as nn
import torch.functional as F
import numpy
import math

## Multi-Head-Attention

In [2]:
x = torch.randn(2, 32, 16) # batch_size, seq_len, hidden_dim
print(x.shape)

torch.Size([2, 32, 16])


In [3]:
d_model = 16
n_head = 2
batch_size, seq_len, _ = x.shape

In [None]:
class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, n_head) -> None:
        super().__init__()

        self.d_model = d_model
        self.n_head = n_head
        self.n_d = self.d_model // self.n_head

        self.w_q = nn.Linear(d_model, d_model)
        self.w_k = nn.Linear(d_model, d_model)
        self.w_v = nn.Linear(d_model, d_model)
        self.w_combine = nn.Linear(d_model, d_model)
        self.softmax = nn.Softmax(dim=-1)
    
    def forward(self, q, k, v, mask=None):
        batch_size, seq_len, d_model = q.shape
        q, k, v = self.w_q(q), self.w_k(k), self.w_v(v)
        q = q.view(batch_size, seq_len, self.n_head, self.n_d).permute(0, 2, 1, 3)
        k = k.view(batch_size, seq_len, self.n_head, self.n_d).permute(0, 2, 1, 3)
        v = v.view(batch_size, seq_len, self.n_head, self.n_d).permute(0, 2, 1, 3)

        attn = torch.matmul(q, k.transpose(2, 3)) / math.sqrt(self.n_d)
        if mask is not None:
            # mask = torch.tril(torch.ones(seq_len, seq_len))
            attn = attn.masked_fill(mask == 0, -1e9)
        attn = torch.matmul(self.softmax(attn), v)

        # view()需要tensor在内存中是连续的，切分和转置等操作会让内存不连续，contiguous()返回一个内存是连续的tensor
        attn = attn.permute(0, 2, 1, 3).contiguous().view(batch_size, seq_len, d_model)
        return self.w_combine(attn)

In [5]:
transformer = MultiHeadAttention(d_model, n_head)
output = transformer(x, x, x)
print(output.shape)

torch.Size([2, 32, 16])


## Token and Position Embedding

In [6]:
class TokenEmbedding(nn.Embedding):
    def __init__(self, vocab_size, embedding_dim):
        super().__init__(vocab_size, embedding_dim, padding_idx=1)

In [7]:
class PositionEmbedding(nn.Module):
    def __init__(self, max_len, d_model):
        super().__init__()
        self.encoding = torch.zeros(max_len, d_model)
        self.encoding.requires_grad_(False)

        # 这里pos需要再进行一次pos.unsqueeze(1)操作吗？需要要不然广播不了
        pos = torch.arange(0, max_len).unsqueeze(1).float() 
        _2i = torch.arange(0, d_model, step=2).float()

        self.encoding[:, 0::2] = torch.sin(pos / (10000 ** (_2i / d_model)))
        self.encoding[:, 1::2] = torch.cos(pos / (10000 ** (_2i / d_model)))

    def forward(self, x):
        batch_size, seq_len, d_model = x.shape
        return self.encoding[:seq_len, :]

In [8]:
class TransformerEmbedding(nn.Module):  
    def __init__(self, vocab_size, max_len, d_model, dropout):  
        super(TransformerEmbedding, self).__init__()  
        self.tok_emb = TokenEmbedding(vocab_size, d_model)  
        self.pos_emb = PositionEmbedding(max_len, d_model)  
        self.dropout = nn.Dropout(dropout)  

    def forward(self, x):  
        tok_emb = self.tok_emb(x)  
        pos_emb = self.pos_emb(x)  
        return self.dropout(tok_emb + pos_emb)

In [9]:
position_emb = PositionEmbedding(seq_len, d_model)
emb = position_emb(x)
print(emb.shape)
# 当repeat参数等于emb维数，则是单纯扩增多少倍，大于emb维数，会先开辟一个第0维，然后扩增多少倍，相当于复制多少次
emb = emb.repeat(batch_size, 1, 1)
print(emb.shape)

torch.Size([32, 16])
torch.Size([2, 32, 16])


In [10]:
def compute_freqs(d_model, seq_len, theta=10000.0):
    """Compute cos and sin m * theta_i"""
    theta_i = 1.0 / theta ** (torch.arange(0, d_model, 2)[:d_model // 2] / d_model)
    m = torch.arange(0, seq_len)

    freqs = torch.outer(m, theta_i)
    freqs = torch.polar(torch.ones_like(freqs), freqs)
    return freqs

def apply_rotary_pos_emb(x, freqs):
    """Apply rotary position encoding to x."""
    batch_size, seq_len, d_model = x.shape
    x_ = x.view(batch_size, seq_len, -1, 2)
    x_ = torch.view_as_complex(x_)
    x_ = x_ * freqs
    return torch.view_as_real(x_).flatten(2)

In [11]:
print(x.shape)
freqs = compute_freqs(d_model, seq_len)
print(freqs.shape)
x = apply_rotary_pos_emb(x, freqs)
print(x.shape)

torch.Size([2, 32, 16])
torch.Size([32, 8])
torch.Size([2, 32, 16])


## LayerNorm BatchNorm RMSNorm

In [12]:
class LayerNorm(nn.Module):
    def __init__(self, d_model, eps=1e-10):
        super().__init__()
        self.gamma = nn.Parameter(torch.ones(d_model))
        self.beta = nn.Parameter(torch.zeros(d_model))
        self.eps = eps

    def forward(self, x):
        mean = x.mean(-1, keepdim=True)
        var = x.var(-1, unbiased=False, keepdim=True)
        out = (x - mean) / (var + self.eps).sqrt()
        # 最后一维相同，可以广播
        out = self.gamma * out + self.beta
        return out

In [13]:
class BatchNorm(nn.Module):
    """just for cv task"""
    def __init__(self, d_model, eps=1e-10, momentum=0.1):
        super(BatchNorm, self).__init__()
        self.gamma = nn.Parameter(torch.ones(d_model))
        self.beta = nn.Parameter(torch.zeros(d_model))
        self.eps = eps
        self.momentum = momentum

        self.running_mean = torch.zeros(d_model)
        self.running_var = torch.ones(d_model)

    def forward(self, x):
        batch_size, d_model = x.shape
        mean = x.mean(0)
        var = x.var(0, unbiased=False)
        self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * mean
        self.running_var = (1 - self.momentum) * self.running_var + self.momentum * var
        out = (x - self.running_mean) / (var + self.running_var).sqrt()
        out = self.gamma * out + self.beta
        return out
        

In [14]:
class RMSNorm(nn.Module):
    def __init__(self, d_model, eps=1e-10):
        super().__init__()
        self.eps = eps
        self.gamma = nn.Parameter(torch.ones(d_model))

    def forward(self, x):
        x_hat = x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
        return self.gamma * x_hat

## FFN

In [15]:
class PositionwiseFeedForward(nn.Module):
    def __init__(self, d_model, hidden_size, dropout=0.1):
        super(PositionwiseFeedForward, self).__init__()
        self.fc1 = nn.Linear(d_model, hidden_size)
        self.fc2 = nn.Linear(hidden_size, d_model)
        self.dropout = nn.Dropout(dropout)
        """
            如果有一个神经元的输出值为 x 并且你应用了 dropout 操作（设定概率 p），
            那么在训练期间，这个值 x 有 p 的概率会被置为 0，而有 1-p 的概率会被乘以 1 / (1 - p)。
            这样的设计是为了确保整体输出的期望值在训练时和测试时（dropout 被关闭）保持一致。
        """

    def forward(self, x):
        x = self.fc1(x)
        x = F.relu(x)
        # dropout 在 fc1 之后，dropout 在 fc2 之前
        x = self.dropout(x)
        x = self.fc2(x)
        
        return x

In [16]:
class LlamaFeedForward(nn.Module):
    def __init__(self, d_model, hidden_size):
        """hidden_size是256的倍数中与8/3*4096最接近的,因为经过加法和rmsnorm会有8/3≈2.7的放缩，d_model为4096，hidden_size则为11008"""
        super().__init__()
        self.fc1 = nn.Linear(d_model, hidden_size)
        self.fc2 = nn.Linear(hidden_size, d_model)

    def forward(self, x):
        x = self.fc1(x)
        x = F.silu(x)
        x = self.fc2(x)

        return x

## Encoder

In [None]:
class EncoderLayer(nn.Module):
    def __init__(self, d_model, n_head, hidden_size, dropout=0.1):
        super(EncoderLayer, self).__init__()
        self.mha = MultiHeadAttention(d_model, n_head)
        self.drop1 = nn.Dropout(dropout)
        self.norm1 = nn.LayerNorm(d_model)

        self.ffn = PositionwiseFeedForward(d_model, hidden_size, dropout)
        self.drop2 = nn.Dropout(dropout)
        self.norm2 = nn.LayerNorm(d_model)
        

    def forward(self, x):
        _x = x
        x = self.mha(x, x, x)
        x = self.drop1(x)
        x = self.norm1(x + _x)

        _x = x
        x = self.ffn(x)
        x = self.drop2(x)
        x = self.norm2(x + _x)

        return x
    

In [None]:
class Encoder(nn.Module):
    def __init__(self, vocab_size, max_len, d_model, n_head, hidden_size, n_layer, dropout=0.1):
        super(Encoder, self).__init__()
        self.embedding = TransformerEmbedding(vocab_size, max_len, d_model, dropout)
        self.layers = nn.ModuleList([EncoderLayer(d_model, n_head, hidden_size, dropout) for _ in range(n_layer)])

    def forward(self, x, s_mask):
        x = self.embedding(x)
        for layer in self.layers:
            x = layer(x, s_mask)
        return x

## Decoder

In [None]:
class DncoderLayer(nn.Module):
    def __init__(self, d_model, n_head, hidden_size, dropout=0.1):
        super(EncoderLayer, self).__init__()
        self.mha1 = MultiHeadAttention(d_model, n_head)
        self.drop1 = nn.Dropout(dropout)
        self.norm1 = nn.LayerNorm(d_model)

        self.mha2 = MultiHeadAttention(d_model, n_head)
        self.drop2 = nn.Dropout(dropout)
        self.norm2 = nn.LayerNorm(d_model)

        self.ffn = PositionwiseFeedForward(d_model, hidden_size, dropout)
        self.drop3 = nn.Dropout(dropout)
        self.norm3 = nn.LayerNorm(d_model)
        

    def forward(self, dec, enc, t_mask, s_mask):
        _x = dec
        x = self.mha1(dec, dec, dec, t_mask) # 下三角掩码
        x = self.drop1(x)
        x = self.norm1(x + _x)

        if enc is not None:
            _x = x
            x = self.mha2(x, enc, enc, t_mask, s_mask) # padding mask
            x = self.drop2(x)
            x = self.norm2(x + _x)

        _x = x
        x = self.ffn(x)
        x = self.drop3(x)
        x = self.norm3(x + _x)

        return x
    

In [None]:
class Dncoder(nn.Module):
    def __init__(self, vocab_size, max_len, d_model, n_head, hidden_size, n_layer, dropout=0.1):
        super(Encoder, self).__init__()
        self.embedding = TransformerEmbedding(vocab_size, max_len, d_model, dropout)
        self.layers = nn.ModuleList([EncoderLayer(d_model, n_head, hidden_size, dropout) for _ in range(n_layer)])
        self.fc = nn.Linear(d_model, vocab_size)

    def forward(self, dec, enc, t_mask, s_mask):
        dec = self.embedding(dec)
        for layer in self.layers:
            dec = layer(dec, enc, t_mask, s_mask)
        return self.fc(dec)

## Transformer