### Let us code Mixture of Experts (MoE) From Stratch

### Step 0 Load packages and import data

In [3]:
## import the necessary packages and set for reproductibity.For this notebook , pytorch is all you need

import torch
import torch.nn as nn
from torch.nn import functional as F
torch.manual_seed(42)

<torch._C.Generator at 0x7f35f01d9a10>

In [4]:
## Downloading the tiny shapespare dataset
!wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt

--2025-05-19 07:34:23--  https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.111.133, 185.199.109.133, 185.199.108.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.111.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 1115394 (1.1M) [text/plain]
Saving to: ‘input.txt’


2025-05-19 07:34:23 (23.9 MB/s) - ‘input.txt’ saved [1115394/1115394]



### Step 1 Define each expert as a neural network



![alt text](IMG/number_of_experts.jpg)

In [5]:
## Expert Module
class Expert(nn.Module):
  """An MLP is a simple linear layer followed by non linearuty i.e each expert"""

  def __init__(self, n_embed,dropout):
    super().__init__()
    self.net=nn.Sequential(
        nn.Linear(n_embed,4*n_embed),
        nn.ReLU(),
        nn.Linear(4*n_embed,n_embed),
        nn.Dropout(dropout)


    )
  def forward(self,x):
    return self.net(x)

### Step 2: Implement the Router

![alt text](<IMG/Routing Matrix.jpg>)

The router , determines which expert network recieves the output for each token from the multi head attention

In [6]:
## Understanding how gating works
num_experts=3
top_k=2
n_embed=8

## Example of multi head attention output for a simple illustrative example consider n_embed=32 ,context length=

mh_output=torch.randn(1,4,n_embed)

routing_matrix=nn.Linear(n_embed,num_experts) ## nn.Linear(32,4)

expert_selector_matrix=routing_matrix(mh_output) ## (1,4,4)

print(expert_selector_matrix)

tensor([[[ 0.0238, -0.2771, -0.5070],
         [-0.5727, -0.9081,  0.1839],
         [ 0.8137,  0.1781,  1.5661],
         [ 0.6523,  0.4525,  0.0062]]], grad_fn=<ViewBackward0>)


### Step 3 : Implement Topk Load Balancing

![alt text](<IMG/expert selector matrix.jpg>)

In [7]:
top_k_logits,top_k_indices=expert_selector_matrix.topk(top_k,dim=-1) ## Get top-k experts
print(top_k_logits)
print(top_k_indices)

tensor([[[ 0.0238, -0.2771],
         [ 0.1839, -0.5727],
         [ 1.5661,  0.8137],
         [ 0.6523,  0.4525]]], grad_fn=<TopkBackward0>)
tensor([[[0, 1],
         [2, 0],
         [2, 0],
         [0, 1]]])


### Step 4: Use -inf and apply Softmax

![alt text](<IMG/apply softmax on expert matrix.jpg>)

In [8]:
## Full_like clones a tensor and fill it with a specified value
zeros=torch.full_like(expert_selector_matrix,float("-inf"))
sparse_logits=zeros.scatter(-1,top_k_indices,top_k_logits)
print(sparse_logits)

tensor([[[ 0.0238, -0.2771,    -inf],
         [-0.5727,    -inf,  0.1839],
         [ 0.8137,    -inf,  1.5661],
         [ 0.6523,  0.4525,    -inf]]], grad_fn=<ScatterBackward0>)


In [9]:
gating_output=F.softmax(sparse_logits,dim=-1)
print(gating_output)

tensor([[[0.5747, 0.4253, 0.0000],
         [0.3194, 0.0000, 0.6806],
         [0.3203, 0.0000, 0.6797],
         [0.5498, 0.4502, 0.0000]]], grad_fn=<SoftmaxBackward0>)


### Step 5: Creare a class for TopKRoting

In [10]:
## First define the topk router module
class TopkRouter(nn.Module):
  def __init__(self,n_embed,num_experts,top_k) :
    super(TopkRouter,self).__init__()
    self.top_k=top_k
    self.linear=nn.Linear(n_embed,num_experts)


  def forward(self,mh_output):
    ## mh_output is the output tensor from multihead self attention block

    routing_matirx=self.linear(mh_output)

    top_k_logits,top_k_indices=routing_matirx.topk(self.top_k,dim=-1)

    zeros=torch.full_like(routing_matirx,float("-inf"))

    sparse_logits=zeros.scatter(-1,top_k_indices,top_k_logits)

    expert_select_weight_matrix=F.softmax(sparse_logits,dim=-1)

    return expert_select_weight_matrix,top_k_indices


