In [1]:
import sys
sys.path.append("..")

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F

from llm_foundry.model.layers import RMSNorm, FeedForward
from llm_foundry.model.attention import MultiHeadAttention

In [3]:
## Laying down the overall architecture of the LLM.
"""
Overview of the LLM architecture being built here.

Input -> Tokenization -> Embedding -> Stack of Transformer Blocks -> Output Layer

Transformer Block:
    - Input
    - Residual Connection + Layer Norm -> Multi-Head Attention (with RoPE for positional encoding) -> Dropout -> Add Input (shortcut connection)
    - Residual Connection + Layer Norm -> Feed-Forward Network -> Dropout -> Add Previous (shortcut connection)
    - Output

The key improvements:
1. Shortcut connections (residual connections) around both attention and FFN components
2. Multiple stacked transformer blocks for deep representation
3. RoPE (Rotary Position Embedding) integrated directly into attention mechanism 
   instead of separate positional encoding
4. Layer normalization before attention and FFN (Pre-LN architecture)
"""



'\nOverview of the LLM architecture being built here.\n\nInput -> Tokenization -> Embedding -> Stack of Transformer Blocks -> Output Layer\n\nTransformer Block:\n    - Input\n    - Residual Connection + Layer Norm -> Multi-Head Attention (with RoPE for positional encoding) -> Dropout -> Add Input (shortcut connection)\n    - Residual Connection + Layer Norm -> Feed-Forward Network -> Dropout -> Add Previous (shortcut connection)\n    - Output\n\nThe key improvements:\n1. Shortcut connections (residual connections) around both attention and FFN components\n2. Multiple stacked transformer blocks for deep representation\n3. RoPE (Rotary Position Embedding) integrated directly into attention mechanism \n   instead of separate positional encoding\n4. Layer normalization before attention and FFN (Pre-LN architecture)\n'

In [4]:
from llm_foundry.utils.config import load_config

In [5]:
config = load_config(path="../configs/llm_270m.yaml")

print(config.keys())
print(config["model"].keys())

dict_keys(['data', 'model', 'training'])
dict_keys(['vocab_size', 'context_length', 'emb_dim', 'n_heads', 'n_layers', 'hidden_dim', 'head_dim', 'qk_norm', 'n_kv_groups', 'rope_local_base', 'rope_base', 'layer_types', 'dtype', 'query_pre_attn_scalar'])


In [6]:
model_cfg = config["model"]

for key, value in model_cfg.items():
    print(f"{key}: {value}")

vocab_size: 50257
context_length: 1024
emb_dim: 1024
n_heads: 8
n_layers: 12
hidden_dim: 2048
head_dim: 256
qk_norm: True
n_kv_groups: 8
rope_local_base: 10000.0
rope_base: 1000000.0
layer_types: ['full_attention', 'full_attention', 'full_attention', 'full_attention', 'full_attention', 'full_attention', 'full_attention', 'full_attention', 'full_attention', 'full_attention', 'full_attention', 'full_attention']
dtype: bfloat16
query_pre_attn_scalar: 256


In [7]:
input_ids = torch.randint(0, model_cfg["vocab_size"], (2, 8))  # (batch_size, seq_length)

print("Input IDs shape:", input_ids.shape)  # Should be (2, 8)


Input IDs shape: torch.Size([2, 8])


In [8]:
tkn_emb = nn.Embedding(model_cfg["vocab_size"], model_cfg["emb_dim"])

x = tkn_emb(input_ids)  # (batch_size, seq_length, emb_dim)
print("Token Embeddings shape:", x.shape)  # Should be (2, 8, emb_dim)

## Scale the input independent of the embedding dimension
x = x * (model_cfg["emb_dim"] ** 0.5)
print("Scaled Embeddings shape:", x.shape)  # Should be (2, 8, emb_dim)

Token Embeddings shape: torch.Size([2, 8, 1024])
Scaled Embeddings shape: torch.Size([2, 8, 1024])


In [9]:
## Attention

att = MultiHeadAttention(
    d_in=model_cfg["emb_dim"],
    num_heads=model_cfg["n_heads"],
    head_dim=model_cfg["head_dim"],
    qk_norm=model_cfg["qk_norm"],
    dtype=torch.float32,
)

## FF Layer
ff = FeedForward(model_cfg)

## Norm Layer
input_layer_RMSNorm = RMSNorm(model_cfg["emb_dim"], eps=1e-6)
post_att_layer_RMSNorm = RMSNorm(model_cfg["emb_dim"], eps=1e-6)

In [10]:
## Running Attention
shortcut = x
x_norm = input_layer_RMSNorm(x)
x_att = att(x_norm, mask=True)  # (batch_size, seq_length, emb_dim)

x = shortcut + x_att  # Residual connection
print("Post-Attention shape:", x.shape)  # Should be (2, 8, emb_dim)

Post-Attention shape: torch.Size([2, 8, 1024])


In [11]:
## Running FF
shortcut = x
x_norm = post_att_layer_RMSNorm(x).to(torch.bfloat16)
x_ff = ff(x_norm)  # (batch_size, seq_length, emb_dim)
x = shortcut + x_ff  # Residual connection

print("Post-FF shape:", x.shape)  # Should be (2, 8, emb_dim)

Post-FF shape: torch.Size([2, 8, 1024])


