# 1 Set Config

In [57]:
from rich import print
from lit_llama.model import Block, LLaMA, LLaMAConfig

In [58]:
# For shakespeare, choose smaller block size than vanilla LLaMA
block_size = 1024

In [59]:
config = LLaMAConfig.from_name("7B")
config.block_size = block_size # block size是序列长度
# config.vocab_size = 100  # from prepare_shakespeare.py
print("7B config", config)

In [60]:
config = LLaMAConfig.from_name("baby_llama")
config.block_size = block_size
config.vocab_size = 32000  # from prepare_shakespeare.py
config.padded_vocab_size = 32000  # from prepare_shakespeare.py
print("baby_llama config:", config)

model = LLaMA(config)
print(model)

# 2 Load Data

In [61]:
import torch
X = torch.load('data/input.pt')
Y = torch.load('data/target.pt')

print(X.shape)
print(Y.shape)
print("batch:{}, length: {} ".format(X.shape[0],X.shape[1]))

# 3 Inference

In [62]:
from torch.nn import functional as F
logits = model(X)
print(logits.shape)
print('vocab size:', config.vocab_size)
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), Y.view(-1))
print('loss:', loss)

torch.Size([16, 8, 1024, 16])
torch.Size([16, 8, 1024, 16])
torch.Size([16, 8, 1024, 16])
torch.Size([1, 1, 1024, 1024])
torch.Size([16, 8, 1024, 16])
torch.Size([16, 8, 1024, 16])
torch.Size([16, 8, 1024, 16])
torch.Size([1, 1, 1024, 1024])


# 4 Model Forward

In [40]:
model

LLaMA(
  (lm_head): Linear(in_features=128, out_features=32000, bias=False)
  (transformer): ModuleDict(
    (wte): Embedding(32000, 128)
    (h): ModuleList(
      (0-1): 2 x Block(
        (rms_1): RMSNorm()
        (attn): CausalSelfAttention(
          (c_attn): Linear(in_features=128, out_features=384, bias=False)
          (c_proj): Linear(in_features=128, out_features=128, bias=False)
        )
        (rms_2): RMSNorm()
        (mlp): MLP(
          (c_fc1): Linear(in_features=128, out_features=512, bias=False)
          (c_fc2): Linear(in_features=128, out_features=512, bias=False)
          (c_proj): Linear(in_features=512, out_features=128, bias=False)
        )
      )
    )
    (ln_f): RMSNorm()
  )
)

In [63]:
idx = X
B, T = idx.size()
print("batch:{}, length:{} ".format(B,T))

In [64]:
print('---------------0. create RoPE, Mask----------------')
# 创建RoPE、mask矩阵
rope = model.rope_cache[:T]
mask = model.mask_cache[:, :, :T, :T]
print("rope: ", rope.shape)
print("mask: ", mask.shape)
max_seq_length = model.config.block_size
print("length:", max_seq_length)

In [70]:
print('---------------1.embding----------------')
x = model.transformer.wte(idx)
x_embd = x
print("n_embd: ", config.n_embd)
print("before embeding: ", idx.shape)
print("after embeding: ", x.shape)

In [66]:
print('---------------2.llama block attention ----------------')
print("block size = ",len(model.transformer.h))
print("n_layer",config.n_layer)
for block in model.transformer.h:
    x, _ = block(x, rope, mask, max_seq_length)
    print("LLama Block: ",x.shape)

torch.Size([16, 8, 1024, 16])
torch.Size([16, 8, 1024, 16])
torch.Size([16, 8, 1024, 16])
torch.Size([1, 1, 1024, 1024])


torch.Size([16, 8, 1024, 16])
torch.Size([16, 8, 1024, 16])
torch.Size([16, 8, 1024, 16])
torch.Size([1, 1, 1024, 1024])


In [67]:
print('---------------3.llama output ----------------')
x = model.transformer.ln_f(x)
print("layer norm no center: ",x.shape)

logits = model.lm_head(x)
print("output logits ",logits.shape)
print("config.vocab_size: ", config.vocab_size)

## 4.1 Block Forward

In [68]:
block = model.transformer.h[0]
print(block)

In [71]:
print(x_embd.shape)
x, _ = block(x_embd, rope, mask, max_seq_length)
print("rms_1 -> attention-> rms_2-> MLP")

