In [1]:
import math
import struct
import inspect
from dataclasses import dataclass
from typing import Any, Optional, Tuple

import numpy as np
import torch
import torch.nn.functional as F
from torch import nn

In [2]:
@dataclass
class ModelArgs:
    # default hyperparameters for the Llama 7B model
    dim: int = 4096
    n_layers: int = 32
    n_heads: int = 32
    n_kv_heads: Optional[int] = None
    vocab_size: int = 32000
    hidden_dim: Optional[int] = None
    multiple_of: int = 256  # MLP hidden layer size will be multiple of
    norm_eps: float = 1e-5
    max_seq_len: int = 2048
    dropout: float = 0.0

In [3]:
args = ModelArgs()

Llama2的RMSNorm层的公式如下：

$$\text{RMSNorm}(x) = \frac{x}{\sqrt{\frac{1}{n}\sum_{i=1}^{n}w_i^2 + \epsilon}}$$

其中：

- ( $x$ ) 是层的输入。
- ( $w_i$ ) 代表层的权重。
- ( $n$ ) 是权重的数量。
- ( $\epsilon$ ) 是一个小常数，用于数值稳定性（以避免除以零的情况）。

In [8]:
class RMSNorm(torch.nn.Module):
    def __init__(self, dim: int, eps: float):
        super().__init__()
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(dim))

    def _norm(self, x):
        return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)

    def forward(self, x):
        output = self._norm(x.float()).type_as(x)
        return output * self.weight

In [9]:
norm = RMSNorm(dim=args.dim, eps=args.norm_eps)

In [17]:
# 写一个关于norm的测试

x = torch.randn(1, 50, 4096)
print(x)
print(norm(x))

tensor([[[-0.4045,  1.2928,  0.5496,  ...,  1.0508, -0.0746,  1.3544],
         [-0.2118, -0.4550, -0.3683,  ...,  1.4466, -0.1933,  0.8179],
         [-1.4608, -1.1372, -1.0112,  ...,  2.1685, -0.1298,  1.0042],
         ...,
         [-1.2684,  1.1784,  0.3849,  ...,  0.8639, -0.0211,  0.9200],
         [-0.1934,  1.0352,  0.2050,  ..., -0.7670,  2.0860,  0.3994],
         [ 1.4662, -0.1603, -0.2214,  ..., -0.4009, -0.0426, -2.1020]]])
tensor([[[-0.3914,  1.2511,  0.5318,  ...,  1.0169, -0.0722,  1.3107],
         [-0.2092, -0.4494, -0.3638,  ...,  1.4288, -0.1909,  0.8079],
         [-1.4585, -1.1354, -1.0096,  ...,  2.1651, -0.1296,  1.0026],
         ...,
         [-1.2429,  1.1547,  0.3771,  ...,  0.8465, -0.0207,  0.9015],
         [-0.1934,  1.0352,  0.2050,  ..., -0.7671,  2.0861,  0.3995],
         [ 1.4526, -0.1588, -0.2194,  ..., -0.3972, -0.0423, -2.0825]]],
       grad_fn=<MulBackward0>)


In [18]:
def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0):
    freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
    t = torch.arange(end, device=freqs.device)  # type: ignore
    freqs = torch.outer(t, freqs).float()  # type: ignore
    freqs_cos = torch.cos(freqs)  # real part
    freqs_sin = torch.sin(freqs)  # imaginary part
    return freqs_cos, freqs_sin