In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

### RMS norm and it is similar to the torch inbuilt Rms norm

In [None]:
class RMSNorm(nn.Module):
    def __init__(self, embed_dim, eps = 1e-5):
        super().__init__()
        self.eps = eps
        self.embed_dim = embed_dim
        self.rms_weights = nn.Parameter(torch.ones(embed_dim)).float()
    
    def forward(self, x):
        means = x.pow(2).mean(dim=-1, keepdim=True)
        x = x * torch.rsqrt(means + self.eps)
        x = x * self.rms_weights
        return x.to(dtype=x.dtype)

In [44]:
x = torch.randn(1, 2, 10) # batch, seq, embed
rms = RMSNorm(x.shape[-1], eps=1e-5)
torch_rms_norm = nn.RMSNorm(x.shape[-1], eps=1e-5)

print(torch.allclose(rms(x), torch_rms_norm(x)))

True


### Testing Config 

In [48]:
from types import SimpleNamespace
config = {
    "dtype" : torch.bfloat16,
    "embed_dim" : 512,
    "hidden_dim" : 2048,
}
config = SimpleNamespace(**config)

In [51]:
type(config)

types.SimpleNamespace

### Computing and initilizing rope embeddings

In [52]:
def precompute_rope_params(head_dim, theta_base=10_000, context_length=4096):
    assert head_dim % 2 == 0, "Embedding dimension must be even"

    # Compute the inverse frequencies
    inv_freq = 1.0 / (theta_base ** (torch.arange(0, head_dim, 2)[: (head_dim // 2)].float() / head_dim))

    # Generate position indices
    positions = torch.arange(context_length)

    # Compute the angles
    angles = positions[:, None] * inv_freq[None, :]  # Shape: (context_length, head_dim // 2)

    # Expand angles to match the head_dim
    angles = torch.cat([angles, angles], dim=1)  # Shape: (context_length, head_dim)

    # Precompute sine and cosine
    cos = torch.cos(angles)
    sin = torch.sin(angles)

    return cos, sin

def compute_rope(x, cos, sin):
    # x: (batch_size, num_heads, seq_len, head_dim)
    batch_size, num_heads, seq_len, head_dim = x.shape
    assert head_dim % 2 == 0, "Head dimension must be even"

    # Split x into first half and second half
    x1 = x[..., : head_dim // 2]  # First half
    x2 = x[..., head_dim // 2 :]  # Second half

    # Adjust sin and cos shapes
    cos = cos[:seq_len, :].unsqueeze(0).unsqueeze(0)  # Shape: (1, 1, seq_len, head_dim)
    sin = sin[:seq_len, :].unsqueeze(0).unsqueeze(0)

    # Apply the rotary transformation
    rotated = torch.cat((-x2, x1), dim=-1)
    x_rotated = (x * cos) + (rotated * sin)

    return x_rotated.to(dtype=x.dtype)

In [59]:
# Settings
batch_size = 2
context_len = 5
num_heads = 4
head_dim = 16

# Instantiate RoPE parameters
cos, sin = precompute_rope_params(head_dim=head_dim, context_length=context_len)

# Dummy query and key tensors
torch.manual_seed(123)
queries = torch.randn(batch_size, num_heads, context_len, head_dim)
keys = torch.randn(batch_size, num_heads, context_len, head_dim)

# Apply rotary position embeddings
queries_rot = compute_rope(queries, cos, sin)
keys_rot = compute_rope(keys, cos, sin)

In [60]:
queries_rot.shape, queries.shape

(torch.Size([2, 4, 5, 16]), torch.Size([2, 4, 5, 16]))

In [3]:
from huggingface_hub import list_models

# List all models containing "bert" in their names
models = list_models(search="Llama2-7B")
for model in models:
    print(model.modelId)
    break


meta-llama/Llama-2-7b-chat-hf


In [5]:
from huggingface_hub import hf_hub_download, login
from dotenv import load_dotenv
import os
load_dotenv()
login(token=os.getenv("HF_ACCESS_TOKEN"))
tokenizer_file = hf_hub_download(
    repo_id="meta-llama/Llama-3.1-8B",
    filename="tokenizer.json",
    local_dir="Llama-3.1-8B"
)

tokenizer.json:   0%|          | 0.00/9.09M [00:00<?, ?B/s]

str