In [11]:
### Testing this out:
num_experts=3
top_k=2
n_embed=8

mh_output=torch.randn(1,4,n_embed)

top_k_gate=TopkRouter(n_embed,num_experts,top_k)

router_output=top_k_gate(mh_output)

print(router_output)

(tensor([[[0.6177, 0.0000, 0.3823],
         [0.6445, 0.3555, 0.0000],
         [0.0000, 0.3600, 0.6400],
         [0.5666, 0.4334, 0.0000]]], grad_fn=<SoftmaxBackward0>), tensor([[[0, 2],
         [0, 1],
         [2, 1],
         [0, 1]]]))


### Step 6: Create a class for NosiyTopkRouting

![alt text](<IMG/Noisy Top K Routing.jpg>)

Nosiy top k gating is an important tool in training Moe models

Essentially you dont want all the tokens to sent to the same set of favoured experts.

You want a fine balance of exploration and exploitation.for this purpose ,to load balance ,it is helpful to add standard normal to logits from the gating linear layer.This makes training more efficient.


In [14]:
## Changing the above to accomdate nosiy top k gating

class NoisyTopkRouter(nn.Module):
  def __init__(self,n_embed,num_experts,top_k):
    super(NoisyTopkRouter,self).__init__()
    self.top_k=top_k

    ## Layer for router logits
    self.topkroute_linear=nn.Linear(n_embed,num_experts)
    self.noise_linear=nn.Linear(n_embed,num_experts)


  def forward(self,mh_output):
    ## Mh_output is the output tensor from multihead self attention block

    logits=self.topkroute_linear(mh_output)

    ## Noise logits
    noise_logits=self.noise_linear(mh_output)

    ## Adding scaled unit gaussian noise to the logits
    noise=torch.randn_like(logits)*F.softplus(noise_logits)

    logits=logits+noise

    top_k_logits,top_k_indices=logits.topk(self.top_k,dim=-1)

    zeros=torch.full_like(logits,float("-inf"))

    sparse_logits=zeros.scatter(-1,top_k_indices,top_k_logits)

    expert_selector_weight_matrix=F.softmax(sparse_logits,dim=-1)

    return expert_selector_weight_matrix,top_k_indices


In [15]:
### Testing this out , again
num_experts=4
top_k=2
n_embed=8

mh_output=torch.randn(1,4,n_embed)

top_k_gate=NoisyTopkRouter(n_embed,num_experts,top_k)

expert_selector_weight_matrix,indices=top_k_gate(mh_output)
print(expert_selector_matrix)
print(indices)

tensor([[[ 0.0238, -0.2771, -0.5070],
         [-0.5727, -0.9081,  0.1839],
         [ 0.8137,  0.1781,  1.5661],
         [ 0.6523,  0.4525,  0.0062]]], grad_fn=<ViewBackward0>)
tensor([[[3, 1],
         [0, 2],
         [3, 1],
         [3, 1]]])


### Step 7: Create the sparse Mixture of experts (Moe)

![alt text](<IMG/expert calculations.jpg>)

![alt text](<IMG/expert cal 1.jpg>)

![alt text](<IMG/expert cal 2.jpg>)

- The primary aspect of this process involves the expert selector weight matrix.

- After acquiring the expert selector weight matrix,top k values are selectively multipied with the outputs from the corresponding top-k
experts for a given token.

- This Selective multiplication forms a weighted sum, which constitues the spareseMoe blocks output.

In [44]:
flat_x=mh_output.view(-1,mh_output.size(-1))
flat_x.shape

torch.Size([4, 8])

In [17]:
expert_selector_weight_matrix

tensor([[[0.0000, 0.3531, 0.0000, 0.6469],
         [0.9226, 0.0000, 0.0774, 0.0000],
         [0.0000, 0.4685, 0.0000, 0.5315],
         [0.0000, 0.2103, 0.0000, 0.7897]]], grad_fn=<SoftmaxBackward0>)

