In [16]:
import torch
import torch.nn as nn

In [17]:
class RMSNORM(nn.Module):
    def __init__(self,emb_dim, eps = 1e-5):
        super().__init__()
        self.eps = eps
        self.emb_dim = emb_dim
        self.weight = nn.Parameter(torch.ones(emb_dim)).float()

    def forward(self, x):
        means = x.pow(2).mean(dim = -1, keepdim = True)
        x_normed = x / torch.sqrt(self.eps + means)
        return (x_normed * self.weight).to(dtype=x.dtype)

In [18]:
class Silu(nn.Module):
    def __init__(self):
        super().__init__()
    def forward(self, x):    
        value = torch.sigmoid(x) * x 
        return value

In [19]:
class FeedForward(nn.Module):
    def __init__(self,cfg): #cfg is the configuration here
        super().__init__()
        self.layer1 = nn.Linear(
            cfg['emb_dim'],
            cfg['hidden_dim'],
            bias=False,
            dtype=cfg["dtype"]
        )
        self.layer2 = nn.Linear(
            cfg['emb_dim'],
            cfg['hidden_dim'],
            bias=False,
            dtype=cfg["dtype"]
        )
        self.layer3 = nn.Linear(
            cfg['hidden_dim'],
            cfg['emb_dim'],
            bias=False,
            dtype=cfg["dtype"]
            
        )
        self.silu = Silu()

    def forward(self,x):
        o1 = self.layer1(x)
        o1 = self.silu(o1)
        o2 = self.layer2(x)
        o1Xo2 = torch.dot(o1 ,o2)
        o3 = self.layer3(o1Xo2)

        return o3 

In [20]:
def precompute_rope_params(head_dim, device, theta_base=10_000, context_length=4096):
    assert head_dim % 2 == 0
    
    # theta(i) = 10000^(-2(i-1)/d)
    # where i = 1,2,3,4,5,6 ... d/2
    theta_numerator = torch.arange(0, head_dim, 2).float()
    print("theta_numerator:--->",theta_numerator.shape)
    
    theta = 1.0 / (theta_base ** (theta_numerator / head_dim)).to(device)
    print("theta_:--->",theta.shape)
    
    m = torch.arange(context_length, device=device)
    angles = m[:, None] * theta[None, :]  
    # the dim of angle is (context_length, head_dim/2)
    print("angles_shape-->",angles.shape)
    
    # This creates complex numbers in polar form (R=1, theta=angles)
    euler_form = torch.polar(torch.ones_like(angles), angles)
    print("euler_form:--->",euler_form.shape)
    
    return euler_form

def apply_rotary_embeddings(token_to_be_applied, euler_form, device):
    # Convert to complex representation
    print("token_to---->",token_to_be_applied.shape)
    x_complex = torch.view_as_complex(
        token_to_be_applied.float().reshape(*token_to_be_applied.shape[:-1], -1, 2)
    )
    print("X_complex---->",x_complex.shape)
    
   
    # Assuming token shape is (batch, seq_len, heads, head_dim)
    # euler_form shape is (context_length, head_dim/2)
    
    # We need to match the sequence length dimension, and broadcast across batch and heads
    # [:token_to_be_applied.shape[1]] ensures we only use angles for the actual sequence length
    reshaped_euler = euler_form[:token_to_be_applied.shape[1]].unsqueeze(0).unsqueeze(2)
    print("reshaped_euler--->",reshaped_euler.shape)
    
    rotated_embeddings = x_complex * reshaped_euler

    print("rotated_embeddings shape--->",rotated_embeddings.shape)
    
    # Convert back from complex to real
    x_out = torch.view_as_real(rotated_embeddings)
    print("x__out--->",x_out.shape)
    
    # Reshape back to original shape
   
    x_out = x_out.reshape(*token_to_be_applied.shape)
    print("x__out--->",x_out.shape)
    
    return x_out.type_as(token_to_be_applied).to(device)

