In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F

## 归一化
注意不管是LayerNorm还是RMSNorm，求均值、方差等计算方式时，都是针对最后一维度。就相当于有batch_size*seq_len个样本，每个样本要保持均值为0方差为1。

### LayerNorm
$$
y = \frac{x - \mathbb{E}[x]}{\sqrt{\text{Var}[x] + \epsilon}} \cdot \gamma + \beta
$$

In [15]:
# 实现LayerNorm
class CustomLayerNorm(nn.Module):
    def __init__(self, normalized_shape, eps=1e-5):
        super().__init_()
        if isinstance(normalized_shape, int):
            normalized_shape = (normalized_shape,)
        self.normalized_shape = torch.Size(normalized_shape)
        self.eps = eps

        # 创建可学习的缩放参数gamma和偏移参数beta
        # nn.Parameter 会将它们注册为模型的参数，这样在训练时可以被优化器更新
        self.gamma = nn.Parameter(torch.ones(self.normalized_shape))
        self.beta = nn.Parameter(torch.zeros(self.normalized_shape))

    def forward(self, x):
        # x.shape = [batch_size, seq_len, embedding_dim]
        dims = tuple(range(x.dim() - len(self.normalized_shape), x.dim()))
        print("dims:", dims)
        # 计算均值和方差
        mean = x.mean(dims, keepdim=True)
        var = x.var(dims, keepdim=True, unbiased=False)
        # 归一化
        x_normalized = (x - mean) / torch.sqrt(var + self.eps)
        # 缩放和偏移
        output = self.gamma * x_normalized + self.beta

        return output

### RMSNorm
$$
y = \frac{x}{\sqrt{\frac{1}{n}\sum_{i=1}^{n}x_i^2 + \epsilon}} \cdot \gamma
$$

In [16]:
# 实现RMSNorm
class RMSNorm(nn.Module):
    def __init__(self, dim, eps=1e-5):
        super().__init__()
        self.dim = dim
        self.eps = eps
        # 创建可学习的缩放参数gamma
        self.gamma = nn.Parameter(torch.ones(dim))

    def _norm(self, x):
        # 实现x / sqrt( (1/n) * sum(x_i^2) + eps )
        # torch.rsqrt()计算1/sqrt()
        rms = torch.rsqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps)
        return x * rms
    
    def forward(self, x):
        output = self._norm(x)
        # 缩放gamma
        return output * self.gamma

### test code

In [17]:
x = torch.randn(2, 3, 4)
embedding_dim = x.shape[-1]

layer_norm = nn.LayerNorm(embedding_dim)
output = layer_norm(x)

print(output.shape)

rms_norm = RMSNorm(embedding_dim)
output = rms_norm(x)

print(output.shape)

torch.Size([2, 3, 4])
torch.Size([2, 3, 4])


## 激活函数

### GLU
$$
\text{GLU}(x, W, V, b, c) = (xW + b)\otimes \sigma(xV + c)
$$

In [23]:
class GLU(nn.Module):
    def __init__(self, in_features: int, hidden_features: int = None, out_features: int = None):
        super().__init__()
        hidden_features = hidden_features or in_features
        out_features = out_features or in_features

        # W与V都是x通过线性变换后得到的；下面两个tensor的bias都设置为了False，是因为假设与LayerNorm联用，LayerNorm本省有bias参数，此外有个x-E[x]的操作
        # 导致在GLU中学习/保留bias是多余的
        self.wv = nn.Linear(in_features, hidden_features * 2, bias=False)
        self.out = nn.Linear(hidden_features, out_features, bias=False)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        '''
        xW ⊗ σ(xV)
        '''
        content, gate = self.wv(x).chunk(2, dim=-1) # chunk函数将矩阵沿着dim分成2部分
        hidden_state = content * torch.sigmoid(gate) # hidden_state.shape = (batch_size, hidden_features)
        output = self.out(hidden_state)

        return output

### SwiGLU
$$
\text{SwiGLU}(x, W, V) = (xW) \otimes \text{SiLU}(xV)
$$
其中
$$
\text{SiLU}(x) = x\cdot\sigma(x)=\frac{x}{1+e^{-x}}
$$

In [25]:
class SwiGLU(nn.Module):
    def __init__(self, in_features: int, hidden_features: int = None, out_features: int = None):
        super().__init__()
        hidden_features = hidden_features or in_features
        out_features = out_features or in_features

        self.w = nn.Linear(in_features, hidden_features)
        self.v = nn.Linear(in_features, hidden_features)
        self.out = nn.Linear(hidden_features, out_features)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x_for_gate = self.v(x)
        gate = x_for_gate * torch.sigmoid(x_for_gate)

        content = self.w(x)
        hidden_state = content * gate

        output = self.out(hidden_state)

        return output

### GeGLU
$$
\text{GeGLU}(x, W, V) = (xW) \otimes \text{GELU}(xV)
$$

In [27]:
class GeGLU(nn.Module):
    def __init__(self, in_features, hidden_features=None, out_features=None):
        super().__init__()
        hidden_features = hidden_features or in_features
        out_features = out_features or in_features
        
        self.wv = nn.Linear(in_features, hidden_features * 2)
        self.out = nn.Linear(hidden_features, out_features)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        content, gate = self.wv(x).chunk(2, dim=-1)
        gate = F.gelu(gate)

        hidden_state = content * gate

        output = self.out(hidden_state)

        return output

### test code

