In [11]:
import torch
import torch.nn as nn
from torch.nn import functional as F
from torch.nn import init
torch.manual_seed(42)

# hyperparameters
batch_size = 16 # how many independent sequences will we process in parallel?
block_size = 32 # what is the maximum context length for predictions?
max_iters = 5000
eval_interval = 100
learning_rate = 1e-3
device = 'cuda' if torch.cuda.is_available() else 'cpu'
eval_iters = 400
head_size = 16
n_embed = 128
n_head = 8
n_layer = 8
dropout = 0.1
num_experts = 8
top_k = 2

In [12]:
with open('./data/input.txt', 'r', encoding='utf-8') as f:
  text = f.read()

print('length of text file:', len(text))

chars = sorted(list(set(text)))
vocab_size = len(chars)

stoi = { ch:i for i,ch in enumerate(chars) }
itos = { i:ch for i,ch in enumerate(chars) }
encode = lambda s: [stoi[c] for c in s] # encoder: take a string, output a list of integers
decode = lambda l: ''.join([itos[i] for i in l]) # decoder: take a list of integers, output a string

data = torch.tensor(encode(text), dtype=torch.long)

length of text file: 998


In [13]:
# Train and test splits
data = torch.tensor(encode(text), dtype=torch.long)
n = int(0.9*len(data)) # first 90% will be train, rest val
train_data = data[:n]
val_data = data[n:]

# data loading
def get_batch(split):
    # generate a small batch of data of inputs x and targets y
    data = train_data if split == 'train' else val_data
    ix = torch.randint(len(data) - block_size, (batch_size,))
    x = torch.stack([data[i:i+block_size] for i in ix])
    y = torch.stack([data[i+1:i+block_size+1] for i in ix])
    x, y = x.to(device), y.to(device)
    return x, y

@torch.no_grad()
def estimate_loss():
    out = {}
    model.eval()
    for split in ['train', 'val']:
        losses = torch.zeros(eval_iters)
        for k in range(eval_iters):
            X, Y = get_batch(split)
            logits, loss = model(X, Y)
            losses[k] = loss.item()
        out[split] = losses.mean()
    model.train()
    return out

In [14]:
class Head(nn.Module):
    """ one head of self-attention """

    def __init__(self, head_size):
        super().__init__()
        self.key = nn.Linear(n_embed, head_size, bias=False)
        self.query = nn.Linear(n_embed, head_size, bias=False)
        self.value = nn.Linear(n_embed, head_size, bias=False)
        self.register_buffer('tril', torch.tril(torch.ones(block_size, block_size)))

        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        B,T,C = x.shape
        k = self.key(x)   # (B,T,C)
        q = self.query(x) # (B,T,C)
        # compute attention scores ("affinities")
        wei = q @ k.transpose(-2,-1) * C**-0.5 # (B, T, C) @ (B, C, T) -> (B, T, T)
        wei = wei.masked_fill(self.tril[:T, :T] == 0, float('-inf')) # (B, T, T)
        wei = F.softmax(wei, dim=-1) # (B, T, T)
        wei = self.dropout(wei)
        # perform the weighted aggregation of the values
        v = self.value(x) # (B,T,C)
        out = wei @ v # (B, T, T) @ (B, T, C) -> (B, T, C)
        return out

#Multi-Headed Self Attention
class MultiHeadAttention(nn.Module):
    """ multiple heads of self-attention in parallel """

    def __init__(self, num_heads, head_size):
        super().__init__()
        self.heads = nn.ModuleList([Head(head_size) for _ in range(num_heads)])
        self.proj = nn.Linear(n_embed, n_embed)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        out = torch.cat([h(x) for h in self.heads], dim=-1)
        out = self.dropout(self.proj(out))
        return out

#noisy top-k gating
class NoisyTopkRouter(nn.Module):
    def __init__(self, n_embed, num_experts, top_k):
        super(NoisyTopkRouter, self).__init__()
        self.top_k = top_k
        #layer for router logits
        self.topkroute_linear = nn.Linear(n_embed, num_experts)
        self.noise_linear =nn.Linear(n_embed, num_experts)


    def forward(self, mh_output):
        # mh_ouput is the output tensor from multihead self attention block
        logits = self.topkroute_linear(mh_output)

        #Noise logits
        noise_logits = self.noise_linear(mh_output)

        #Adding scaled unit gaussian noise to the logits
        noise = torch.randn_like(logits)*F.softplus(noise_logits)
        noisy_logits = logits + noise

        top_k_logits, indices = noisy_logits.topk(self.top_k, dim=-1)
        zeros = torch.full_like(noisy_logits, float('-inf'))
        sparse_logits = zeros.scatter(-1, indices, top_k_logits)
        router_output = F.softmax(sparse_logits, dim=-1)
        return router_output, indices

#Expert module
class Expert(nn.Module):
    """ An MLP is a simple linear layer followed by a non-linearity i.e. each Expert """

    def __init__(self, n_embd):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(n_embd, 4 * n_embd),
            nn.ReLU(),
            nn.Linear(4 * n_embd, n_embd),
            nn.Dropout(dropout),
        )

    def forward(self, x):
        return self.net(x)
     