In [21]:
device = "cuda" if torch.cuda.is_available() else "cpu"
class MultiHeadAttentionModule(nn.Module):
    def __init__(self, d_in, d_out, num_heads, context_length, dtype=None, dropout=0.1 ,device = device):
        super().__init__()
        assert d_out % num_heads == 0
        
        # self.d_in = d_in
        self.d_out = d_out
        self.num_heads = num_heads
        
        self.head_dim = d_out // num_heads
        
        self.W_query = torch.nn.Linear(
            in_features=d_in,
            out_features=d_out,
            device=device,
            dtype=dtype,
            bias=False
        )
        self.W_key = torch.nn.Linear(
            in_features=d_in,
            out_features=d_out,
            device=device,
            dtype=dtype,
            bias=False
        )
        self.W_value = torch.nn.Linear(
            in_features=d_in,
            out_features=d_out,
            device=device,
            dtype=dtype,
            bias=False
        )
        
        self.projection_layer = nn.Linear(
            d_out, d_out, bias=False, device=device, dtype=dtype
        )
        
        self.register_buffer("mask", torch.triu(torch.ones(context_length, context_length), diagonal=1))
        
        rope_params = precompute_rope_params(
            head_dim=self.head_dim,
            device=device,
        )
        self.register_buffer("rope_params", rope_params)
    
    def forward(self, x):
        b, num_tokens, dim = x.shape
        
        key_vec = self.W_key(x)
        query_vec = self.W_query(x)
        value_vec = self.W_value(x)
        
        keys = key_vec.view(
            b, num_tokens, self.num_heads, self.head_dim
        )
        values = value_vec.view(
            b, num_tokens, self.num_heads, self.head_dim
        )
        queries = query_vec.view(
            b, num_tokens, self.num_heads, self.head_dim
        )
        
        keys = keys.transpose(1, 2)
        queries = queries.transpose(1, 2)
        values = values.transpose(1, 2)
        
        # applying rope to key and query part
        
        keys = apply_rotary_embeddings(keys, self.rope_params,device=device) #(batch, num_heads, num_tokens, head_dim)
        queries = apply_rotary_embeddings(queries, self.rope_params,device=device) #(batch, num_heads, num_tokens, head_dim)
        
        attn_scores = torch.matmul(queries, keys.transpose(2, 3))
        
        mask_bool = self.mask.bool()[:num_tokens, :num_tokens]
        attn_scores.masked_fill_(mask_bool, -torch.inf)
        
        scaled_attn_scores = torch.softmax(attn_scores / torch.sqrt(torch.tensor(self.head_dim)), dim=-1)
        #shape of scaled attn-->(#(batch, num_heads, num_tokens, num_tokens)
        #shape of values-->(#(batch, num_heads, num_tokens, head_dim)
        
        context_vec = scaled_attn_scores @ values
        
        # shape of context_vec ----> shape of values-->(#(batch, num_heads, num_tokens, head_dim)
        
        context_vec = context_vec.transpose(1, 2).contiguous().view(b, num_tokens, self.d_out)
        
        context_vec = self.projection_layer(context_vec)
        
        return context_vec

In [22]:
llama_3_context_len = 8192
llama_3_theta_base = 500_000
    

In [32]:
# using shared buffers 

class SharedBuffers:
    _buffers = {}

    @staticmethod
    def get_buffers(context_length, head_dim, rope_base, dtype = torch.float32):
        key = (context_length, head_dim, rope_base, dtype)

        if key not in SharedBuffers._buffers:
            mask = torch.triu(torch.ones(context_length, context_length),diagonal=1)
            precomputed_rope_params = precompute_rope_params(head_dim=head_dim,device=device, theta_base=llama_3_theta_base, context_length = llama_3_context_len)

            if dtype is not None:
                precomputed_rope_params = precomputed_rope_params.to(dtype)

            SharedBuffers._buffers[key] = (mask , precomputed_rope_params)    

        return SharedBuffers._buffers[key]    



