In [1]:
import torch
import torch.nn as nn
from torch.nn import functional as F
torch.manual_seed(42)

<torch._C.Generator at 0x7cb47211be10>

In [2]:
!wget https://raw.githubusercontent.com/AviSoori1x/makeMoE/main/input.txt

--2025-08-16 09:14:55--  https://raw.githubusercontent.com/AviSoori1x/makeMoE/main/input.txt
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.108.133, 185.199.109.133, 185.199.110.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.108.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 1115394 (1.1M) [text/plain]
Saving to: ‘input.txt’


2025-08-16 09:14:56 (27.1 MB/s) - ‘input.txt’ saved [1115394/1115394]



# Making the experts

In [3]:
class Expert(nn.Module):
  '''
  The expert can be seen as an neural network that expands the input embedding into 4 times it's dimention and re contracts it back to it's original dimention
  '''
  def __init__(self, n_embed):
    super().__init__()
    self.net = nn.Sequential(
        nn.Linear(n_embed, 4* n_embed), # take input and expand to 4 times the input dimenstion
        nn.ReLU(), #we can also use GeLU
        nn.Linear(4 * n_embed, n_embed), # contract it back to original dimension
        nn.Dropout(dropout),
    )

  def forward(self,x):
    return self.net(x)

# Now we will need the routing matrix to route the tokens to the correct experts, this will help implement sparsity. After taking the dot product of the routing matrix and the original matrix we will get the expert selector matrix

Let:

- $X \in \mathbb{R}^{B \times D}$ : Input matrix (batch size $B$, input dimension $D$)  
- $R \in \mathbb{R}^{D \times E}$ : Routing (or gating) matrix (maps inputs to experts, with $E$ experts)  
- $S \in \mathbb{R}^{B \times E}$ : Expert selector matrix  

The expert selector matrix is obtained as:

$$
S = X \cdot R
$$

where:  

- Each row of \( S \) corresponds to one input token’s affinity/score for each expert.  
- A softmax is often applied along the expert dimension:

$$
\hat{S}_{i,j} = \frac{\exp(S_{i,j})}{\sum_{k=1}^{E} \exp(S_{i,k})}
$$

to produce a probability distribution over experts.  

Finally, the top-\(k\) experts are chosen per token:

$$
\text{Experts}(i) = \text{Top-}k(\hat{S}_{i,:})
$$


# Load balancing :

now we only select 2 experts for the tokens, Meaning we will take the top 2 values in the expert selector matrix and make every other value 0

#Now to we make he every other value negative infinity and apply softmax, this will result in every other value than top k values to be 0 and the top k values to have a sum of 1

# Top-k Masking with Softmax

Let $S \in \mathbb{R}^{B \times E}$ be the expert score matrix.

---

### 1. Top-k Masking
For each token $i$, define a masked score matrix $M$ as:

$$
M_{i,j} =
\begin{cases}
S_{i,j}, & \text{if } j \in \text{Top-}k(S_{i,:}) \\[6pt]
-\infty, & \text{otherwise}
\end{cases}
$$

---

### 2. Softmax over Masked Scores
We then apply softmax:

$$
\hat{S}_{i,j} = \frac{\exp(M_{i,j})}{\sum_{m=1}^{E} \exp(M_{i,m})}
$$

---

### 3. Why Non-Top-$k$ Entries Become Zero
Since

$$
\exp(-\infty) = 0,
$$

all masked (non-top-$k$) entries vanish in the numerator and denominator.  

Thus, only the top-$k$ entries survive, and their probabilities normalize to sum to 1.

---

### 4. Result
- Non-top-$k$ experts: $\hat{S}_{i,j} = 0$  
- Top-$k$ experts: $\hat{S}_{i,j} > 0$ and  

$$
\sum_{j \in \text{Top-}k} \hat{S}_{i,j} = 1
$$



In [9]:
class TopKRouter(nn.Module):
  def __init__(self, n_embed, num_experts, top_k):
    super(TopKRouter,self).__init__()
    self.top_k = top_k
    self.linear = nn.Linear(n_embed, num_experts)

  def forward(self, mh_output):
    logits = self.linear(mh_output)
    top_k_logits,indices = logits.topk(self.top_k, dim = -1)
    zeros = torch.full_like(logits, float('-inf'))
    sparse_logits = zeros.scatter(-1, indices, top_k_logits)
    router_output = F.softmax(sparse_logits, dim =-1)
    return router_output,indices

# Implementing noisy top-k gating