In [12]:
## Stacking multiple transformer blocks
block = [
    {
        "att": MultiHeadAttention(model_cfg["emb_dim"], model_cfg["n_heads"], model_cfg["head_dim"], qk_norm=model_cfg["qk_norm"]),
        "ff": FeedForward(model_cfg),
        "ln1": RMSNorm(model_cfg["emb_dim"], eps=1e-6),
        "ln2": RMSNorm(model_cfg["emb_dim"], eps=1e-6)
    }
    for number_of_layers in range(model_cfg["n_layers"])
]

In [13]:
block[0]

{'att': MultiHeadAttention(
   (W_query): Linear(in_features=1024, out_features=2048, bias=False)
   (W_key): Linear(in_features=1024, out_features=2048, bias=False)
   (W_value): Linear(in_features=1024, out_features=2048, bias=False)
   (out_proj): Linear(in_features=2048, out_features=1024, bias=False)
   (q_norm): RMSNorm()
   (k_norm): RMSNorm()
 ),
 'ff': FeedForward(
   (fc1): Linear(in_features=1024, out_features=2048, bias=False)
   (fc2): Linear(in_features=1024, out_features=2048, bias=False)
   (fc3): Linear(in_features=2048, out_features=1024, bias=False)
 ),
 'ln1': RMSNorm(),
 'ln2': RMSNorm()}

In [14]:
x = tkn_emb(input_ids)* (model_cfg["emb_dim"] ** 0.5)

for i, block in enumerate(block):
    # Attention
    shortcut = x
    x_norm = block["ln1"](x)
    x_att = block["att"](x_norm, mask=True)
    x = shortcut + x_att  # Residual connection

    # Feed Forward
    shortcut = x
    x_norm = block["ln2"](x).to(torch.bfloat16)
    x_ff = block["ff"](x_norm)
    x = shortcut + x_ff  # Residual connection

    print(f"After Block {i+1}, shape: {x.shape}")  # Should be (2, 8, emb_dim)

After Block 1, shape: torch.Size([2, 8, 1024])
After Block 2, shape: torch.Size([2, 8, 1024])
After Block 3, shape: torch.Size([2, 8, 1024])
After Block 4, shape: torch.Size([2, 8, 1024])
After Block 5, shape: torch.Size([2, 8, 1024])
After Block 6, shape: torch.Size([2, 8, 1024])
After Block 7, shape: torch.Size([2, 8, 1024])
After Block 8, shape: torch.Size([2, 8, 1024])
After Block 9, shape: torch.Size([2, 8, 1024])
After Block 10, shape: torch.Size([2, 8, 1024])
After Block 11, shape: torch.Size([2, 8, 1024])
After Block 12, shape: torch.Size([2, 8, 1024])


In [15]:
## Final Layer Norm
final_norm = RMSNorm(model_cfg["emb_dim"], eps=1e-6)
out_head = nn.Linear(model_cfg["emb_dim"], model_cfg["vocab_size"])

x = final_norm(x)
logits = out_head(x)  # (batch_size, seq_length, vocab_size)

print("Logits shape:", logits.shape)  # Should be (2, 8, vocab_size)

Logits shape: torch.Size([2, 8, 50257])


In [16]:
### Computing Loss

targets = torch.randint(0, model_cfg["vocab_size"], (2, 8))     # (batch_size, seq_length)

loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))

print("Loss:", loss.item())

Loss: 11.239240646362305


In [22]:
blocks = [
    {
        "att": MultiHeadAttention(model_cfg["emb_dim"], model_cfg["n_heads"], model_cfg["head_dim"], qk_norm=model_cfg["qk_norm"]),
        "ff": FeedForward(model_cfg),
        "ln1": RMSNorm(model_cfg["emb_dim"], eps=1e-6),
        "ln2": RMSNorm(model_cfg["emb_dim"], eps=1e-6)
    }
    for _ in range(model_cfg["n_layers"])  # Fixed: use '_' instead of 'number_of_layers'
]

In [25]:
### Generate Loop
def generate(input_ids, max_new_tokens, blocks, temperature=1.0, top_k=None):
    for _ in range(max_new_tokens):
        x = tkn_emb(input_ids) * (model_cfg["emb_dim"] ** 0.5)
        for block in blocks:
            sc = x
            x_norm = block["ln1"](x)
            x_att = block["att"](x_norm, mask=None)
            x = sc + x_att

            sc = x
            x_norm = block["ln2"](x).to(torch.bfloat16)
            x_ff = block["ff"](x_norm)
            x = sc + x_ff

        x = final_norm(x)
        logits = out_head(x[:, -1, :]) / temperature  # Focus on the last token only

        if top_k is not None:
            v, index = torch.topk(logits, top_k)
            # Smallest values for each.
            smallest = v[:, -1].unsqueeze(1)
            logits[logits < smallest] = -float('Inf')

        probs = logits.softmax(dim=-1)
        next_token = torch.multinomial(probs, num_samples=1)
        input_ids = torch.cat((input_ids, next_token), dim=1)
    
    return input_ids

In [26]:
start_ids = torch.randint(0, model_cfg["vocab_size"], (1, 5))
print("Generated:", generate(start_ids, max_new_tokens=5, blocks=blocks))

Generated: tensor([[24160, 14289, 15789, 45289,  8577, 10633, 46144, 39431, 26047, 25751]])