In [36]:
class GroupedQueryAttention(nn.Module):
    def __init__(
            self,
            d_in, 
            d_out, 
            context_length,
            num_heads,
            num_kv_groups,
            rope_base=50_000,
            rope_config=None,
            dtype=None,
            device=None 
    ):
        super().__init__()
        assert d_out % num_heads == 0
        assert num_heads % num_kv_groups == 0
        self.num_heads = num_heads
        self.head_dim = d_out // num_heads  
        self.d_out = d_out  

        self.W_key = nn.Linear(
            d_in, 
            self.head_dim * num_kv_groups,
            device=device,
            bias=False,
            dtype=dtype  
        )
        self.W_value = nn.Linear(
            d_in, 
            self.head_dim * num_kv_groups,
            device=device,
            bias=False,
            dtype=dtype  
        )
        self.num_kv_groups = num_kv_groups
        self.group_size = num_heads // num_kv_groups  

        self.W_query = nn.Linear(
            d_in, d_out, bias=False, dtype=dtype, device=device 
        )
        self.out_proj = nn.Linear(
            d_out, d_out, device=device, dtype=dtype
        )

        mask, rope_config = SharedBuffers.get_buffers(
            context_length=context_length, 
            head_dim=self.head_dim,
            rope_base=rope_base,
            dtype=dtype
        )

        self.register_buffer("mask", mask)
        self.register_buffer("computed_rope", rope_config)

    def forward(self, x):
        b, num_tokens, d_in = x.shape 

        queries = self.W_query(x)
        keys = self.W_key(x)
        values = self.W_value(x)

        queries = queries.view(b, num_tokens, self.num_heads, self.head_dim)

        keys = keys.view(b, num_tokens, self.num_kv_groups, self.head_dim)
        values = values.view(b, num_tokens, self.num_kv_groups, self.head_dim)

        keys = keys.transpose(1, 2)
        values = values.transpose(1, 2)
        queries = queries.transpose(1, 2)

       
        keys = apply_rotary_embeddings(keys, self.computed_rope,device=device)
        queries = apply_rotary_embeddings(queries, self.computed_rope,device=device)  

        
        keys = keys.repeat_interleave(self.group_size, dim=1)
        values = values.repeat_interleave(self.group_size, dim=1)

        
        attn_scores = queries @ keys.transpose(2, 3)

        mask_bool = self.mask.bool()[:num_tokens, :num_tokens]

        attn_scores.masked_fill_(mask_bool, -torch.inf)

        attn_weights = torch.softmax(attn_scores / self.head_dim**0.5, dim=-1) 

        
        
       
        context_vec = (attn_weights @ values).transpose(1, 2)

        context_vec = context_vec.reshape(b, num_tokens, self.d_out)
        context_vec = self.out_proj(context_vec)  

        return context_vec

In [37]:
batch_size = 1
context_len = 3000
max_context_len = 8192
embed_dim = 4096
num_heads = 32


example_batch = torch.randn((batch_size, context_len, embed_dim))

mha = MultiHeadAttentionModule(
    d_in=embed_dim,
    d_out=embed_dim,
    context_length=max_context_len,
    num_heads=num_heads
)

mha(example_batch)

print("W_key:", mha.W_key.weight.shape)
print("W_value:", mha.W_value.weight.shape)
print("W_query:", mha.W_query.weight.shape)

theta_numerator:---> torch.Size([64])
theta_:---> torch.Size([64])
angles_shape--> torch.Size([4096, 64])
euler_form:---> torch.Size([4096, 64])
token_to----> torch.Size([1, 32, 3000, 128])
X_complex----> torch.Size([1, 32, 3000, 64])
reshaped_euler---> torch.Size([1, 32, 1, 64])
rotated_embeddings shape---> torch.Size([1, 32, 3000, 64])
x__out---> torch.Size([1, 32, 3000, 64, 2])
x__out---> torch.Size([1, 32, 3000, 128])
token_to----> torch.Size([1, 32, 3000, 128])
X_complex----> torch.Size([1, 32, 3000, 64])
reshaped_euler---> torch.Size([1, 32, 1, 64])
rotated_embeddings shape---> torch.Size([1, 32, 3000, 64])
x__out---> torch.Size([1, 32, 3000, 64, 2])
x__out---> torch.Size([1, 32, 3000, 128])
W_key: torch.Size([4096, 4096])
W_value: torch.Size([4096, 4096])
W_query: torch.Size([4096, 4096])


