In [1]:
import torch
import torch.nn as nn

In [18]:
class RMSNORM(nn.Module):
    def __init__(self,emb_dim, eps = 1e-5):
        super().__init__()
        self.eps = eps
        self.emb_dim = emb_dim
        self.weight = nn.Parameter(torch.ones(emb_dim)).float()

    def forward(self, x):
        means = x.pow(2).mean(dim = -1, keepdim = True)
        x_normed = x / torch.sqrt(self.eps + means)
        return (x_normed * self.weight).to(dtype=x.dtype)


In [21]:
torch.manual_seed(123)

example_batch = torch.randn(2, 3, 4)

rms_norm = RMSNORM(emb_dim=example_batch.shape[-1])

rmsnorm_pytorch = torch.nn.RMSNorm(example_batch.shape[-1], eps=1e-5)


In [22]:
rms_norm(example_batch)

tensor([[[ 0.8834, -0.4655, -0.7948, -1.5398],
         [ 0.8053,  1.5254, -0.5074, -0.8759],
         [-0.1619, -0.4049,  0.7016, -1.8214]],

        [[ 1.3348, -0.1232,  0.9638, -1.1288],
         [-0.8511, -1.5285, -0.4420,  0.8624],
         [ 0.9398,  0.2082,  1.7241, -0.3180]]], grad_fn=<MulBackward0>)

In [24]:
rmsnorm_pytorch(example_batch)

tensor([[[ 0.8834, -0.4655, -0.7948, -1.5398],
         [ 0.8053,  1.5254, -0.5074, -0.8759],
         [-0.1619, -0.4049,  0.7016, -1.8214]],

        [[ 1.3348, -0.1232,  0.9638, -1.1288],
         [-0.8511, -1.5285, -0.4420,  0.8624],
         [ 0.9398,  0.2082,  1.7241, -0.3180]]], grad_fn=<MulBackward0>)

In [36]:
class Silu(nn.Module):
    def __init__(self):
        super().__init__()
    def forward(self, x):    
        value = torch.sigmoid(x) * x 
        return value
        

In [37]:
fn = Silu()
value = fn(example_batch)

In [38]:
value

tensor([[[ 0.1969, -0.0810, -0.1289, -0.2100],
         [ 0.2044,  0.4354, -0.0978, -0.1541],
         [-0.0739, -0.1610,  0.4642, -0.2549]],

        [[ 1.6484, -0.0799,  1.0913, -0.2686],
         [-0.2459, -0.2767, -0.1628,  0.5480],
         [ 0.7297,  0.1228,  1.5791, -0.1407]]])

In [41]:
class FeedForward(nn.Module):
    def __init__(self,cfg): #cfg is the configuration here
        super().__init__()
        self.layer1 = nn.Linear(
            cfg['emb_dim'],
            cfg['hidden_dim'],
            bias=False,
            dtype=cfg["dtype"]
        )
        self.layer2 = nn.Linear(
            cfg['emb_dim'],
            cfg['hidden_dim'],
            bias=False,
            dtype=cfg["dtype"]
        )
        self.layer3 = nn.Linear(
            cfg['hidden_dim'],
            cfg['emb_dim'],
            bias=False,
            dtype=cfg["dtype"]
            
        )
        self.silu = Silu()

    def forward(self,x):
        o1 = self.layer1(x)
        o1 = self.silu(o1)
        o2 = self.layer2(x)
        o1Xo2 = torch.dot(o1 ,o2)
        o3 = self.layer3(o1Xo2)

        return o3 

In [94]:
def precompute_rope_params(head_dim, device, theta_base=10_000, context_length=4096):
    assert head_dim % 2 == 0
    
    # theta(i) = 10000^(-2(i-1)/d)
    # where i = 1,2,3,4,5,6 ... d/2
    theta_numerator = torch.arange(0, head_dim, 2).float()
    print("theta_numerator:--->",theta_numerator.shape)
    
    theta = 1.0 / (theta_base ** (theta_numerator / head_dim)).to(device)
    print("theta_:--->",theta.shape)
    
    m = torch.arange(context_length, device=device)
    angles = m[:, None] * theta[None, :]  
    # the dim of angle is (context_length, head_dim/2)
    print("angles_shape-->",angles.shape)
    
    # This creates complex numbers in polar form (R=1, theta=angles)
    euler_form = torch.polar(torch.ones_like(angles), angles)
    print("euler_form:--->",euler_form.shape)
    
    return euler_form