In [10]:
class NoisyTopKRouter(nn.Module):
      def __init__(self, n_embed, num_experts, top_k):

        super(NoisyTopKRouter, self).__init__()
        self.top_k = top_k
        self.topkroute_linear = nn.Linear(n_embed, num_experts)
        self.noise_linear = nn.Linear(n_embed, num_experts)

      def forward(self, mh_output):
        # Calculate the base logits for routing
        logits = self.topkroute_linear(mh_output)
        # Calculate the noise logits
        noise_logits = self.noise_linear(mh_output)
        # Add scaled unit gaussian noise to the logits
        noise = torch.randn_like(logits) * F.softplus(noise_logits)
        noisy_logits = logits + noise
        # Get the top-k logits and their corresponding indices
        top_k_logits, indices = noisy_logits.topk(self.top_k, dim=-1)
        # Create a sparse tensor with negative infinity for unselected experts
        zeros = torch.full_like(noisy_logits, float('-inf'))
        sparse_logits = zeros.scatter(-1, indices, top_k_logits)
        # Apply softmax to get the final routing probabilities
        router_output = F.softmax(sparse_logits, dim=-1)

        return router_output, indices

# Implementing Mixture of Experts with sparsity

* Every top k value from the expert selector matrxi nor get multipled with the corresponding top k experts.

* But in practice we itterate over all the experts and get the values for the tokens



In [11]:
class SparseMoE(nn.Module):
  def __init__(self,n_embed, num_experts, top_k):
    super(SparseMoE,self).__init__()
    self.router = NoisyTopKRouter(n_embed, num_experts, top_k)
    self.experts = nn.ModuleList([Expert(n_embed) for _ in range(num_experts)])
    self.top_k = top_k

  def forward(self, x):
    gating_output, indices = self.router(x)
    final_output = torch.zeros_like(x)

    flat_x = x.view(-1, x.size(-1))
    flat_gating_output = gating_output.view(-1, gating_output.size(-1))

    for i,expert in enumerate(self.experts):
      expert_mask = (indices == i).any(dim=-1)
      flat_mask = expert_mask.view(-1)

      if flat_mask.any():
        expert_input = flat_x[flat_mask]
        expert_output = expert(expert_input)

        gating_scores = flat_gating_output[flat_mask,i].unsqueeze(1)
        weighted_output = expert_output * gating_scores

        final_output[expert_mask] += weighted_output.squeeze(1)

    return final_output

#Complete MoE

In [12]:
class Expert(nn.Module):
  '''
  The expert can be seen as an neural network that expands the input embedding into 4 times it's dimention and re contracts it back to it's original dimention
  '''
  def __init__(self, n_embed):
    super().__init__()
    self.net = nn.Sequential(
        nn.Linear(n_embed, 4* n_embed), # take input and expand to 4 times the input dimenstion
        nn.ReLU(), #we can also use GeLU
        nn.Linear(4 * n_embed, n_embed), # contract it back to original dimension
        nn.Dropout(dropout),
    )

  def forward(self,x):
    return self.net(x)

class NoisyTopKRouter(nn.Module):
      def __init__(self, n_embed, num_experts, top_k):

        super(NoisyTopKRouter, self).__init__()
        self.top_k = top_k
        self.topkroute_linear = nn.Linear(n_embed, num_experts)
        self.noise_linear = nn.Linear(n_embed, num_experts)

      def forward(self, mh_output):
        logits = self.topkroute_linear(mh_output)
        noise_logits = self.noise_linear(mh_output)
        noise = torch.randn_like(logits) * F.softplus(noise_logits)
        noisy_logits = logits + noise
        top_k_logits, indices = noisy_logits.topk(self.top_k, dim=-1)
        zeros = torch.full_like(noisy_logits, float('-inf'))
        sparse_logits = zeros.scatter(-1, indices, top_k_logits)
        router_output = F.softmax(sparse_logits, dim=-1)

        return router_output, indices

class SparseMoE(nn.Module):
  def __init__(self,n_embed, num_experts, top_k):
    super(SparseMoE,self).__init__()
    self.router = NoisyTopKRouter(n_embed, num_experts, top_k)
    self.experts = nn.ModuleList([Expert(n_embed) for _ in range(num_experts)])
    self.top_k = top_k

  def forward(self, x):
    gating_output, indices = self.router(x)
    final_output = torch.zeros_like(x)

    flat_x = x.view(-1, x.size(-1))
    flat_gating_output = gating_output.view(-1, gating_output.size(-1))

    for i,expert in enumerate(self.experts):
      expert_mask = (indices == i).any(dim=-1)
      flat_mask = expert_mask.view(-1)

      if flat_mask.any():
        expert_input = flat_x[flat_mask]
        expert_output = expert(expert_input)

        gating_scores = flat_gating_output[flat_mask,i].unsqueeze(1)
        weighted_output = expert_output * gating_scores

        final_output[expert_mask] += weighted_output.squeeze(1)

    return final_output