torch.Size([16, 8, 1024, 16])
torch.Size([16, 8, 1024, 16])
torch.Size([16, 8, 1024, 16])
torch.Size([1, 1, 1024, 1024])


In [81]:
x_rms_1 = block.rms_1(x_embd)
x_attn, _ = block.attn(x_rms_1, rope, mask, max_seq_length)
x = x_embd + x_attn
print('block attention result:', x.shape)

x_rms_2 = block.rms_2(x)
x_block_out = block.mlp(x_rms_2) + x
print('x + mlp(x) result:', x_block_out.shape)

torch.Size([16, 8, 1024, 16])
torch.Size([16, 8, 1024, 16])
torch.Size([16, 8, 1024, 16])
torch.Size([1, 1, 1024, 1024])


## 4.2 rms norm

In [82]:
rms_norm = model.transformer.h[0].rms_1
print(rms_norm)

In [84]:
x = x_embd
print("rms_norm.scale", rms_norm.scale.shape)
print("config.n_embd", config.n_embd)
print("rms_norm.eps", rms_norm.eps)
print("rms_norm.dim", rms_norm.dim)

In [87]:
# rms_norm = x / rsqrt(Mean(x^2)+e) * W
norm_x = torch.mean(x * x, dim=rms_norm.dim, keepdim=True)
x_normed = x * torch.rsqrt(norm_x + rms_norm.eps)
x_rms = rms_norm.scale * x_normed
print("归一化前", x_embd.shape)
print("归一化后", x_rms.shape)

## 4.3 simple rope

In [89]:
seq_len = block_size
n_elem = config.n_embd
n_head = config.n_head
base = 1000
print("输入:句长:{},单头维度:{},头:{}".format(seq_len, n_elem, n_head))

In [109]:
theta = 1.0 / (base ** (torch.arange(0, n_elem, 2) / n_elem))
print(theta.shape) # 128d,两个一组
seq_idx = torch.arange(seq_len)
idx_theta = torch.outer(seq_idx, theta).float()
print(idx_theta.shape)
cache = torch.stack([torch.cos(idx_theta), torch.sin(idx_theta)], dim=-1)
print(cache[:T].shape)

In [110]:
# torch.outer(行，列)的用法
import torch

seq_idx = torch.tensor([1, 2, 3])  # 形状 (3,)
theta = torch.tensor([0.5, 1.0])     # 形状 (2,)
result = torch.outer(seq_idx, theta)
# result 为：
# tensor([[0.5000, 1.0000],
#         [1.0000, 2.0000],
#         [1.5000, 3.0000]])

In [113]:
# flatten的用法，flatten(start_dim, end_dim) 会把张量从 start_dim 到 end_dim 的所有维度合并为一个维度
a = torch.randn(1,1,1,2,3)
print(a.shape)
print(a)
print(a.flatten(3).shape)
print(a.flatten(3))

In [119]:
# torch.stack 用于将多个张量沿新维度进行堆叠。
# 具体来说，torch.stack([a, a], -1) 会在 a 的最后一个维度后再添加一个新的维度，将两个 a 张量在该维度进行堆叠
# 可以把它理解为 在每个数的后面，增加一个新的一层（或者说新的厚度）
a = torch.tensor([[1,2],[3,4]])
print(a.shape)
print(a)
b = torch.stack([a,a], -1)
print(b.shape)
print(b)
c = torch.stack([a,a,a], -1)
print(c.shape)
print(c)

In [125]:
# 整体实现rope
RoPECache = torch.Tensor

def build_rope_cache(seq_len:int, n_elem:int, dtype=torch.dtype, device=torch.device, base: int = 10000) -> RoPECache:
    print("输入:句长:{},单头维度:{}".format(seq_len, n_elem))
    
    theta = 1.0 / (base ** torch.arange(0, n_elem, 2, dtype=dtype, device=device) / n_elem)
    print("theta:", theta)
    
    seq_idx = torch.arange(0, seq_len, dtype=dtype, device=device)
    print("seqidx:", seq_idx)
    
    idx_theta = torch.outer(seq_idx, theta).float()
    print("position idx* theta :", idx_theta.shape)
    print("idx_theta[:4,:4]:", idx_theta[:4,:4])
    
    cache = torch.stack([torch.cos(idx_theta), torch.sin(idx_theta)], dim=-1)
    print("cache: (8是因为单头d为16，两个一组则为8)", cache.shape)
    print(cache[1,:4,:2])
    
    # float16 类型兼容性较高，支持更多硬件加速
    if dtype in (torch.float16, torch.bfloat16, torch.int8):
        cache = cache.half()
    print(cache.shape)
    print(cache.type)
    
    return cache


