# Week 09: Advanced LLM Architectures

Exploring modern efficiency techniques: Mixture of Experts (MoE) and Rotary Embeddings (RoPE).

## Learning Objectives
1. Implement Rotary Positional Embeddings (RoPE)
2. Build a Mixture of Experts (MoE) Layer
3. Understand Gated Linear Units (GLU)

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math

device = 'cuda' if torch.cuda.is_available() else 'cpu'

## 1. Rotary Positional Embeddings (RoPE)

RoPE rotates the query and key vectors to encode relative position.

In [None]:
def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0):
    """
    Precompute the frequency tensor for complex exponentials (cis).
    """
    freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
    t = torch.arange(end, device=freqs.device)
    freqs = torch.outer(t, freqs).float()  # (seq_len, dim/2)
    freqs_cis = torch.polar(torch.ones_like(freqs), freqs)  # complex64
    return freqs_cis

def apply_rotary_emb(xq, xk, freqs_cis):
    """
    Apply RoPE to queries and keys.
    Input shape: (batch, seq_len, n_head, head_dim)
    """
    # Reshape for broadcast
    xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
    xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
    
    freqs_cis = freqs_cis.view(1, xq_.shape[1], 1, xq_.shape[-1])
    
    # Rotate
    xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)
    xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
    
    return xq_out.type_as(xq), xk_out.type_as(xk)

In [None]:
# Test RoPE
dim = 64
seq_len = 10
freqs_cis = precompute_freqs_cis(dim, seq_len)

q = torch.randn(1, seq_len, 4, dim)
k = torch.randn(1, seq_len, 4, dim)

q_rot, k_rot = apply_rotary_emb(q, k, freqs_cis)
print(f"Original Q norm: {torch.norm(q):.4f}")
print(f"Rotated Q norm:  {torch.norm(q_rot):.4f} (should be same)")

## 2. SwiGLU Activation

SwiGLU(x) = Swish(xW) * (xV)

In [None]:
class SwiGLU(nn.Module):
    def __init__(self, dim, hidden_dim):
        super().__init__()
        self.w1 = nn.Linear(dim, hidden_dim, bias=False)
        self.w2 = nn.Linear(hidden_dim, dim, bias=False)
        self.w3 = nn.Linear(dim, hidden_dim, bias=False)

    def forward(self, x):
        # silu(x * w1) * (x * w3) -> * w2
        return F.silu(self.w1(x)) * self.w3(x) @ self.w2.weight.T

## 3. Mixture of Experts (MoE)

Sparse MoE layer with Top-K gating.

In [None]:
class MoELayer(nn.Module):
    def __init__(self, dim, num_experts, top_k=2):
        super().__init__()
        self.num_experts = num_experts
        self.top_k = top_k
        
        # Experts: simple FFN or MLP
        self.experts = nn.ModuleList([
            nn.Sequential(
                nn.Linear(dim, 4*dim),
                nn.ReLU(),
                nn.Linear(4*dim, dim)
            ) for _ in range(num_experts)
        ])
        
        # Gating network
        self.gate = nn.Linear(dim, num_experts)
    
    def forward(self, x):
        batch, seq_len, dim = x.shape
        x_flat = x.view(-1, dim)
        
        # Gate scores
        gate_logits = self.gate(x_flat)
        probs = F.softmax(gate_logits, dim=-1)
        
        # Top-K experts
        top_probs, top_indices = torch.topk(probs, self.top_k, dim=-1)
        # Normalize probs
        top_probs = top_probs / top_probs.sum(dim=-1, keepdim=True)
        
        out = torch.zeros_like(x_flat)
        
        # Route tokens to experts
        for k in range(self.top_k):
            expert_idx = top_indices[:, k]
            prob = top_probs[:, k].unsqueeze(-1)
            
            for i in range(self.num_experts):
                mask = (expert_idx == i)
                if mask.sum() == 0:
                    continue
                
                expert_input = x_flat[mask]
                expert_output = self.experts[i](expert_input)
                
                # Add weighted expert output
                out[mask] += prob[mask] * expert_output
                
        return out.view(batch, seq_len, dim)

In [None]:
# Test MoE
moe = MoELayer(dim=128, num_experts=8, top_k=2)
x = torch.randn(4, 10, 128)
output = moe(x)
print(f"MoE Input: {x.shape}")
print(f"MoE Output: {output.shape}")