#Transformer block

In [13]:
class Head(nn.Module):

    def __init__(self, head_size):
        super().__init__()
        self.key = nn.Linear(n_embed, head_size, bias=False)
        self.query = nn.Linear(n_embed, head_size, bias=False)
        self.value = nn.Linear(n_embed, head_size, bias=False)
        self.register_buffer('tril', torch.tril(torch.ones(block_size, block_size)))

        self.dropout = nn.Dropout(dropout)

    def forward(self, x):

        B,T,C = x.shape
        k = self.key(x)
        q = self.query(x)


        wei = q @ k.transpose(-2, -1) * C**-0.5
        wei = wei.masked_fill(self.tril[:T, :T] == 0, float('-inf'))
        wei = F.softmax(wei, dim=-1)
        wei = self.dropout(wei)

        v = self.value(x)
        out = wei @ v
        return out

class MultiHeadAttention(nn.Module):

    def __init__(self, num_heads, head_size):
        super().__init__()
        self.heads = nn.ModuleList([Head(head_size) for _ in range(num_heads)])
        self.proj = nn.Linear(n_embed, n_embed)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        out = torch.cat([h(x) for h in self.heads], dim=-1)
        out = self.dropout(self.proj(out))
        return out

In [14]:
class Block(nn.Module):

    def __init__(self, n_embed, n_head, num_experts, top_k):
        super().__init__()
        head_size = n_embed // n_head
        self.sa = MultiHeadAttention(n_head, head_size)
        self.smoe = SparseMoE(n_embed, num_experts, top_k)
        self.ln1 = nn.LayerNorm(n_embed)
        self.ln2 = nn.LayerNorm(n_embed)

    def forward(self, x):
        x = x + self.sa(self.ln1(x))
        x = x + self.smoe(self.ln2(x))
        return x

In [31]:
class SparseMoELanguageModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.token_embedding_table = nn.Embedding(vocab_size, n_embed)
        self.position_embedding_table = nn.Embedding(block_size, n_embed)
        self.blocks = nn.Sequential(*[Block(n_embed, n_head, num_experts=num_experts, top_k=top_k) for _ in range(n_layer)])
        self.ln_f = nn.LayerNorm(n_embed)
        self.lm_head = nn.Linear(n_embed, vocab_size)

    def forward(self, idx, targets=None):
        B, T = idx.shape
        tok_emb = self.token_embedding_table(idx)
        pos_emb = self.position_embedding_table(torch.arange(T, device=device))
        x = tok_emb + pos_emb
        x = self.blocks(x)
        x = self.ln_f(x)
        logits = self.lm_head(x)

        if targets is None:
            loss = None
        else:
            B, T, C = logits.shape
            logits = logits.view(B*T, C)
            targets = targets.view(B*T)
            loss = F.cross_entropy(logits, targets)

        return logits, loss

    def generate(self, idx, max_new_tokens):
      for _ in range(max_new_tokens):
          idx_cond = idx[:, -block_size:]
          logits, loss = self(idx_cond)
          logits = logits[:, -1, :]
          probs = F.softmax(logits, dim=-1)
          idx_next = torch.multinomial(probs, num_samples=1)
          idx = torch.cat((idx, idx_next), dim=1)
      return idx

In [32]:
torch.manual_seed(67)

with open('input.txt', 'r', encoding='utf-8') as f:
    text = f.read()

chars = sorted(list(set(text)))
vocab_size = len(chars)

stoi = { ch:i for i,ch in enumerate(chars) }
itos = { i:ch for i,ch in enumerate(chars) }
encode = lambda s: [stoi[c] for c in s]
decode = lambda l: ''.join([itos[i] for i in l])


data = torch.tensor(encode(text), dtype=torch.long)
n = int(0.9*len(data))
train_data = data[:n]
val_data = data[n:]

def get_batch(split):
    data = train_data if split == 'train' else val_data
    ix = torch.randint(len(data) - block_size, (batch_size,))
    x = torch.stack([data[i:i+block_size] for i in ix])
    y = torch.stack([data[i+1:i+block_size+1] for i in ix])
    x, y = x.to(device), y.to(device)
    return x, y

