In [1]:
import torch
import torch.nn as nn
from torch.nn import functional as F

torch.manual_seed(42)

<torch._C.Generator at 0x79b980b62830>

In [2]:
!wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt

--2025-07-23 19:39:10--  https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.108.133, 185.199.111.133, 185.199.109.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.108.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 1115394 (1.1M) [text/plain]
Saving to: ‘input.txt’


2025-07-23 19:39:11 (3.96 MB/s) - ‘input.txt’ saved [1115394/1115394]



# Define the experts of the NN

In [3]:
# Eppert mocule
class Expert(nn.Module):

  """An MLP is a simple linear layer folowed by an non-linearity """

  def __init__(self, n_embd):
    super().__init__()
    self.net = nn.Sequential(
        nn.Linear(n_embd, 4 * n_embd),
        nn.ReLU(),
        nn.Linear(4 * n_embd, n_embd),
        nn.Dropout(dropout)
    )

  def forward(self, x):
    return self.net(x)



In [14]:
# Understanding how gating works
num_experts = 3
top_k = 2
n_embd = 32

# Example multi-head attention output for simple illustrative example.
mh_output = torch.randn(1, 4, n_embd)

topkgate_linear = nn.Linear(n_embd, num_experts)
logits = topkgate_linear(mh_output)
print(logits)

tensor([[[ 0.2753, -1.8810, -0.1553],
         [-0.1253, -0.5594, -0.2212],
         [-0.3208, -0.1205, -0.0061],
         [-0.3639, -0.2565,  0.1207]]], grad_fn=<ViewBackward0>)


# Implementing load balancing

In [15]:
# We start with topk load balancing
top_k_logits, top_k_indicies = logits.topk(top_k, dim=-1)
top_k_logits, top_k_indicies

(tensor([[[ 0.2753, -0.1553],
          [-0.1253, -0.2212],
          [-0.0061, -0.1205],
          [ 0.1207, -0.2565]]], grad_fn=<TopkBackward0>),
 tensor([[[0, 2],
          [0, 2],
          [2, 1],
          [2, 1]]]))

# Apply -ing softmax

In [16]:
zeros = torch.full_like(logits, float("-inf"))
sparse_logits = zeros.scatter(-1, top_k_indicies, top_k_logits)
sparse_logits

tensor([[[ 0.2753,    -inf, -0.1553],
         [-0.1253,    -inf, -0.2212],
         [   -inf, -0.1205, -0.0061],
         [   -inf, -0.2565,  0.1207]]], grad_fn=<ScatterBackward0>)

In [17]:
gating_output = F.softmax(sparse_logits, dim=-1)
gating_output

tensor([[[0.6060, 0.0000, 0.3940],
         [0.5239, 0.0000, 0.4761],
         [0.0000, 0.4714, 0.5286],
         [0.0000, 0.4068, 0.5932]]], grad_fn=<SoftmaxBackward0>)

# Create a class for TopKRouting

In [19]:
# First define the top k router module
class TopKRouter(nn.Module):
  def __init__(self, n_embd, num_experts, top_k):
    super(TopKRouter, self).__init__()
    self.top_k = top_k
    self.linear = nn.Linear(n_embd, num_experts)

  def forward(self, mh_output):
    # mh_output is the output tensor from multihead self attention block
    logits = self.linear(mh_output)
    top_k_logits, indicies = logits.topk(self.top_k, dim=-1)
    zeros = torch.full_like(logits, float('-inf'))
    sparse_logits = zeros.scatter(-1, indicies, top_k_logits)
    router_output = F.softmax(sparse_logits, dim=-1)
    return router_output, indicies

