In [None]:
from llama2 import apply_rotary_emb,precompute_freqs_cis,Attention
import torch

xq = torch.randn(1, 50, 6, 48) # bs, seq_len, dim//n_head, n_head_dim
xk = torch.randn(1, 50, 6, 48) # bs, seq_len, dim//n_head, n_head_dim

# 使用precomputed_freqs_cis 获取sin和cos
cos,sin = precompute_freqs_cis(288//6, 50)
print(cos.shape,sin.shape) 
xq_out,xk_out = apply_rotary_emb(xq, xk, cos, sin)
xq_out.shape,xk_out.shape


In [2]:
from llama2 import Attention,ModelConfig,precompute_freqs_cis
import torch

args = ModelConfig()
# 创建Attention实例
attention_model = Attention(args)

# 模拟输入数据
batch_size = 1
seq_len = 50  # 假设实际使用的序列长度为50
dim = args.dim
x = torch.rand(batch_size, seq_len, dim)  # 随机生成输入张量
print(f"x shape {x.shape}")
# freqs_cos = torch.rand(seq_len, dim // 2)  # 模拟cos频率，用于RoPE
# freqs_sin = torch.rand(seq_len, dim // 2)  # 模拟sin频率，用于RoPE

freqs_cos, freqs_sin = precompute_freqs_cis(dim//args.n_heads, seq_len)

print(f"freqs_cos shape {freqs_cos.shape}, freqs_sin shape {freqs_sin.shape}")


# 运行Attention模型
output = attention_model(x, freqs_cos, freqs_sin)

# attention出来之后的形状 依然是[batch_size, seq_len, dim]
print("Output shape:", output.shape)

x shape torch.Size([1, 50, 768])
freqs_cos shape torch.Size([50, 24]), freqs_sin shape torch.Size([50, 24])
调整维度之前： xq shape torch.Size([1, 50, 768]), xk shape torch.Size([1, 50, 384]), xv shape torch.Size([1, 50, 384]) 
调整维度之后： xq shape torch.Size([1, 50, 16, 48]), xk shape torch.Size([1, 50, 8, 48]), xv shape torch.Size([1, 50, 8, 48]) 
旋转位置嵌入： xq shape torch.Size([1, 50, 16, 48]), xk shape torch.Size([1, 50, 8, 48])
批次维度处理： xq shape torch.Size([1, 16, 50, 48]), xk shape torch.Size([1, 16, 50, 48]), xv shape torch.Size([1, 16, 50, 48]) 
Output shape: torch.Size([1, 50, 768])


In [2]:
from llama2 import Attention,ModelConfig,MLP
import torch

args = ModelConfig()
# 创建MLP实例
mlp = MLP(args.dim, args.hidden_dim, args.multiple_of, args.dropout)
# 随机生成数据
x = torch.randn(1, 50, args.dim)
# 运行MLP模型
output = mlp(x)
print(output.shape)

dim:768 hidden_dim:2048 mutile_of:64 dropout:0.0
torch.Size([1, 50, 768])


In [2]:
from llama2 import Attention,ModelConfig,MLP,DecoderLayer,precompute_freqs_cis
import torch

args = ModelConfig()

# 创建LLaMADecoderLayer实例
decoderlayer = DecoderLayer(0, args)

# 模拟输入数据
dim = args.dim
seq_len = 50

x = torch.randn(1, seq_len, dim) # [bs, seq_len, dim]
print(x.shape)

freqs_cos, freqs_sin = precompute_freqs_cis(dim//args.n_heads, seq_len)

out = decoderlayer(x, freqs_cos, freqs_sin)

print(out.shape) # 形状和输入的x一样 [batch_size, seq_len, dim]

dim:768 hidden_dim:2048 mutile_of:64 dropout:0.0
torch.Size([1, 50, 768])
调整维度之前： xq shape torch.Size([1, 50, 768]), xk shape torch.Size([1, 50, 384]), xv shape torch.Size([1, 50, 384]) 
调整维度之后： xq shape torch.Size([1, 50, 16, 48]), xk shape torch.Size([1, 50, 8, 48]), xv shape torch.Size([1, 50, 8, 48]) 
旋转位置嵌入： xq shape torch.Size([1, 50, 16, 48]), xk shape torch.Size([1, 50, 8, 48])
批次维度处理： xq shape torch.Size([1, 16, 50, 48]), xk shape torch.Size([1, 16, 50, 48]), xv shape torch.Size([1, 16, 50, 48]) 
torch.Size([1, 50, 768])


In [4]:
from llama2 import Attention,ModelConfig,MLP,DecoderLayer,precompute_freqs_cis,Transformer
import torch
args = ModelConfig()
# LLaMA2Model.forward 接受两个参数，tokens和targets，其中tokens是输入的张量, 应为int类型
x = torch.randint(0, 6144, (1, 50)) # [bs, seq_len]
print (f"x {x}")
# 实例化LLaMA2Model
model = Transformer(args=args)
# 计算model的全部参数
num_params = sum(p.numel() for p in model.parameters())
print('Number of parameters:', num_params)

out = model(x)
print(f"out {out.logits}")
print(out.logits.shape) # [batch_size, 1, vocab_size]

x tensor([[1585, 2038, 4952, 3093, 5099, 2860, 2580, 2518,  634, 3128, 1349, 4460,
         3617, 5235,  275, 1986, 5783,  463,  931, 2584, 2447, 4277,  820, 5148,
         4094, 1558, 5582, 4577, 3232, 2681, 2609, 1664, 1732, 3448, 1200, 2254,
         1621, 4233, 4794, 3279, 4542, 2308, 5400, 3364, 2047, 5510,  177,  679,
         5949, 4153]])
Number of parameters: 82594560
out tensor([[[-0.1328,  0.2435, -0.3125,  ...,  0.5583, -0.6154,  0.6677]]],
       grad_fn=<UnsafeViewBackward0>)
torch.Size([1, 1, 6144])