In [28]:
in_dim = 64
x = torch.randn(2, 10, in_dim)
activation = GLU(in_features=in_dim, hidden_features=128)

output = activation(x)
print(output.shape)

activation = SwiGLU(in_features=in_dim)
output = activation(x)
print(output.shape)

activation = GeGLU(in_features=in_dim)
output = activation(x)
print(output.shape)

torch.Size([2, 10, 64])
torch.Size([2, 10, 64])
torch.Size([2, 10, 64])


### RoPE: rotary position embeddings
实现非复数版本RoPE；参考[视频链接](https://www.youtube.com/watch?v=o29P0Kpobz0&t=351s)中的代码实现
$$
\theta_j=10000^{-\frac{2j}{d}}
$$

In [24]:
class RotaryEmbedding(nn.Module):
    def __init__(self, dim, max_seq_len=2048, base=1e4):
        '''
        :dim: 词嵌入的维度，必须是偶数。
        :max_seq_len: 模型能处理的最大序列长度，默认2048。
        :base: 用于计算旋转频率的基数默认10000。
        '''
        super(RotaryEmbedding, self).__init__()
        if dim %2 != 0:
            raise ValueError("RotaryEmbedding dim must be even!")
        
        # 单位旋转角度θ
        theta = 1 / (base ** (torch.arange(0, dim, 2) / dim)) # theta.shape = (dim//2,)
        # m
        m = torch.arange(max_seq_len) # m.shape = (max_seq_len,)
        # 获得mθ
        m_theta = torch.outer(m, theta) # m_theta.shape = (max_seq_len, dim//2)

        # 计算sin和cos
        cos = torch.cos(m_theta) # cos.shape = (max_seq_len, dim//2)
        sin = torch.sin(m_theta) # sin.shape = (max_seq_len, dim//2)

        # 下面这个操作是对应ROT矩阵@[x1, x2].T
        cos = torch.cat((cos, cos), dim=-1) # cos.shape = (max_seq_len, dim)
        sin = torch.cat((sin, sin), dim=-1) # sin.shape = (max_seq_len, dim)

        # register_buffer能将cos和sin加入到模型的state_dict中，但是不属于parameters，所以不会被优化器更新。
        self.register_buffer('cos_cached', cos)
        self.register_buffer('sin_cached', sin)

    @staticmethod
    def rotate_half(x):
        x1, x2 = torch.chunk(x, 2, dim=-1)
        return torch.cat((-x2, x1), dim=-1) # shape = (batch_size, seq_len, dim)

    def forward(self, x):
        '''
        x: 输入的序列，shape = (batch_size, seq_len, dim)
        '''
        seq_len = x.size(1)
        # 计算sin和cos
        cos = self.cos_cached[:seq_len] # shape = (seq_len, dim)
        sin = self.sin_cached[:seq_len] # shape = (seq_len, dim)
        cos = cos.unsqueeze(0) # shape = (1, seq_len, dim)
        sin = sin.unsqueeze(0) # shape = (1, seq_len, dim)

        embed = x * cos + self.rotate_half(x) * sin

        return embed

### test code

In [None]:
batch_size = 2
seq_len = 6
embedding_dim = 8

rope = RotaryEmbedding(dim=embedding_dim, max_seq_len=seq_len)

q = torch.randn(batch_size, seq_len, embedding_dim)
k = torch.randn(batch_size, seq_len, embedding_dim)

q_with_rope = rope(q)
k_with_rope = rope(k)

print(q_with_rope.shape)

# 验证旋转不改变norm长度
original_norm = torch.linalg.norm(q, dim=-1)
rotated_norm = torch.linalg.norm(q_with_rope, dim=-1)

print(torch.allclose(original_norm, rotated_norm))

# 与复数版本代码进行对比
class RotaryEmbeddingComplex(nn.Module):
    def __init__(self, dim, max_seq_len=2048, base=10000):
        super().__init__()
        inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
        t = torch.arange(max_seq_len, dtype=torch.float32)
        freqs = torch.outer(t, inv_freq)
        self.register_buffer("freqs_cis", torch.polar(torch.ones_like(freqs), freqs))

    def forward(self, x, seq_start_pos=0):
        seq_len = x.shape[1]
        
        # --- 修正开始 ---
        # 1. 将 x 拆分为前后两半（实部和虚部）
        x_r = x[..., : x.shape[-1] // 2]
        x_i = x[..., x.shape[-1] // 2 :]
        
        # 2. 手动组合成复数
        x_complex = torch.complex(x_r, x_i)
        
        # 3. 获取旋转频率
        freqs_cis = self.freqs_cis[seq_start_pos: seq_start_pos + seq_len].unsqueeze(0)
        
        # 4. 应用旋转
        x_rotated_complex = x_complex * freqs_cis
        
        # 5. 拆分回实数和虚数部分
        x_rotated_r = x_rotated_complex.real
        x_rotated_i = x_rotated_complex.imag
        
        # 6. 拼接回原来的形状
        x_out = torch.cat([x_rotated_r, x_rotated_i], dim=-1)
        # --- 修正结束 ---

        return x_out.type_as(x)
        
rope_complex = RotaryEmbeddingComplex(dim=embedding_dim, max_seq_len=seq_len)

q_with_rope_complex = rope_complex(q)

print(torch.allclose(q_with_rope, q_with_rope_complex, atol=1e-6))

torch.Size([2, 6, 8])
True
True