In [20]:
# chanign the above to accomadate noise top-k gating
class NoisyTopkTouter(nn.Module):
  def __init__(self, n_embd, num_experts, top_k):
    super(NoisyTopkTouter, self).__init__()
    self.top_k = top_k
    # layer for router logits
    self.topkroute_linear = nn.Linear(n_embd, num_experts)
    self.noise_linear  = nn.Linear(n_embd, num_experts)

  def forward(self, mh_output):
    # mh_output is the output tensor from multihead self attention block
    logits = self.topkroute_linear(mh_output)

    # Noise logits
    noise_logits = self.noise_linear(mh_output)

    # Adding scaled unit gaussian noise to the logits
    noise = torch.randn_like(logits)*F.softplus(noise_logits)
    noisy_logits = logits + noise

    top_k_logits, indices = noisy_logits.topk(self.top_k, dim=-1)
    zeros = torch.full_like(noisy_logits, float('-inf'))
    sparse_logits = zeros.scatter(-1, indices, top_k_logits)
    router_output = F.softmax(sparse_logits, dim=-1)
    return router_output, indices

# Create a spares MoE model

In [21]:
class SparseMoE(nn.Module):
  def __init__(self, n_embd, num_experts, top_k):
    super(SparseMoE, self).__init__()
    self.router = NoisyTopkTouter(n_embd, num_experts, top_k)
    self.experts = nn.ModuleList([Expert(n_embd) for _ in range(num_experts)])
    self.top_k = top_k

  def forward(self, x):
    gating_output, indices = self.router(x)
    final_output = torch.zeros_like(x)

    # Reshape inputs for batch processing
    flat_x = x.view(-1, x.size(-1))
    flat_gating_output = gating_output.view(-1, gating_output.size(-1))

    # Process each expert in parallel
    for i, expert in enumerate (self.experts):
      # Create a mask for the inputs where the current expert is in top-k
      expert_mask = (indices == i).any(dim=-1)
      flat_mask = expert_mask.view(-1)

      if flat_mask.any():
        expert_input = flat_x[flat_mask]
        expert_output = expert(expert_input)

        # Extract and apply gating scores
        gating_scores = flat_gating_output[flat_mask, i].unsqueeze(1)
        weighted_output = expert_output * gating_scores

        # Update final output additively by indexing and adding
        final_output[expert_mask] += weighted_output.squeeze(1)

    return final_output



# Code the entire transformer block: Part 1 (Multi-head Attention)

In [40]:
class Head(nn.Module):

  """one head of self-attention """

  def __init__(self, head_size):
    super().__init__()
    self.key = nn.Linear(n_embd, head_size, bias=False)
    self.query = nn.Linear(n_embd, head_size, bias=False)
    self.value = nn.Linear(n_embd, head_size, bias=False)
    self.register_buffer('tril', torch.tril(torch.ones(block_size, block_size)))

    self.dropout = nn.Dropout(dropout)

  def forward(self, x):
    B,T,C = x.shape
    k = self.key(x)
    v = self.value(x)
    q = self.query(x)

    #compute attention scores ("affinities")
    wei = q @ k.transpose(-2, -1) * C**-0.5
    wei = wei.masked_fill(self.tril[:T, :T] == 0, float('-inf'))
    wei = F.softmax(wei,dim=-1)
    wei = self.dropout(wei)

    #perform the weighted aggregation of the values
    out = wei @ v
    return out

class MultiHeadAttention(nn.Module):
  """ We run multiple heads of self-attention in parallel"""

  def __init__(self, n_head, head_size):
    super().__init__()
    self.heads = nn.ModuleList([Head(head_size) for _ in range(n_head)])
    self.proj = nn.Linear(n_embd, n_embd)
    self.dropout = nn.Dropout(dropout)

  def forward(self, x):
    out = torch.cat([h(x) for h in self.heads], dim=-1)
    out = self.dropout(self.proj(out))
    return out

# code the entire transformer block

In [23]:
 # First create a self attnetnio + mixture of experts block, that may be repeated several times
 # Xopy pasting key architecture variables for clarity

class Block(nn.Module):
  def __init__(self, n_embd, n_head, num_experts, top_k):
    super().__init__()
    head_size = n_embd // n_head
    self.sa = MultiHeadAttention(n_head, head_size)
    self.smoe = SparseMoE(n_embd, num_experts, top_k)
    self.ln1 = nn.LayerNorm(n_embd)
    self.ln2 = nn.LayerNorm(n_embd)

  def forward(self, x):
    x = x + self.sa(self.ln1(x))
    x = x + self.smoe(self.ln2(x))
    return x

