In [29]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from dataclasses import dataclass

In [None]:
class router(nn.Module):
  def __init__(self, embedding_size, num_experts, top_k):
    self.gate = nn.Linear(embedding_size, num_experts)
    self.num_experts = num_experts
    self.topk = top_k

  def forward(self, x):
    # x.shape = (b*s,embedding), weight.shape = (b*s,num_experts)
    weight = self.gate(x)
    probs = F.softmax(weight, dim = 1)

    # dim=-1: select on the last dim
    topk_val, selected_experts = torch.topk(probs, k=self.topk, dim=-1)

    # shape = (b*s, top_k)
    topk_val = topk_val / topk_val.sum(dim=1,keepdim=True)
    topk_val = topk_val.to(x.dtype)

    expert_mask = F.one_hot(
        selected_experts,
        self.num_experts
    )
    expert_mask = expert_mask.permute(2,1,0)

    return topk_val, selected_experts, expert_mask

@dataclass
class MOEconfig:
  num_experts: int = 8
  top_k: int = 2
  embedding_size: int = 768