# Implimenting LLaMA model in PyTorch

In [1]:
import RMSNorm

In [2]:
from typing import Optional
class ModelArgs:
    dim: int = 4096
    n_layers: int = 32
    n_heads: int = 32
    n_kv_heads: Optional[int] = None
    vocab_size: int = -1 # Later set in the build method
    multiple_of: int = 256
    ffn_dim_multiplier: Optional[float] = None
    norm_eps: float = 1e-5
    theta: float = 10000.0
    
    # Needed for KV cache
    max_batch_size: int = 32
    max_seq_len: int = 2048

    device: str = "cuda"

In [3]:
import torch
from RoPE import rope 

rope = rope(ModelArgs.dim, ModelArgs.max_seq_len, ModelArgs.device, ModelArgs.theta)

x = torch.randn(1, 2048, 4096).to("cuda")
rope(x).shape

--------------------------------------------------
torch.Size([1, 2048, 2048]) torch.Size([1, 1, 2048, 2048])
--------------------------------------------------


torch.Size([1, 2048, 4096])

In [4]:
# import torch
# from torch import nn
# class MHA_KVCache(nn.Module):
#     def __init__(self, args:ModelArgs,  qkv_bias=False):
#         super().__init__()
#         # Indicates the number of heads for the Keys and Values
#         self.n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads
#         # Indicates the number of heads for the Queries
#         self.n_heads_q = args.n_heads
#         # Indicates how many times the Keys and Values should be repeated
#         self.n_rep = self.n_heads_q // self.n_kv_heads
#         # Indicates the dimension of each head, that is, the part of the embedding that each head will be responsible for
#         self.head_dim = args.dim // args.n_heads
        
#         dim = args.dim
#         self.dim = dim
#         self.head_dim  = dim // args.n_heads 
#         self.W_query = nn.Linear(dim, dim, bias=qkv_bias)
#         self.W_key = nn.Linear(dim, dim, bias=qkv_bias)
#         self.W_value = nn.Linear(dim, dim, bias=qkv_bias)
#         self.out_proj = nn.Linear(dim, dim) 
        
#         self.cache_k = torch.zeros((args.max_batch_size, args.max_seq_len, self.n_kv_heads, self.head_dim))
#         self.cache_v = torch.zeros((args.max_batch_size, args.max_seq_len, self.n_kv_heads, self.head_dim))

        
#     def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor:
#         batch_size, seq_len, n_kv_heads, head_dim = x.shape
#         if n_rep == 1:
#             return x
#         return (
#             # (B, Seq_Len, N_KV_Heads, 1, Head_Dim)
#             x[:, :, :, None, :]
#             # (B, Seq_Len, N_KV_Heads, N_Rep, Head_Dim)
#             .expand(batch_size, seq_len, n_kv_heads, n_rep, head_dim)
#             # (B, Seq_Len, N_KV_Heads * N_Rep, Head_Dim)
#             .reshape(batch_size, seq_len, n_kv_heads * n_rep, head_dim)
#         )
        
#     def forward(self,
#         x: torch.Tensor):
#         batch_size, seq_len, d_in = x.shape
        
#         xk = self.W_key(x) 
#         xq = self.W_query(x) 
#         xv = self.W_value(x)
        
#         xk = xk.view(batch_size, seq_len, self.n_kv_heads, self.head_dim) 
#         xv = xv.view(batch_size, seq_len, self.n_kv_heads, self.head_dim) 
#         xq = xq.view(batch_size, seq_len, self.n_heads_q,  self.head_dim) 
        
#         xq = rope(xq)
#         xk = rope(xk)
#         start_pos = 0
        
#         # Replace the entry in the cache
#         self.cache_k[:batch_size, start_pos : start_pos + seq_len] = xk
#         self.cache_v[:batch_size, start_pos : start_pos + seq_len] = xv

#         keys = self.cache_k[:batch_size, : start_pos + seq_len]
#         values = self.cache_v[:batch_size, : start_pos + seq_len]
        
#         # Since every group of Q shares the same K and V heads, just repeat the K and V heads for every Q in the same group.
#         keys = repeat_kv(keys, self.n_rep)
#         values = repeat_kv(values, self.n_rep)

#         xq = xq.transpose(1, 2)
#         keys = keys.transpose(1, 2)
#         values = values.transpose(1, 2)

#         scores = torch.matmul(xq, keys.transpose(2, 3)) / math.sqrt(self.head_dim)
#         scores = F.softmax(scores.float(), dim=-1).type_as(xq)

#         output = torch.matmul(scores, values)
#         output = (output.transpose(1, 2).contiguous().view(batch_size, seq_len, -1))
#         return self.out_proj(output) 