# Defnine the language model architecture

In [53]:
# Finally we put it all together
class SparseMoELanguageModel(nn.Module):

  def __init__(self):
    super().__init__()
    # each token directly reads of the logits for the next token from a lookup table
    self.t_embd = nn.Embedding(vocab_size, n_embd)
    self.p_embd = nn.Embedding(block_size, n_embd)
    self.blocks = nn.Sequential(*[Block(n_embd, n_head=n_head, num_experts=num_experts, top_k=top_k) for _ in range(n_layer)])
    self.ln_f = nn.LayerNorm(n_embd)
    self.lm_head = nn.Linear(n_embd, vocab_size)

  def forward(self, idx, targets=None):
    B, T = idx.shape

    # idx and targets are both (B,T) tensor of integers
    tok_emb = self.t_embd(idx)
    pos_emb = self.p_embd(torch.arange(T, device=device))
    x = tok_emb + pos_emb
    x = self.blocks(x)
    x = self.ln_f(x)
    logits = self.lm_head(x)

    if targets is None:
      loss = None
    else:
      B, T, X = logits.shape
      logits = logits.view(B*T, X)
      targets = targets.view(B*T)
      loss = F.cross_entropy(logits, targets)

    return logits, loss

  def generate(self, idx, max_new_tokens):
    # idx is (B,T) array of indices in the current context
    for _ in range(max_new_tokens):
      # crop idx to the last block_size tokens
      idx_cond = idx[:, -block_size:]
      # get predictions
      logits, loss = self(idx_cond)
      #focus only on the last time step
      logits = logits[:, -1, :]
      # apply softmax to get the probabilities
      probs = F.softmax(logits, dim=-1)
      # sample from the distribution
      idx_next = torch.multinomial(probs, num_samples=1)
      # append sampled index to the running sequence
      idx = torch.cat((idx, idx_next), dim=-1)
    return idx

# Create training and testing data

In [47]:
torch.manual_seed(1137)

with open('input.txt', 'r', encoding='utf-8') as f:
  text = f.read()

# here are all the unique characters that occur in teh test
chars = sorted(list(set(text)))
vocab_size = len(chars)
#creat mapping from chars to integers
stoi = { ch:i for i, ch in enumerate(chars)}
itos = { i:ch for i, ch in enumerate(chars)}
encode = lambda s: [stoi[c] for c in s] # encoder: take a string, output a list of integers
decode = lambda l: ''.join([itos[i] for i in l]) # decoder: take a list of integers, output a string

# trian and test sets
data = torch.tensor(encode(text), dtype=torch.long)
n = int(0.9*len(data))
train_data = data[:n]
val_data = data[n:]

#data loading
def get_batch(split):
  # generate a small batcho f data of inputs x and targets y
  data = train_data if split == 'train' else val_data
  ix = torch.randint(len(data) - block_size, (batch_size,))
  x = torch.stack([data[i:i+block_size] for i in ix])
  y = torch.stack([data[i+1:i+block_size+1] for i in ix])
  x, y = x.to(device), y.to(device)
  return x, y

# Define LLM Loss

In [49]:
@torch.no_grad()
def estimate_loss():
  out = {}
  model.eval()
  for split in ['train', 'val']:
    losses = torch.zeros(eval_iters)
    for k in range(eval_iters):
      X, Y = get_batch(split)
      logits, loss = model(X,Y)
      losses[k] = loss.item()
    out[split] = losses.mean()
  model.train()
  return out

# Define training loop params

In [33]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.nn import functional as F
from torch.nn import init

# hyperparameters
batch_size = 16
block_size = 32
max_iters = 200
learning_rate = 1e-3
device = 'cuda' if torch.cuda.is_available() else 'cpu'
eval_interval = 100
eval_iters = 400
head_size = 16
n_embd = 128
n_head = 8
n_layer = 8
dropout = 0.1
num_experts = 8
top_k = 2

# Initializ the model

In [57]:
def kaiming_init_weights(m):
  if isinstance (m, (nn.Linear)):
    init.kaiming_normal_(m.weight)