# Rope 实现
RoPECache = build_rope_cache(
    seq_len=model.config.block_size,
    n_elem=model.config.n_embd // model.config.n_head,
    dtype=idx.dtype,
    device=idx.device,
)

model.RoPECache = RoPECache

In [168]:
def apply_rope(x:torch.Tensor, rope_cache:RoPECache)-> torch.Tensor:
    T = x.size(1)
    
    rope_cache = rope_cache[:T]
    xshaped = x.float().reshape(*x.shape[:-1], -1, 2)
    print('x.shape', x.shape)
    print('xshaped.shape', xshaped.shape)
    
    # 准备 rope_cache 以便与 xshaped 进行元素级操作。
    # 通过这种形状调整和广播操作，rope_cache 可以在特定维度上重复自身，使得它的形状与 xshaped 兼容
    rope_cache = rope_cache.view(1, xshaped.size(1), 1, xshaped.size(3), 2)
    print('rope_cache ', rope_cache.shape)
    print('rope_cache ',rope_cache)
    
    
    # x0 * cos0 - x1 * sin0
    # x1 * cos0 + x0 * sin0
    x_out2 = torch.stack(
        [
            xshaped[..., 0] * rope_cache[..., 0] - xshaped[..., 1] * rope_cache[..., 1],
            xshaped[..., 1] * rope_cache[..., 0] + xshaped[..., 0] * rope_cache[..., 1],
        ],
        -1,
    )
    print('x_out2', x_out2.shape)
    x_out2 = x_out2.flatten(3)
    print('flat_x_out2', x_out2.shape)
    return x_out2 # 和输入的x形状相同

x = torch.rand(1, 3, 1, 16)
print('x', x)
out = apply_rope(x, model.RoPECache)
print('out', out)

In [170]:
# 举个列子
tmp_x = x[0,-1,0,:4]
print('tmp_x', tmp_x)
print(tmp_x.shape)

tmp_rope = RoPECache[2,:2,:2] # pos, dim:2, 2(sin, cos)
print('tmp_rope',tmp_rope)
cos_rope = tmp_rope[:,0]
print('cos_rope', cos_rope)
sin_rope = tmp_rope[:,1]
print('sin_rope',sin_rope)

'''
配对旋转：RoPE 编码通常将连续的两个元素视为一个“配对”，对每一对元素应用相同的旋转角度。
这种做法能够让编码对位置的变化有更敏感的反应，从而增强模型对序列位置的感知。

共享相同角度：在 tmp_x 的前两个元素中，tmp_x[0] 和 tmp_x[1] 都使用相同的角度（cos_rope[0] 和 sin_rope[0]）进行旋转，
确保这对元素按照统一的方式进行旋转变换。
类似地，tmp_x[2] 和 tmp_x[3] 使用另一个角度（cos_rope[1] 和 sin_rope[1]），让下一对元素的旋转方式保持一致。
'''
rope_x0 = tmp_x[0] * cos_rope[0] - tmp_x[1] * sin_rope[0]
rope_x1 = tmp_x[1] * cos_rope[0] + tmp_x[0] * sin_rope[0]
rope_x2 = tmp_x[2] * cos_rope[1] - tmp_x[3] * sin_rope[1]
rope_x3 = tmp_x[3] * cos_rope[1] + tmp_x[2] * sin_rope[1]

# 原始x向量
print(tmp_x)

# 手动计算的rope结果
print("手动计算rope结果")
print(rope_x0,
      rope_x1,
      rope_x2,
      rope_x3)

# this by apply_rope
print("apply_rope计算的rope结果")
print(out.shape)
print(out[0,-1,0,:4])

## 4.4 Block Attention

In [172]:
# block attention
model.transformer.h[0].attn

CausalSelfAttention(
  (c_attn): Linear(in_features=128, out_features=384, bias=False)
  (c_proj): Linear(in_features=128, out_features=128, bias=False)
)