In [18]:
expert=nn.ModuleList([Expert(n_embed,0) for _ in range(num_experts)])

In [19]:
indices

tensor([[[3, 1],
         [0, 2],
         [3, 1],
         [3, 1]]])

In [20]:
indices==0

tensor([[[False, False],
         [ True, False],
         [False, False],
         [False, False]]])

In [45]:
flat_mask=(indices==0).any(dim=-1).view(-1)
flat_mask.shape

torch.Size([4])

In [22]:
if flat_mask.any():
  expert_input=flat_x[flat_mask]
  print(expert_input)
  expert_output=expert[0](expert_input)
  print(expert_output)


tensor([[ 0.3416, -0.2214,  1.2554, -0.7150,  0.8539,  0.5130,  0.5397,  0.5655]])
tensor([[ 0.1063,  0.1303,  0.1469,  0.0185, -0.2671,  0.0029, -0.0806, -0.0068]],
       grad_fn=<AddmmBackward0>)


In [23]:
flat_expert_selector_weight_matrix=expert_selector_weight_matrix.view(-1,expert_selector_weight_matrix.size(1))
flat_expert_selector_weight_matrix

tensor([[0.0000, 0.3531, 0.0000, 0.6469],
        [0.9226, 0.0000, 0.0774, 0.0000],
        [0.0000, 0.4685, 0.0000, 0.5315],
        [0.0000, 0.2103, 0.0000, 0.7897]], grad_fn=<ViewBackward0>)

In [24]:
gating_scores=flat_expert_selector_weight_matrix[flat_mask,0].unsqueeze(1)

In [25]:
weighted_output=expert_output*gating_scores
weighted_output

tensor([[ 0.0980,  0.1202,  0.1356,  0.0171, -0.2465,  0.0027, -0.0743, -0.0063]],
       grad_fn=<MulBackward0>)

In [26]:
weighted_output.squeeze(1)

tensor([[ 0.0980,  0.1202,  0.1356,  0.0171, -0.2465,  0.0027, -0.0743, -0.0063]],
       grad_fn=<SqueezeBackward1>)

In [46]:
class SparseMoE(nn.Module):
  def __init__(self,n_embed,num_experts,top_k):
     super(SparseMoE,self).__init__()
     self.router=NoisyTopkRouter(n_embed,num_experts,top_k)
     self.experts=nn.ModuleList([Expert(n_embed,0) for _ in range(num_experts)])

  def forward(self,x):
    ## GETTING THE EXPERT SELECT WEIGHT MATRIX AND INDICES
    expert_selector_weight_matrix,indices=self.router(x)

    ## CREATE TENSOR FOR OUTPUT OF SIZE OF MHA OUTPUT
    final_output=torch.zeros_like(x)

    ## RESHAPE INPUTS FOR BATCH PROCESSING

    flat_x=x.view(-1,x.size(-1))

    flat_expert_selector_weight_matrix=expert_selector_weight_matrix.view(-1,expert_selector_weight_matrix.size(-1))

    ## PROCESS EACH EXPERT IN PARALLEL
    for i,expert in enumerate(self.experts):
      ## CREATE A MASK FOR THE INPUTS WHERE THE CURRENT EXPERTS IS IN TOP K

      ## IT GIVE A TRUE AND FALSE VALUE FOR EACH EXPERT WHERE THE TOKEN IS SELECT FOR THAT EXPERT.
      expert_mask=(indices==i).any(dim=-1)
      flat_mask=expert_mask.view(-1)

      if flat_mask.any():
        ## IT FILTER OUT TOKENS FROM THE MHA OUTPUT FOR SELECT EXPERT.
        expert_input=flat_x[flat_mask]
        ## AND THEN OUTPUT SENT TO EXPERT
        expert_output=expert(expert_input)

        ## EXTRACT AND APPLY GATING SCORES
        gating_scores=flat_expert_selector_weight_matrix[flat_mask,i].unsqueeze(1)

        weighted_output=gating_scores*expert_output

        ## UPDATE FINAL OUTPUT ADDITIVELY BY INDEXING AND ADDING.
        final_output[expert_mask]=weighted_output.squeeze(1)

    return final_output



In [51]:
import torch
import torch.nn as nn