In [33]:
@torch.no_grad()
def estimate_loss():
    out = {}
    model.eval()
    for split in ['train', 'val']:
        losses = torch.zeros(eval_iters)
        for k in range(eval_iters):
            X, Y = get_batch(split)
            logits, loss = model(X, Y)
            losses[k] = loss.item()
        out[split] = losses.mean()
    model.train()
    return out

# hyper parameters

In [34]:
batch_size = 16
block_size = 32
max_iters = 20
eval_interval = 100
learning_rate = 1e-3
device = 'cuda' if torch.cuda.is_available() else 'cpu'
eval_iters = 400
head_size = 16
n_embed = 128
n_head = 8
n_layer = 8
dropout = 0.1
num_experts = 8
top_k = 2

In [35]:
from torch.nn import init
def kaiming_init_weights(m):
  if isinstance(m,(nn.Linear)):
    nn.init.kaiming_normal_(m.weight)

In [36]:
model = SparseMoELanguageModel()
model.apply(kaiming_init_weights)

SparseMoELanguageModel(
  (token_embedding_table): Embedding(65, 128)
  (position_embedding_table): Embedding(32, 128)
  (blocks): Sequential(
    (0): Block(
      (sa): MultiHeadAttention(
        (heads): ModuleList(
          (0-7): 8 x Head(
            (key): Linear(in_features=128, out_features=16, bias=False)
            (query): Linear(in_features=128, out_features=16, bias=False)
            (value): Linear(in_features=128, out_features=16, bias=False)
            (dropout): Dropout(p=0.1, inplace=False)
          )
        )
        (proj): Linear(in_features=128, out_features=128, bias=True)
        (dropout): Dropout(p=0.1, inplace=False)
      )
      (smoe): SparseMoE(
        (router): NoisyTopKRouter(
          (topkroute_linear): Linear(in_features=128, out_features=8, bias=True)
          (noise_linear): Linear(in_features=128, out_features=8, bias=True)
        )
        (experts): ModuleList(
          (0-7): 8 x Expert(
            (net): Sequential(
             

In [37]:

m = model.to(device)
print(sum(p.numel() for p in m.parameters())/1e6, 'M parameters')

optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)

for iter in range(max_iters):

    if iter % eval_interval == 0 or iter == max_iters - 1:
        losses = estimate_loss()
        print(f"step {iter}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}")

    xb, yb = get_batch('train')

    logits, loss = m(xb, yb)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()

8.996545 M parameters
step 0: train loss 5.1506, val loss 5.1520
step 19: train loss 3.3006, val loss 3.3420


In [38]:
context = torch.zeros((1,1), dtype =torch.long, device=device)
print(decode(m.generate(context,max_new_tokens=2000)[0].tolist()))



uGhto  n  wnyE GeaS gYAb dfdtnirten,
nE tBusrQhUntksicytoo

cadobEn ou  e giirfdgxnqcyr rdtrt t  y Oase a?L,kd.orn t tg wo  p uouasa H ,hh  r  n,yu 
eT
o o,h ws ,spy. bai n ioCIIpr dchyt
 hhetGeVrrBue  tdf Eayiooottto hdAI an  i, AokoaomsTBitwe? onasbwh eA, c    r,muaeTnrr Ido mt:tm e
tn   eo dr a; oMnt c a me rtainoAtstoo  whrbC meB:tyhusng uI e ,ptIswaroro L mr sUsg :rooe'sn tsdagwr ,rItneItulAa h
hin m,e ntWr I Urht: H sts.
,  tgres  ohonvaa aeduadt tKno y a
ioi I ue,t in Bo m  ieCTT
fe  I,If
n i teil!wI  nt io$tiima'a ss
ola trcdo ld tchola  rs  rt
uw ,en,s d,thiIho rib i
M o-. hd: nliu
  hhs
hiu woe m 
iidm  t ,anhniq Krwlr.r r rp tycNb e ha in,e H,tkqeBd :yth srro hhatt Cu tocrnbo,Wi ynhl: donoun i
-uu,ah;r tEEiH hngdtmi   eio nIgnohs oorrtsip
Eestoidtdotttsmrrru ,,InUtlt MaPcsyoaehI mNehauss i,,si.nu
lahue m rB rt
 :anesaee s.edofTtira  art nirdee  nrrDrob Im,aseUrs e se
iut dgfs ap du:e,in wlynw tIddntdtiag   ns r nlti r al m  gym tn
r asntls.h ,ia. od  rtenpea mtoOgtyddrsi  