In [176]:
model.transformer.h[0] # attn的输入是rms_1的输出

Block(
  (rms_1): RMSNorm()
  (attn): CausalSelfAttention(
    (c_attn): Linear(in_features=128, out_features=384, bias=False)
    (c_proj): Linear(in_features=128, out_features=128, bias=False)
  )
  (rms_2): RMSNorm()
  (mlp): MLP(
    (c_fc1): Linear(in_features=128, out_features=512, bias=False)
    (c_fc2): Linear(in_features=128, out_features=512, bias=False)
    (c_proj): Linear(in_features=512, out_features=128, bias=False)
  )
)

In [179]:
# 通过已有的blck attn输出结果
block_attn = model.transformer.h[0].attn
print(help(block_attn.forward))

x_attn, _ = block_attn(x_rms_1, rope, mask, max_seq_length)
print(x_attn.shape) # B, T, D

Help on method forward in module lit_llama.model:

forward(x: torch.Tensor, rope: torch.Tensor, mask: torch.Tensor, max_seq_length: int, input_pos: Union[torch.Tensor, NoneType] = None, kv_cache: Union[Tuple[torch.Tensor, torch.Tensor], NoneType] = None) -> Tuple[torch.Tensor, Union[Tuple[torch.Tensor, torch.Tensor], NoneType]] method of lit_llama.model.CausalSelfAttention instance
    Defines the computation performed at every call.
    
    Should be overridden by all subclasses.
    
    .. note::
        Although the recipe for forward pass needs to be defined within
        this function, one should call the :class:`Module` instance afterwards
        instead of this since the former takes care of running the
        registered hooks while the latter silently ignores them.



torch.Size([16, 8, 1024, 16])
torch.Size([16, 8, 1024, 16])
torch.Size([16, 8, 1024, 16])
torch.Size([1, 1, 1024, 1024])


In [181]:
# 自己实现block attn
B, T, D = x_rms_1.size()
print("batch:{}, length:{}, n_embding:{}".format(B,T,D))

In [192]:
'''
1.block_attn.c_attn(x) 通过一个线性变换，生成了一个形状为 (batch_size, seq_len, 3 * n_embd) 的张量；
2..split(block_attn.n_embd, dim=2) 将最后一维（3 * n_embd）分割为 3 个大小为 n_embd 的部分；
3.q, k, v 分别保存分割后的每部分，形状均为 (batch_size, seq_len, n_embd)。
'''
print('--------------1. attenion split------------------')
x = x_rms_1
q, k, v = block_attn.c_attn(x).split(block_attn.n_embd, dim=-1)
head_size = D // block_attn.n_head
q = q.view(B, T, block_attn.n_head, head_size)
k = k.view(B, T, block_attn.n_head, head_size)
v = v.view(B, T, block_attn.n_head, head_size)
print("batch, length, head, n_embding: {}".format(k.shape))

In [195]:
print('--------------2. qk RoPE 旋转相对位置编码------------------')
print('RoPE编码作用在每个block的attention计算QK里')
q_rope_before = q
q = apply_rope(q, rope)
q_rope_after = q
k = apply_rope(k, rope)
print("q_rope前:", q_rope_before.shape)
print("q_rope后:", q_rope_after.shape)

q = q.transpose(1,2) # B, T, H, H_D -> B, H, T, H_D
k = k.transpose(1,2)
v = v.transpose(1,2)

In [199]:
print('--------------3. 计算scale dot product 和前向传播------------------')
y = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0)
y = y.transpose(1,2).contiguous().view(B, T, D)
y = block_attn.c_proj(y)
print(y.shape)

## 4.5 MLP silu

In [200]:
model.transformer.h[0].mlp

MLP(
  (c_fc1): Linear(in_features=128, out_features=512, bias=False)
  (c_fc2): Linear(in_features=128, out_features=512, bias=False)
  (c_proj): Linear(in_features=512, out_features=128, bias=False)
)

In [202]:
mlp  = model.transformer.h[0].mlp
x = x_rms_1
print("SiLU(x) = x * sigmoid(x)")
x = F.silu(mlp.c_fc1(x)) * mlp.c_fc2(x)
print("c_fc1 is gate")
print("c_fc2 is up")
x = mlp.c_proj(x)
print(x.shape)