In [15]:
def compute_reward(selected_experts, correct_experts):
    """
    Compute reward based on whether the correct expert(s) were chosen.
    - Reward = 1 if at least 1 correct expert is chosen.
    - Penalty = -1 if none of the correct experts are chosen.
    """
    batch_size = selected_experts.size(0)
    rewards = torch.zeros(batch_size, device=selected_experts.device)
    
    for i in range(batch_size):
        # Check if any of the selected experts match the correct ones
        if any(expert in correct_experts[i] for expert in selected_experts[i]):
            rewards[i] = 1.0  # Reward for correct expert selection
        else:
            rewards[i] = -1.0  # Penalty for wrong selection
    
    return rewards

def compute_policy_loss(logits, selected_experts, rewards):
    """
    Compute policy loss using REINFORCE.
    """
    # Convert logits to probabilities
    prob_dist = F.softmax(logits, dim=-1)
    
    # Get log-probabilities of selected experts
    selected_log_probs = torch.log(torch.gather(prob_dist, -1, selected_experts.unsqueeze(-1)).squeeze(-1))
    
    # REINFORCE Loss: Negative log-prob weighted by reward
    loss = -torch.mean(rewards * selected_log_probs)
    return loss

In [None]:
import torch.optim as optim

def train_router(model, router, experts, train_data, num_epochs):
    """Train the router to choose the best experts for different tasks."""

    router_optimizer = optim.Adam(router.parameters(), lr=1e-4)
    loss_fn = nn.CrossEntropyLoss()

    for epoch in range(num_epochs):
        model.eval()  # Freeze the main model
        for expert in experts:
            expert.eval()  # Freeze all experts during router training

        router.train()  # Train only the router
        total_reward = 0
        total_loss = 0

        for i in range(0, len(train_data) - block_size, batch_size):
            x = torch.stack([train_data[i:i + block_size] for _ in range(batch_size)]).to(device)
            y = torch.stack([train_data[i + 1:i + block_size + 1] for _ in range(batch_size)]).to(device)

            with torch.no_grad():
                mh_output = model(x)  # Get output after multi-head attention
                router_output, indices = router(mh_output)  # Route to experts

            # -------------------------
            # 🧠 Choose top-k experts
            # -------------------------
            expert_outputs = []
            for j, idx in enumerate(indices):  # For each sequence in batch
                selected_experts = [experts[k](mh_output[j]) for k in idx]  # Top-k experts
                avg_expert_output = sum(selected_experts) / len(selected_experts)  # Average output
                expert_outputs.append(avg_expert_output)

            expert_outputs = torch.stack(expert_outputs).to(device)  # Stack outputs

            # -------------------------
            # 🎯 Define reward function
            # -------------------------
            with torch.no_grad():
                correct_loss = loss_fn(expert_outputs.view(-1, vocab_size), y.view(-1))  # Expert output loss
                baseline_loss = torch.ones_like(correct_loss) * 2.0  # Baseline for reward comparison
                reward = (baseline_loss - correct_loss).detach()  # Higher reward for lower loss

            # -------------------------
            # 🏆 Reinforce the router
            # -------------------------
            log_probs = torch.log(router_output + 1e-9)  # Avoid NaN
            selected_log_probs = torch.gather(log_probs, -1, indices)  # Log probs of chosen experts
            loss = -(reward * selected_log_probs.sum(-1)).mean()  # Policy gradient loss

            router_optimizer.zero_grad()
            loss.backward()
            router_optimizer.step()

            total_reward += reward.mean().item()
            total_loss += loss.item()

        avg_reward = total_reward / (len(train_data) // batch_size)
        avg_loss = total_loss / (len(train_data) // batch_size)

        print(f"Epoch [{epoch + 1}/{num_epochs}] - Avg Reward: {avg_reward:.4f}, Loss: {avg_loss:.4f}")

    print("Router training complete! 🚀")

experts = nn.ModuleList([Expert(n_embed).to(device) for _ in range(num_experts)])
router = NoisyTopkRouter(n_embed=n_embed, num_experts=num_experts, top_k=top_k).to(device)

# Define your main model if it's not yet defined
# Example placeholder for a model that outputs after MHA
class MainModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, n_embed)
        self.mha = MultiHeadAttention(n_head, head_size)
        self.proj = nn.Linear(n_embed, n_embed)
        
    def forward(self, x):
        x = self.embedding(x)  # (B, T) -> (B, T, n_embed)
        x = self.mha(x)  # Multi-head attention
        x = self.proj(x)  # Projection
        return x

# Initialize the model
model = MainModel().to(device)

train_data = data[:n]

train_router(model, router, experts, train_data, 10)

print(f"Model: {model}")
print(f"Router: {router}")
print(f"Number of Experts: {len(experts)}")
print(f"Train Data Size: {len(train_data)}")

TypeError: only integer tensors of a single element can be converted to an index

: 