In [58]:
model = SparseMoELanguageModel()
model.apply(kaiming_init_weights)
model.to(device)

SparseMoELanguageModel(
  (t_embd): Embedding(65, 128)
  (p_embd): Embedding(32, 128)
  (blocks): Sequential(
    (0): Block(
      (sa): MultiHeadAttention(
        (heads): ModuleList(
          (0-7): 8 x Head(
            (key): Linear(in_features=128, out_features=16, bias=False)
            (query): Linear(in_features=128, out_features=16, bias=False)
            (value): Linear(in_features=128, out_features=16, bias=False)
            (dropout): Dropout(p=0.1, inplace=False)
          )
        )
        (proj): Linear(in_features=128, out_features=128, bias=True)
        (dropout): Dropout(p=0.1, inplace=False)
      )
      (smoe): SparseMoE(
        (router): NoisyTopkTouter(
          (topkroute_linear): Linear(in_features=128, out_features=8, bias=True)
          (noise_linear): Linear(in_features=128, out_features=8, bias=True)
        )
        (experts): ModuleList(
          (0-7): 8 x Expert(
            (net): Sequential(
              (0): Linear(in_features=128, out

# Run the pretrained model

In [59]:
m = model.to(device)

# display the params in the model
print(sum(p.numel() for p in m.parameters())/1e6, 'M parameters')

# create a pytorch optimizer
optimizer = torch.optim.AdamW(model.parameters(), lr = learning_rate)

for iter in range(max_iters):

  if iter % eval_iters == 0 or iter == max_iters - 1:

    losses = estimate_loss()
    print(f"step {iter}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}")

  xb, yb = get_batch('train')

  logits, loss = model(xb, yb)
  optimizer.zero_grad(set_to_none=True)
  loss.backward()
  optimizer.step()

8.996545 M parameters
step 0: train loss 5.3314, val loss 5.3018
step 199: train loss 2.5056, val loss 2.4987


# Inference - Generate from the model

In [60]:
context = torch.zeros((1,1), dtype = torch.long, device = device)
print(decode(m.generate(context, max_new_tokens=2000)[0].tolist()))


Draner hak.
Wh nd mithe
ARI ty iroo n us at mandwith fat ncor me no ous' helo Pyongote!
Buderivet w and: y re sier d,e.
e, blas, ODorofe bewie d:
Ane be ge so hovend be itourtou
m f: Lure hine? seLue htoro one gipraste imo hit ovktho y saot o wgethiconrllar fs s, oat I o mer r guris.
Atikaru imethe rim l se inof o le?eran oramy, Ckd, nqorar ndetor kis uth br:
Whis I 
Ihane nered fe l,
Thoreakee he werenocllyowis ais orice ur an he l'oust he ane ame e tan'ds, y t fpainor gowarrosf o the sjrs ANpeer
Pour y ild hofut hongoonor, n o mes t,
Asur;
kik, wfo arN the we ginche, r he alesatho ko s;
That meeBy wi as o Ctot th sisure witounch,
Aneove wheo mee. s
Sillld my Rore parold; hie
IAyor h he e
Tpulird I lalle.
Ygoughe; mith pre hit'e; ndithEThe:

And at m, orr be; ule 'ld ithyor gierine.
We cthomear nouTher!
An ; d be wise I horbe tee Wof!
He s s, hikele.
ANLat msuit prte jus fr, be nsom s. aneatowiceors.
Your ores oud s se y, aray we, tharicd tethya; nce be.
Sonnicqie: Le our lllo pillth

In [62]:
!git init

[33mhint: Using 'master' as the name for the initial branch. This default branch name[m
[33mhint: is subject to change. To configure the initial branch name to use in all[m
[33mhint: [m
[33mhint: 	git config --global init.defaultBranch <name>[m
[33mhint: [m
[33mhint: Names commonly chosen instead of 'master' are 'main', 'trunk' and[m
[33mhint: 'development'. The just-created branch can be renamed via this command:[m
[33mhint: [m
[33mhint: 	git branch -m <name>[m
Initialized empty Git repository in /content/.git/
