In [2]:
import torch
import torch.nn as nn 

In [41]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class FeedForward(nn.Module):
    def __init__(self, embed_dim, hidden_dim):
        super().__init__()
        self.fc1 = nn.Linear(embed_dim, hidden_dim, bias=False)
        self.fc2 = nn.Linear(embed_dim, hidden_dim, bias=False)
        self.fc3 = nn.Linear(hidden_dim, embed_dim, bias=False)
    
    def forward(self, x):
        # gated SwiGLU-style FFN: silu(W1x) * (W2x)
        x_fc1 = self.fc1(x)
        x_fc2 = self.fc2(x)
        x = F.silu(x_fc1) * x_fc2
        return self.fc3(x)

In [79]:
B = 25
T = 30
D = 200
N = 6 # no of experts 
E = 3 # no of experts per token 

inp = torch.rand((B,T,D))

gate_proj = torch.rand((D, N))
fc1 = [torch.rand((D, 400)) for _ in range(N)]
fc2 = [torch.rand((D, 400)) for _ in range(N)]
fc3 = [torch.rand((400, D)) for _ in range(N)]

scores = (inp @ gate_proj).reshape(-1, N)
topk_scores, topk_indices = torch.topk(scores, E, dim=-1)
topk_probs = torch.softmax(topk_scores, dim=-1)

inp_flat = inp.reshape(-1, D)
output_flat = torch.zeros((B*T, D))

unique_experts = torch.unique(topk_indices)
for unique_expert_tensor in unique_experts:
    unique_expert = unique_expert_tensor.item()
    
    mask = topk_indices == unique_expert
    select_indices = mask.any(axis=-1)
    
    tokens_to_expert = inp_flat[select_indices, :]
    output = F.silu(tokens_to_expert @ fc1[unique_expert]) * (tokens_to_expert @ fc2[unique_expert])
    output = output @ fc3[unique_expert]

    scale = topk_probs[select_indices,  mask[select_indices].to(torch.int).argmax(axis=-1)]
    output *= scale[:, None]
    
    output_flat[select_indices, :] += output

In [81]:
x = torch.tensor([
    [10, 20, 30],
    [40, 50, 60],
    [70, 80, 90]
])

indices = torch.tensor([
    [2],   # pick col 2,1,0 for row 0
    [0],   # pick col 0,0,2 for row 1
    [1]    # pick col 1,1,1 for row 2
])

y = torch.gather(x, dim=1, index=indices)
print(y)

tensor([[30],
        [40],
        [80]])


In [None]:
class MoEFeedForward(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.num_experts_per_tok = cfg["num_experts_per_tok"]
        self.num_experts = cfg["num_experts"]
        self.emb_dim = cfg["emb_dim"]
        self.gate = nn.Linear(cfg["emb_dim"], cfg["num_experts"], bias=False, dtype=cfg["dtype"])

        self.fc1 = nn.ModuleList([nn.Linear(cfg["emb_dim"], cfg["moe_intermediate_size"], bias=False, dtype=cfg["dtype"])
                                  for _ in range(cfg["num_experts"])])
        self.fc2 = nn.ModuleList([nn.Linear(cfg["emb_dim"], cfg["moe_intermediate_size"], bias=False, dtype=cfg["dtype"])
                                  for _ in range(cfg["num_experts"])])
        self.fc3 = nn.ModuleList([nn.Linear(cfg["moe_intermediate_size"], cfg["emb_dim"], bias=False, dtype=cfg["dtype"])
                                  for _ in range(cfg["num_experts"])])

    def forward(self, x):
        scores = self.gate(x)  # (b, seq_len, num_experts)
        topk_scores, topk_indices = torch.topk(scores, self.num_experts_per_tok, dim=-1)
        topk_probs = torch.softmax(topk_scores, dim=-1)

        batch, seq_len, _ = x.shape
        x_flat = x.reshape(batch * seq_len, -1)
        out_flat = torch.zeros(batch * seq_len, self.emb_dim, device=x.device, dtype=x.dtype)

        topk_indices_flat = topk_indices.reshape(-1, self.num_experts_per_tok)
        topk_probs_flat = topk_probs.reshape(-1, self.num_experts_per_tok)

        unique_experts = torch.unique(topk_indices_flat)

        for expert_id_tensor in unique_experts:
            expert_id = int(expert_id_tensor.item())
            mask = topk_indices_flat == expert_id
            if not mask.any():
                continue

            token_mask = mask.any(dim=-1)
            selected_idx = token_mask.nonzero(as_tuple=False).squeeze(-1)
            if selected_idx.numel() == 0:
                continue

            expert_input = x_flat.index_select(0, selected_idx)
            hidden = torch.nn.functional.silu(self.fc1[expert_id](expert_input)) * self.fc2[expert_id](expert_input)
            expert_out = self.fc3[expert_id](hidden)

            mask_selected = mask[selected_idx]
            slot_indices = mask_selected.int().argmax(dim=-1, keepdim=True)
            selected_probs = torch.gather(topk_probs_flat.index_select(0, selected_idx), dim=-1, index=slot_indices).squeeze(-1)

            out_flat.index_add_(0, selected_idx, expert_out * selected_probs.unsqueeze(-1))

        return out_flat.reshape(batch, seq_len, self.emb_dim)