In [38]:
gqa = GroupedQueryAttention(
    d_in=embed_dim,
    d_out=embed_dim,
    context_length=max_context_len,
    num_heads=num_heads,
    num_kv_groups=8,
    rope_base=llama_3_theta_base
)

gqa(example_batch)

print("W_key:", gqa.W_key.weight.shape)
print("W_value:", gqa.W_value.weight.shape)
print("W_query:", gqa.W_query.weight.shape)

token_to----> torch.Size([1, 8, 3000, 128])
X_complex----> torch.Size([1, 8, 3000, 64])
reshaped_euler---> torch.Size([1, 8, 1, 64])
rotated_embeddings shape---> torch.Size([1, 8, 3000, 64])
x__out---> torch.Size([1, 8, 3000, 64, 2])
x__out---> torch.Size([1, 8, 3000, 128])
token_to----> torch.Size([1, 32, 3000, 128])
X_complex----> torch.Size([1, 32, 3000, 64])
reshaped_euler---> torch.Size([1, 32, 1, 64])
rotated_embeddings shape---> torch.Size([1, 32, 3000, 64])
x__out---> torch.Size([1, 32, 3000, 64, 2])
x__out---> torch.Size([1, 32, 3000, 128])
W_key: torch.Size([1024, 4096])
W_value: torch.Size([1024, 4096])
W_query: torch.Size([4096, 4096])


In [39]:
del mha
del gqa

In [40]:
class TransformerBlock(nn.Module):
    def __init__(self,cfg):
        super().__init__()
        self.attn = GroupedQueryAttention(
            d_in=cfg['emb_dim'],
            d_out=cfg['emb_dim'],
            context_length=cfg['context_length'],
            num_heads=cfg['n_heads'],
            num_kv_groups=cfg['n_kv_groups'],
            rope_base=cfg['rope_base'],
            rope_config=None,
            dtype=cfg['dtype']
        )
        self.FFN = FeedForward(cfg)
        self.norm1 = RMSNORM(emb_dim=cfg["emb_dim"],eps=1e-5)
        self.norm2 = RMSNORM(emb_dim=cfg["emb_dim"],eps=1e-5)

    def forward(self,x):
        shortcut = x
        x = self.norm1(x)
        x = self.attn(x)
        x = shortcut + x 
        shortcut = x 
        x =  self.norm2(x)
        x =  self.FFN(x)
        x = x + shortcut


        return x 

In [41]:
class LLAMA3MODEL(nn.Module):
    def __init__(self,cfg):
        super().__init__()

        self.embeddingLayer = nn.Embedding(
            cfg["vocab_size"],
            cfg["emb_dim"],
            dtype=cfg['dtype']
        )
        self.trf_blocks = nn.Sequential(
            *[TransformerBlock(cfg)for _ in range(cfg["num_layers"])]
        )
        self.finalNorm = RMSNORM(cfg["emb_dim"],eps=1e-5)
        self.out_head = nn.Linear(cfg['emb_dim'],cfg["vocab_dim"],bias=False,dtype=cfg['dtype'])

    def forward(self,in_idx):
        tok_embeds = self.embeddingLayer(in_idx)
        x = tok_embeds
        x = self.trf_blocks(x)
        x = self.finalNorm(x)

        logits = self.out_head(x.to(torch.bfloat16))
        return logits    



In [42]:
LLAMA2_CONFIG_7B = {
    "vocab_size": 32_000,    # Vocabulary size
    "context_length": 4096,  # Context length
    "emb_dim": 4096,         # Embedding dimension
    "n_heads": 32,           # Number of attention heads
    "n_layers": 32,          # Number of layers
    "hidden_dim": 11_008,    # Size of the intermediate dimension in FeedForward
    "dtype": torch.bfloat16  # Lower-precision dtype to reduce memory usage
}