def apply_rotary_embeddings(token_to_be_applied, euler_form, device):
    # Convert to complex representation
    print("token_to---->",token_to_be_applied.shape)
    x_complex = torch.view_as_complex(
        token_to_be_applied.float().reshape(*token_to_be_applied.shape[:-1], -1, 2)
    )
    print("X_complex---->",x_complex.shape)
    
   
    # Assuming token shape is (batch, seq_len, heads, head_dim)
    # euler_form shape is (context_length, head_dim/2)
    
    # We need to match the sequence length dimension, and broadcast across batch and heads
    # [:token_to_be_applied.shape[1]] ensures we only use angles for the actual sequence length
    reshaped_euler = euler_form[:token_to_be_applied.shape[1]].unsqueeze(0).unsqueeze(2)
    print("reshaped_euler--->",reshaped_euler.shape)
    
    rotated_embeddings = x_complex * reshaped_euler

    print("rotated_embeddings shape--->",rotated_embeddings.shape)
    
    # Convert back from complex to real
    x_out = torch.view_as_real(rotated_embeddings)
    print("x__out--->",x_out.shape)
    
    # Reshape back to original shape
   
    x_out = x_out.reshape(*token_to_be_applied.shape)
    print("x__out--->",x_out.shape)
    
    return x_out.type_as(token_to_be_applied).to(device)

In [95]:
device = "cuda" if torch.cuda.is_available() else "cpu"
print(device)

cpu


In [96]:
batch_size = 2
context_len = 5
num_heads = 4
head_dim = 16

rope_params = precompute_rope_params(head_dim, device)

theta_numerator:---> torch.Size([8])
theta_:---> torch.Size([8])
angles_shape--> torch.Size([4096, 8])
euler_form:---> torch.Size([4096, 8])


In [97]:
rope_params.shape

torch.Size([4096, 8])

In [98]:
sample_tokens = torch.randn(batch_size, context_len, num_heads, head_dim, device=device)

In [99]:
sample_tokens.shape


torch.Size([2, 5, 4, 16])

In [100]:
rotated_tokens = apply_rotary_embeddings(sample_tokens, rope_params, device)

token_to----> torch.Size([2, 5, 4, 16])
X_complex----> torch.Size([2, 5, 4, 8])
reshaped_euler---> torch.Size([1, 5, 1, 8])
rotated_embeddings shape---> torch.Size([2, 5, 4, 8])
x__out---> torch.Size([2, 5, 4, 8, 2])
x__out---> torch.Size([2, 5, 4, 16])


In [101]:
rotated_tokens.shape

torch.Size([2, 5, 4, 16])

In [108]:
class MultiHeadAttentionModule(nn.Module):
    def __init__(self, d_in, d_out, num_heads, context_length, dtype=None, dropout=0.1):
        super().__init__()
        assert d_out % num_heads == 0
        
        # self.d_in = d_in
        self.d_out = d_out
        self.num_heads = num_heads
        
        self.head_dim = d_out // num_heads
        
        self.W_query = torch.nn.Linear(
            in_features=d_in,
            out_features=d_out,
            device=device,
            dtype=dtype,
            bias=False
        )
        self.W_key = torch.nn.Linear(
            in_features=d_in,
            out_features=d_out,
            device=device,
            dtype=dtype,
            bias=False
        )
        self.W_value = torch.nn.Linear(
            in_features=d_in,
            out_features=d_out,
            device=device,
            dtype=dtype,
            bias=False
        )
        
        self.projection_layer = nn.Linear(
            d_out, d_out, bias=False, device=device, dtype=dtype
        )
        
        self.register_buffer("mask", torch.triu(torch.ones(context_length, context_length), diagonal=1))
        
        rope_params = precompute_rope_params(
            head_dim=self.head_dim,
            device=device,
        )
        self.register_buffer("rope_params", rope_params)
    
    def forward(self, x):
        b, num_tokens, dim = x.shape
        
        key_vec = self.W_key(x)
        query_vec = self.W_query(x)
        value_vec = self.W_value(x)
        
        keys = key_vec.view(
            b, num_tokens, self.num_heads, self.head_dim
        )
        values = value_vec.view(
            b, num_tokens, self.num_heads, self.head_dim
        )
        queries = query_vec.view(
            b, num_tokens, self.num_heads, self.head_dim
        )
        
        keys = keys.transpose(1, 2)
        queries = queries.transpose(1, 2)
        values = values.transpose(1, 2)
        
        # applying rope to key and query part
        
        keys = apply_rotary_embeddings(keys, self.rope_params,device=device) #(batch, num_heads, num_tokens, head_dim)
        queries = apply_rotary_embeddings(queries, self.rope_params,device=device) #(batch, num_heads, num_tokens, head_dim)
        
        attn_scores = torch.matmul(queries, keys.transpose(2, 3))
        
        mask_bool = self.mask.bool()[:num_tokens, :num_tokens]
        attn_scores.masked_fill_(mask_bool, -torch.inf)
        
        scaled_attn_scores = torch.softmax(attn_scores / torch.sqrt(torch.tensor(self.head_dim)), dim=-1)
        #shape of scaled attn-->(#(batch, num_heads, num_tokens, num_tokens)
        #shape of values-->(#(batch, num_heads, num_tokens, head_dim)
        
        context_vec = scaled_attn_scores @ values
        
        # shape of context_vec ----> shape of values-->(#(batch, num_heads, num_tokens, head_dim)
        
        context_vec = context_vec.transpose(1, 2).contiguous().view(b, num_tokens, self.d_out)
        
        context_vec = self.projection_layer(context_vec)
        
        return context_vec