## Let's test this output
num_experts=3
top_k=2
n_embed=16
dropout=0.1
mh_output=torch.randn(4,8,n_embed) ## Example multi Head Attention Output
sparse_moe=SparseMoE(n_embed,num_experts,top_k)
final_output=sparse_moe(mh_output)
print("Shape of the final Output:",final_output.shape)
print(final_output)


Shape of the final Output: torch.Size([4, 8, 16])
tensor([[[-9.4351e-02,  2.5056e-01, -7.4708e-02,  8.4360e-02,  2.6926e-01,
           1.4315e-01, -3.4850e-01,  1.2705e-01, -6.9078e-02,  1.9349e-01,
          -2.3063e-02,  8.1250e-02, -1.4750e-01,  2.3005e-01, -1.3541e-01,
          -1.4158e-01],
         [-5.1728e-02,  1.3329e-01,  2.3243e-01,  2.0983e-01, -1.2289e-01,
          -1.1335e-01, -8.7145e-02,  9.0008e-02, -2.0408e-02,  2.6926e-02,
           6.0160e-02,  2.1683e-01,  1.6879e-01,  1.7361e-01,  2.7809e-02,
           1.7415e-01],
         [-6.4685e-02,  7.3002e-02, -2.3478e-02, -3.5024e-02,  1.7841e-02,
           3.7098e-02,  1.5010e-01,  2.2208e-01,  2.3342e-02,  6.4208e-02,
          -1.8596e-02, -3.7774e-02,  4.1073e-02, -5.6595e-02, -6.7111e-03,
          -1.5239e-01],
         [-4.5738e-01,  1.5445e-01,  2.3443e-01, -1.2563e-01,  1.3437e-01,
          -3.3661e-01,  2.4439e-01,  4.4006e-01,  1.2974e-01,  1.1116e-01,
           3.1039e-01,  1.6280e-01,  4.8311e-02,  3.3

## **STEP : Putting together all the building blocks of MoE**

In [52]:
## EXPERT MODULE
## Expert Module
class Expert(nn.Module):
  """An MLP is a simple linear layer followed by non linearuty i.e each expert"""

  def __init__(self, n_embed,dropout):
    super().__init__()
    self.net=nn.Sequential(
        nn.Linear(n_embed,4*n_embed),
        nn.ReLU(),
        nn.Linear(4*n_embed,n_embed),
        nn.Dropout(dropout)


    )
  def forward(self,x):
    return self.net(x)

## CHANGING THE ABOVE TO ACCOMDATE NOISY TOP K GATING

class NoisyTopkRouter(nn.Module):
  def __init__(self,n_embed,num_experts,top_k):
    super(NoisyTopkRouter,self).__init__()
    self.top_k=top_k

    ## Layer for router logits
    self.topkroute_linear=nn.Linear(n_embed,num_experts)
    self.noise_linear=nn.Linear(n_embed,num_experts)


  def forward(self,mh_output):
    ## Mh_output is the output tensor from multihead self attention block

    logits=self.topkroute_linear(mh_output)

    ## Noise logits
    noise_logits=self.noise_linear(mh_output)

    ## Adding scaled unit gaussian noise to the logits
    noise=torch.randn_like(logits)*F.softplus(noise_logits)

    logits=logits+noise

    top_k_logits,top_k_indices=logits.topk(self.top_k,dim=-1)

    zeros=torch.full_like(logits,float("-inf"))

    sparse_logits=zeros.scatter(-1,top_k_indices,top_k_logits)

    expert_selector_weight_matrix=F.softmax(sparse_logits,dim=-1)

    return expert_selector_weight_matrix,top_k_indices


## CREATE THE SPARSE MIXTURE OF EXPERTS MODULE
class SparseMoE(nn.Module):
  def __init__(self,n_embed,num_experts,top_k):
     super(SparseMoE,self).__init__()
     self.router=NoisyTopkRouter(n_embed,num_experts,top_k)
     self.experts=nn.ModuleList([Expert(n_embed,0) for _ in range(num_experts)])

  def forward(self,x):
    ## GETTING THE EXPERT SELECT WEIGHT MATRIX AND INDICES
    expert_selector_weight_matrix,indices=self.router(x)

    ## CREATE TENSOR FOR OUTPUT OF SIZE OF MHA OUTPUT
    final_output=torch.zeros_like(x)

    ## RESHAPE INPUTS FOR BATCH PROCESSING

    flat_x=x.view(-1,x.size(-1))

    flat_expert_selector_weight_matrix=expert_selector_weight_matrix.view(-1,expert_selector_weight_matrix.size(-1))

    ## PROCESS EACH EXPERT IN PARALLEL
    for i,expert in enumerate(self.experts):
      ## CREATE A MASK FOR THE INPUTS WHERE THE CURRENT EXPERTS IS IN TOP K

      ## IT GIVE A TRUE AND FALSE VALUE FOR EACH EXPERT WHERE THE TOKEN IS SELECT FOR THAT EXPERT.
      expert_mask=(indices==i).any(dim=-1)
      flat_mask=expert_mask.view(-1)

      if flat_mask.any():
        ## IT FILTER OUT TOKENS FROM THE MHA OUTPUT FOR SELECT EXPERT.
        expert_input=flat_x[flat_mask]
        ## AND THEN OUTPUT SENT TO EXPERT
        expert_output=expert(expert_input)

        ## EXTRACT AND APPLY GATING SCORES
        gating_scores=flat_expert_selector_weight_matrix[flat_mask,i].unsqueeze(1)

        weighted_output=gating_scores*expert_output

        ## UPDATE FINAL OUTPUT ADDITIVELY BY INDEXING AND ADDING.
        final_output[expert_mask]=weighted_output.squeeze(1)

    return final_output




## **Code the Entire transformer block: Part1 (Multi Head Attention)**

![
](IMG/mha_1jpg.jpg)

![
](IMG/mha_2.jpg)
![alt text](IMG/mha_3.jpg)
![alt text](IMG/mha_4.jpg)


In [54]:
class Head(nn.Module):
  """One head of self attention"""

  def __init__(self,head_size):
    super().__init__()
    self.key=nn.Linear(n_embed,head_size,bias=False)
    self.query=nn.Linear(n_embed,head_size,bias=False)
    self.value=nn.Linear(n_embed,head_size,bias=False)
    self.register_buffer("tril",torch.tril(torch.ones(block_size,block_size)))

    self.dropout=nn.Dropout(dropout)

  def forward(self,x):
    B,T,C=x.shape
    k=self.key(x) ## (B,T,C)
    q=self.query(x) ## (B,T,C)
    ## Compute attention scores ("affinities")
    wei=q @ k.transpose(-2,-1) * C**-0.5 ## (B,T,C) @ (B,C,T) -> (B,T,T)
    wei=wei.masked_fill(self.tril[:T,:T]==0,float("-inf")) ## (B,T,T)
    wei=F.softmax(wei,dim=-1) ## (B,T,T)
    wei=self.dropout(wei)
    ## Perform the weighted aggregation of the values
    v=self.value(x) ## (B,T,C)
    out=wei@v ## (B,T,T)@(B,T,C)->(B,T,C)

    return out

## MULTIHEAD SELF ATTENTION
class MultiHeadAttention(nn.Module):
  """Multiple heads of self attention in parallel"""

  def __init__(self,num_heads,head_size):
    super().__init__()
    self.heads=nn.ModuleList([Head(head_size) for _ in range(num_heads)])
    self.proj=nn.Linear(n_embed,n_embed)
    self.dropout=nn.Dropout(dropout)

  def forward(self,x):
    out=torch.cat([h(x) for h in self.heads],dim=-1)
    out=self.dropout(self.proj(out))
    return out

## **STEP 10: Code the Entire Transformer Block Part:2(Assemble all layers)**

![alt text](IMG/transformer.jpg)

In [64]:
## First create a self attention +mixture of experts block that may be repeated several number of times
## Copy pasting key architecture variables for clarity

class Block(nn.Module):
  """Mixture of Experts Transformer block: Communicution followed by computation (multi head self atttention+sparse mixture of experts)"""
  def __init__(self,n_embed,n_heads,num_experts,top_k,dropout):
    super().__init__()
    head_size=n_embed//n_heads
    self.sa=MultiHeadAttention(n_heads,head_size)
    self.moe=SparseMoE(n_embed,num_experts,top_k)
    self.ln1=nn.LayerNorm(n_embed)
    self.ln2=nn.LayerNorm(n_embed)

  def forward(self,x):
    x=x+self.sa(self.ln1(x))
    x=x+self.moe(self.ln2(x))
    return x


## **STEP 11: Define Entire Language Model Architecture**

![alt text](IMG/sparse_modeljpg.jpg)

In [61]:
## Finally putting it all together to create a sparse mixture of experts language model

class SparseMoELanguageModel(nn.Module):
  def __init__(self):
    super().__init__()
    self.token_embedding_table=nn.Embedding(vocab_size,n_embed)
    self.position_embedding_table=nn.Embedding(block_size,n_embed)
    self.blocks=nn.Sequential(*[Block(n_embed,n_heads,num_experts,top_k,dropout) for _ in range(n_layers)])
    self.ln_f=nn.LayerNorm(n_embed)
    self.lm_head=nn.Linear(n_embed,vocab_size)

  def forward(self,idx,targets=None):
    B,T=idx.shape

    ## idx and targets are both (B,T) tensor of integers
    tok_emb=self.token_embedding_table(idx) ## (B,T,C)
    pos_emb=self.position_embedding_table(torch.arange(T,device=device)) ## (T,C)
    x=tok_emb+pos_emb ## (B,T,C)
    x=self.blocks(x) ## (B,T,C)
    x=self.ln_f(x) ## (B,T,C)
    logits=self.lm_head(x) ## (B,T,vocab_size)
    if targets is None:
      loss=None
    else:
      B,T,C=logits.shape
      logits=logits.view(B*T,C)
      targets=targets.view(B*T)
      loss=F.cross_entropy(logits,targets)
    return logits,loss

  def generate(self,idx,max_new_tokens):
    ## idx is (B,T) array of indices in the current context
    for _ in range(max_new_tokens):
      ## crop idx to the last block_size tokens
      idx_cond=idx[:,-block_size:]
      ## get the predictions
      logits,loss=self(idx_cond)
      ## focus only on the last time step
      logits=logits[:,-1,:]
      ## apply softmax to get probabilities
      probs=F.softmax(logits,dim=-1) ## (B,C)
      ## sample from the distribution
      idx_next=torch.multinomial(probs,num_samples=1) ## (B,1)
      ## Append sampled index to the running sequence
      idx=torch.cat((idx,idx_next),dim=1) ## (B,T+1)
    return idx

## **STEP 12: Creating Training and Testing Data**

In [67]:
torch.manual_seed(1337)

with open("input.txt","r",encoding="utf-8") as f:
  text=f.read()

## Here are all the unique characters that occur in this text
chars=sorted(list(set(text)))
vocab_size=len(chars)

## Create a mapping from characters to integers
stoi={ch:i for i,ch in enumerate(chars)}
itos={i:ch for i,ch in enumerate(chars)}
encode=lambda s:[stoi[c] for c in s] ## encoder: take a string, output a list of integers
decode=lambda l:"".join([itos[i] for i in l]) ## decoder: take a list of integers, output a string

## Train and test splits
data=torch.tensor(encode(text),dtype=torch.long)
n=int(0.9*len(data))
train_data=data[:n]
val_data=data[n:]

## Data Loading
## Data Loading
def get_batch(split):
  ## Generate a small batch of data of inputs x and targets y
  data=train_data if split=="train" else val_data
  ix=torch.randint(len(data)-block_size,(batch_size,))
  x=torch.stack([data[i:i+block_size] for i in ix])
  y=torch.stack([data[i+1:i+block_size+1] for i in ix])
  # Move the tensors to the specified device
  x, y = x.to(device), y.to(device)
  return x,y

## **STEP 13: Define LLM Loss**

In [57]:
@torch.no_grad()
def estimate_loss():
  out={}
  model.eval()
  for split in ["train","val"]:
    losses=torch.zeros(eval_iters)
    for k in range(eval_iters):
      X,Y=get_batch(split)
      logits,loss=model(X,Y)
      losses[k]=loss.item()
    out[split]=losses.mean()
  model.train()
  return out

## **STEP 14: Define Training loop parameters and other hyperparameters**

In [58]:
## First defining hyperparameters and bioler plate code and data preparation code is repeated for convenice
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import init
## Hyperparameters
batch_size=16
block_size=32
max_iters=20
eval_interval=100
learning_rate=1e-3
device="cuda" if torch.cuda.is_available() else "cpu"
eval_iters=400
head_size=32
n_layers=8
n_heads=8
dropout=0.1
n_embed=128
num_experts=8
top_k=2

## **STEP 15 : Initialize the Entire Model**

In [62]:
def kaiming_init(m):
  if isinstance(m,(nn.Linear)):
    init.kaiming_uniform_(m.weight)

In [65]:
model=SparseMoELanguageModel()
model.apply(kaiming_init)

SparseMoELanguageModel(
  (token_embedding_table): Embedding(65, 128)
  (position_embedding_table): Embedding(32, 128)
  (blocks): Sequential(
    (0): Block(
      (sa): MultiHeadAttention(
        (heads): ModuleList(
          (0-7): 8 x Head(
            (key): Linear(in_features=128, out_features=16, bias=False)
            (query): Linear(in_features=128, out_features=16, bias=False)
            (value): Linear(in_features=128, out_features=16, bias=False)
            (dropout): Dropout(p=0.1, inplace=False)
          )
        )
        (proj): Linear(in_features=128, out_features=128, bias=True)
        (dropout): Dropout(p=0.1, inplace=False)
      )
      (moe): SparseMoE(
        (router): NoisyTopkRouter(
          (topkroute_linear): Linear(in_features=128, out_features=8, bias=True)
          (noise_linear): Linear(in_features=128, out_features=8, bias=True)
        )
        (experts): ModuleList(
          (0-7): 8 x Expert(
            (net): Sequential(
              

## **STEP 16 : Run The Pre Training Loop**

In [68]:
## Not using Mlflow
m=model.to(device)
## Print the number of parameters in the model
print(sum(p.numel() for p in m.parameters())/1e6,"M parameters")

## Create a PyTorch optimizer
optimizer=torch.optim.AdamW(model.parameters(),lr=learning_rate)

for iter in range(max_iters):
  ## every once in a while evaluate the loss on train and val sets
  if iter%eval_interval==0 or iter==max_iters-1:
    losses=estimate_loss()
    print(f"step {iter}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}")

  ## sample a batch of data
  xb,yb=get_batch("train")

  ## Evaluate the loss
  logits,loss=model(xb,yb)
  optimizer.zero_grad(set_to_none=True)
  loss.backward()
  optimizer.step()

8.996545 M parameters
step 0: train loss 5.4057, val loss 5.3837
step 100: train loss 2.6960, val loss 2.6898
step 199: train loss 2.5163, val loss 2.5302


## STEP 17: Inference

In [69]:
## Generate from the model .Not great .Not too bad either
context=torch.zeros((1,1),dtype=torch.long,device=device)
print(decode(m.generate(context,max_new_tokens=2000)[0].tolist()))


CClas
Kp pe y wiso stathatherrd llathe y,
th,
Dha pofome me thisatr! po e toud vey, thesor he preng ak isor s wu.
APAshrd w thiealo k d pesu t!
I :

OLB he inde!
Ror d tofolrcpess? f cese cerormarnonowinkom, ard ttie cehe
UK mr sthe thado sonwosCutthe gst uu hu f weinoinNo,Ulleand kedXer
Dur the kidse irth arg er buw heng od,
Wamr foplor myoucaf? lt nfw f t Yosond tthoondin ah, iger atonz?


AROO:
Kind mathit tord INII o thed these uthed y bfoond thuve sendesky trir

Whitoipe hy thrort bo mm is ee, arro Angote?
Whu e billartheBcand, cor ltimaivul rl funchenomt s;
Ter
D c;
T hande manofa ben y ivendeeshet th!
Anferd, swhe gunsural
Tot thikinck cry d othor;
Sothy borig oule brth fl-e te w.
Tleld thethe then wo concucinle tose- ngnenete baver
pbe.

RLId md n on Bus pstond hine hasthe.
WeXe hilt he path akn'se sy?
PLourrw CWessimy d th!
Whet mor w wanes, Four I myondtowe m I mmazeve Mive:
He
Se aton hevecagheoven, pw thod cakyyotow,
Ye we, n chouoe kir oulagr, gn w nthe.

Se malee sin tho