In [212]:
import torch
import math
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

In [None]:
class Expert(nn.Module):
    def __init__(self,ff_dim: int,embed_dim: int):
        super().__init__()
        self.lin1,self.lin2 = nn.Linear(embed_dim,ff_dim),nn.Linear(ff_dim,embed_dim)
    def forward(self,x):
        return self.lin2(F.gelu(self.lin1(x)))

class Router(nn.Module):
    def __init__(self,embed_dim: int, n_experts: int):
        super().__init__()
        self.lin1= nn.Linear(embed_dim,n_experts)
    def forward(self,x):
        return F.softmax(self.lin1(x),dim=-1)
    
class MoE(nn.Module):
    def __init__(self,ff_dim: int,embed_dim: int, n_experts: int, k: int):
        super().__init__()
        self.router = Router(embed_dim,n_experts)
        self.experts = nn.ModuleList([
            Expert(ff_dim,embed_dim)
            for _ in range(n_experts)
        ])

        self.k = k
    
    def forward(self,x_ble, k=None):
        #step 1: get scores
        scores_bln = self.router(x_ble)
        #step 2: select top-k
        K = self.k
        if k is not None: K=k
        vals_blk,idxs_blk = torch.topk(scores_bln,K)
        vals_blk = F.normalize(vals_blk,p=1,dim=-1)

        #step 3+4: compute and weight
        out_ble = torch.zeros_like(x_ble)
        for i, expert in enumerate(self.experts):
            batch_idx,seq_idx,k_idx = torch.where(idxs_blk == i)
            out_ble[batch_idx,seq_idx] += expert(x_ble[batch_idx,seq_idx]) * vals_blk[batch_idx,seq_idx,k_idx][:,None]
        return out_ble 


class MoE1(nn.Module):
    def __init__(self, ff_dim: int, embed_dim: int, n_experts: int, k: int):
        super().__init__()
        self.k = k
        self.router = Router(embed_dim, n_experts)
        
        # TODO: Define w1 and w2 as nn.Parameters with the shapes we agreed on
        stdv = 1. / math.sqrt(embed_dim)
        self.w1_nfe = nn.Parameter(2 * stdv * torch.rand(n_experts,ff_dim,embed_dim) - stdv) 
        stdv = 1. / math.sqrt(ff_dim)
        self.w2_nef = nn.Parameter(2 * stdv * torch.rand(n_experts,embed_dim,ff_dim) - stdv) 

    def forward(self, x_ble):
        scores_bln = self.router(x_ble)
        vals_blk,idxs_blk = torch.topk(scores_bln,self.k,sorted=False)
        mask_bln = torch.zeros_like(scores_bln).scatter(2,idxs_blk,1)
        scores_bln = F.normalize(scores_bln * mask_bln,p=1,dim=-1)

        out_blnf = torch.einsum('ble,nfe->blnf',x_ble,self.w1_nfe)
        out_blnf = F.gelu(out_blnf)
        out_blne = torch.einsum('blnf,nef->blne',out_blnf,self.w2_nef)
        out_ble = torch.einsum('blne,bln->ble',out_blne,scores_bln)
        return out_ble




In [222]:
# 1. Setup constants
BATCH_SIZE = 32
SEQ_LEN = 64
EMBED_DIM = 8
FF_DIM = 32
N_EXPERTS = 4
K = 2 # Top-2 experts

# 2. Create dummy input
# Shape: [Batch, Seq, Embed]
x = torch.randn(BATCH_SIZE, SEQ_LEN, EMBED_DIM)

# 3. Initialize Model
model = MoE1(FF_DIM, EMBED_DIM, N_EXPERTS, K)

# 4. Run Forward Pass
# This will return zeros right now because the loop is empty!
output = model(x)

print("Input shape:", x.shape)
print("Output shape:", output.shape)

Input shape: torch.Size([32, 64, 8])
Output shape: torch.Size([32, 64, 8])


In [171]:
# Check if output is non-zero
print("Is output non-zero?", torch.any(output != 0).item())

Is output non-zero? True


In [None]:
A,B = torch.rand(5,5),torch.rand(5,5)
C = torch.zeros_like(A)