In [109]:
batch_size = 1
context_len = 100
max_context_len = 4096
embed_dim = 128
num_heads = 4


example_batch = torch.randn((batch_size, context_len, embed_dim))

mha = MultiHeadAttentionModule(
    d_in=embed_dim,
    d_out=embed_dim,
    context_length=max_context_len,
    num_heads=num_heads
)

mha(example_batch)

theta_numerator:---> torch.Size([16])
theta_:---> torch.Size([16])
angles_shape--> torch.Size([4096, 16])
euler_form:---> torch.Size([4096, 16])
token_to----> torch.Size([1, 4, 100, 32])
X_complex----> torch.Size([1, 4, 100, 16])
reshaped_euler---> torch.Size([1, 4, 1, 16])
rotated_embeddings shape---> torch.Size([1, 4, 100, 16])
x__out---> torch.Size([1, 4, 100, 16, 2])
x__out---> torch.Size([1, 4, 100, 32])
token_to----> torch.Size([1, 4, 100, 32])
X_complex----> torch.Size([1, 4, 100, 16])
reshaped_euler---> torch.Size([1, 4, 1, 16])
rotated_embeddings shape---> torch.Size([1, 4, 100, 16])
x__out---> torch.Size([1, 4, 100, 16, 2])
x__out---> torch.Size([1, 4, 100, 32])


tensor([[[-0.3530, -0.2181, -0.2645,  ...,  0.0304, -0.5108, -0.3729],
         [-0.3053,  0.2656,  0.0188,  ...,  0.2956, -0.0616, -0.2711],
         [-0.5983,  0.3538,  0.0800,  ...,  0.1081,  0.1835, -0.1742],
         ...,
         [-0.0266,  0.0862,  0.0463,  ...,  0.0458,  0.0166, -0.0078],
         [-0.0265,  0.0971,  0.0789,  ...,  0.0437,  0.0252,  0.0165],
         [-0.0232,  0.0928,  0.0422,  ...,  0.0566,  0.0107,  0.0040]]],
       grad_fn=<UnsafeViewBackward0>)

In [110]:
del mha

In [112]:
class Transformer(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.att = MultiHeadAttentionModule(
            d_in=cfg["emb_dim"],
            d_out=cfg["emb_dim"],
            context_length=cfg["context_length"],
            num_heads=cfg["n_heads"],
            dtype=cfg["dtype"]  
            
        )
        self.ff = FeedForward(cfg)
        self.norm1 = RMSNORM(cfg["emb_dim"])
        self.norm2 = RMSNORM(cfg["emb_dim"])

    def forward(self,x):
        shortcut = x 
        x = self.norm1(x)
        x = self.att(x)
        x = x + shortcut
        shortcut = x 
        x = self.norm2(x)
        x = self.ff(x)
        x = x +shortcut
        x = rms_norm(x)
        return x

In [113]:
class LLAMA2(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.tok_emb = nn.Embedding(cfg["vocab_size"], cfg["emb_dim"], dtype=cfg["dtype"])

        self.trf_blocks = nn.Sequential(
            *[Transformer(cfg) for _ in range(cfg["n_layers"])])
        
        self.final_norm = RMSNORM(cfg["emb_dim"])
        self.out_head = nn.Linear(cfg["emb_dim"], cfg["vocab_size"], bias=False, dtype=cfg["dtype"])

    def forward(self, in_idx):
       
        tok_embeds = self.tok_emb(in_idx)
        x = tok_embeds  
        x = self.trf_blocks(x)
        x = self.final_norm(x)
        logits = self.out_head(x)
        return logits
        

In [114]:
LLAMA2_CONFIG_7B = {
    "vocab_size": 32000,     # Vocabulary size
    "context_length": 4096,  # Context length
    "emb_dim": 4096,         # Embedding dimension
    "n_heads": 32,           # Number of attention heads
    "n_layers": 32,          # Number of layers
    "hidden_dim": 11008,     # NEW: Size of the intermediate dimension in FeedForward
    "dtype": torch.bfloat16  # NEW: Lower-precision dtype to reduce memory usage
}