In [5]:
import torch
from torch import nn
import math
import torch.nn.functional as F

class MHA_KVCache(nn.Module):
    def __init__(self, args, qkv_bias=False):
        super().__init__()
        self.n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads
        self.n_heads_q = args.n_heads
        self.n_rep = self.n_heads_q // self.n_kv_heads
        self.head_dim = args.dim // args.n_heads
        self.dim = args.dim
        
        self.W_query = nn.Linear(self.dim, self.head_dim*self.n_heads_q , bias=qkv_bias).to("cuda")
        self.W_key = nn.Linear(self.dim, self.n_kv_heads*self.head_dim, bias=qkv_bias).to("cuda")
        self.W_value = nn.Linear(self.dim, self.n_kv_heads*self.head_dim, bias=qkv_bias).to("cuda")
        self.out_proj = nn.Linear(args.n_heads*self.head_dim, self.dim).to("cuda")
        
        self.cache_k = torch.zeros((args.max_batch_size, args.max_seq_len, self.n_kv_heads, self.head_dim)).to("cuda")
        self.cache_v = torch.zeros((args.max_batch_size, args.max_seq_len, self.n_kv_heads, self.head_dim)).to("cuda")

    @staticmethod
    def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor:
        batch_size, seq_len, n_kv_heads, head_dim = x.shape
        if n_rep == 1:
            return x
        return (
            x[:, :, :, None, :]
            .expand(batch_size, seq_len, n_kv_heads, n_rep, head_dim)
            .reshape(batch_size, seq_len, n_kv_heads * n_rep, head_dim)
        )
        
    def forward(self, x: torch.Tensor, start_pos: int = 0):
        batch_size, seq_len, _ = x.shape
        
        # Project queries, keys, values
        xq, xk, xv = self.W_query(x), self.W_key(x), self.W_value(x)
        
        # Reshape to (batch_size, seq_len, num_heads, head_dim)
        xq = xq.view(batch_size, seq_len, self.n_heads_q, self.head_dim).to("cuda")
        xk = xk.view(batch_size, seq_len, self.n_kv_heads, self.head_dim).to("cuda")
        xv = xv.view(batch_size, seq_len, self.n_kv_heads, self.head_dim).to("cuda")
        
        print(xq.shape, xk.shape, xv.shape)
        
        # Apply RoPE with position offset
        xq = rope(xq)
        xk = rope(xk)
        
        # Update KV cache
        self.cache_k[:batch_size, start_pos:start_pos+seq_len] = xk
        self.cache_v[:batch_size, start_pos:start_pos+seq_len] = xv
        
        # Retrieve cached keys/values including current sequence
        keys = self.cache_k[:batch_size, :start_pos+seq_len]
        values = self.cache_v[:batch_size, :start_pos+seq_len]
        
        # Repeat KV heads to match Q heads
        keys = self.repeat_kv(keys, self.n_rep)
        values = self.repeat_kv(values, self.n_rep)
        
        # Transpose for attention computation
        xq = xq.transpose(1, 2)  # (bs, n_heads_q, seq_len, hd)
        keys = keys.transpose(1, 2)  # (bs, n_heads_kv*rep, cache_len, hd)
        values = values.transpose(1, 2)  # (bs, n_heads_kv*rep, cache_len, hd)
        
        # Compute attention scores
        scores = torch.matmul(xq, keys.transpose(2, 3)) / math.sqrt(self.head_dim)
        scores = F.softmax(scores.float(), dim=-1).type_as(xq)
        
        # Weighted sum of values
        output = torch.matmul(scores, values)  # (bs, n_heads_q, seq_len, hd)
        
        # Combine heads and project
        output = output.transpose(1, 2).contiguous().view(batch_size, seq_len, -1)
        return self.out_proj(output).to("cuda")

In [6]:
mha = MHA_KVCache(ModelArgs)

In [7]:
x= torch.randn(1, 2048, 4096).to("cuda")

In [8]:
rope(x).shape

--------------------------------------------------
torch.Size([1, 2048, 2048]) torch.Size([1, 1, 2048, 2048])
--------------------------------------------------


torch.Size([1, 2048, 4096])

In [10]:
128//32

4

In [9]:
out= mha(x, start_pos=0)

torch.Size([1, 2048, 32, 128]) torch.Size([1, 2048, 32, 128]) torch.Size([1, 2048, 32, 128])
--------------------------------------------------
torch.Size([1, 2048, 32, 64]) torch.Size([1, 1, 2048, 2048])
--------------------------------------------------


RuntimeError: The size of tensor a (64) must match the size of tensor b (2048) at non-singleton dimension 3

In [9]:
out.shape , x.shape

NameError: name 'out' is not defined