for i in range(A.shape[0]):
    for j in range(B.shape[1]):
        total = 0.0
        for k in range(A.shape[0]):
            total += A[i,k] * B[k,j]
        C[i,j] = total

In [183]:
C = torch.einsum('ik,kj->ij',A,B)

In [189]:
A,B = torch.rand(5),torch.rand(5)
C = torch.einsum('i,i->',A,B)

In [206]:
batch,length,embed,ff,n_exp = 32,64,16,32,8
A_ble,B_nfe,C_nef = torch.rand(batch,length,embed),torch.rand(n_exp,ff,embed), torch.rand(n_exp,embed,ff)
probs_bln = torch.rand(batch,length,n_exp)

In [208]:
out_blnf = torch.einsum('ble,nfe->blnf',A_ble,B_nfe)
out_blnf = F.gelu(out_blnf)
out_blne = torch.einsum('blnf,nef->blne',out_blnf,C_nef)
out_ble = torch.einsum('blne,bln->ble',out_blne,probs_bln)


In [227]:
a = torch.randint(0,2,(10,))

In [228]:
torch.bincount(a)

tensor([6, 4])

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math

class Router(nn.Module):
    def __init__(self, embed_dim: int, n_experts: int):
        super().__init__()
        self.lin1 = nn.Linear(embed_dim, n_experts)
    
    def forward(self, x):
        return F.softmax(self.lin1(x), dim=-1)

class MoE(nn.Module):
    def __init__(self, ff_dim: int, embed_dim: int, n_experts: int, k: int, capacity: int = 512):
        super().__init__()
        self.k = k
        self.n_experts = n_experts
        self.router = Router(embed_dim, n_experts)
        
        # Dense parameters for all experts (Vectorized)
        stdv = 1. / math.sqrt(embed_dim)
        self.w1_nfe = nn.Parameter(2 * stdv * torch.rand(n_experts, ff_dim, embed_dim) - stdv) 
        stdv = 1. / math.sqrt(ff_dim)
        self.w2_nef = nn.Parameter(2 * stdv * torch.rand(n_experts, embed_dim, ff_dim) - stdv) 

    def forward(self, x_ble):
        # --- Step 1: Routing ---
        scores_bln = self.router(x_ble)
        vals_blk, idxs_blk = torch.topk(scores_bln, self.k, sorted=False)
        vals_blk = F.normalize(vals_blk, p=1, dim=-1)

        # --- Step 2: Flattening ---
        # Treat batch+seq as one long list of tokens
        x_te = x_ble.flatten(0, 1) #t=b*l
        idxs_r = idxs_blk.flatten() # All expert choices (r=b*l*k)

        # --- Step 3: Permutation / Sorting ---
        # 3a. Create Source Indices (who sent this request?)
        src_r = torch.repeat_interleave(torch.arange(x_te.shape[0], device=x_ble.device), K)
        
        # 3b. Sort by Expert ID to group requests together
        perm_r = torch.argsort(idxs_r)
        
        # 3c. Apply the Sort
        idxs_r = idxs_r[perm_r]   # Sorted expert IDs: [0, 0, ..., 1, 1, ...]
        src_r = src_r[perm_r]     # Corresponding source tokens

        # --- Step 4: Grouping & Capacity (Current Step) ---
        # Calculate how many tokens each expert has
        counts_n = torch.bincount(idxs_r, minlength=self.n_experts)
        
        # Calculate where each expert's block starts in the sorted list
        # cumsum gives the ends: [3, 5, ...]
        cumulative_n = torch.cumsum(counts_n, dim=0)
        
        # We prepend 0 to get the starts: [0, 3, 5, ...]
        starts_nplus1 = torch.cat([torch.tensor([0], device=x_ble.device), cumulative_n[:-1]])
        
        # ... We are here ...

        ranks_r = torch.arange(idxs_r.shape[0]) - starts_nplus1[idxs_r]

        mask = (ranks_r < self.capacity)
        idxs_leqc = idxs_r[mask]
        src_leqc = src_r[mask]
        ranks_leqc = ranks_r[mask]


        return x_te

In [230]:
mask = [True,False,False]

x = np.array([1,2,3])

In [231]:
x[mask]

array([1])