In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from dataclasses import dataclass

In [20]:
class basicExpert(nn.Module):
  def __init__(self, feature_in, feature_out):
    super().__init__()

    self.expert = nn.Linear(feature_in, feature_out)

  def forward(self, x):
    output = self.expert(x)
    return output


In [19]:
a = torch.tensor([[3,2],[1,3],[1,2],[2,1]])
print(a)
b = F.one_hot(a, 4)
b = b.permute(2,1,0)

idx, top_x = torch.where(b[3])

print(f'idx: {idx}, top_x: {top_x}')

current = a[top_x, :].reshape(-1, 4)
print(current)

current_sqz = a.unsqueeze(
    0
)[:, top_x, :].reshape(-1, 4)
print(current_sqz)

tensor([[3, 2],
        [1, 3],
        [1, 2],
        [2, 1]])
idx: tensor([0, 1]), top_x: tensor([0, 1])
tensor([[3, 2, 1, 3]])
tensor([[3, 2, 1, 3]])


In [29]:
class Router(nn.Module):
  def __init__(self, embedding_size, num_experts, top_k):
    super().__init__()
    self.gate = nn.Linear(embedding_size, num_experts)
    self.num_experts = num_experts
    self.topk = top_k

  def forward(self, x):
    # x.shape = (b*s,embedding), weight.shape = (b*s,num_experts)
    weight = self.gate(x)
    probs = F.softmax(weight, dim = 1)

    # dim=-1: select on the last dim
    topk_val, selected_experts = torch.topk(probs, k=self.topk, dim=-1)

    # shape = (b*s, top_k)
    topk_weights = topk_val / topk_val.sum(dim=1,keepdim=True)
    topk_weights = topk_weights.to(x.dtype)

    # mask.shape = (b*s, top_k, num_experts)
    expert_mask = F.one_hot(
        selected_experts,
        self.num_experts
    )
    # mask.shape = (num_experts, top_k, b*s)
    expert_mask = expert_mask.permute(2,1,0)

    return topk_weights, selected_experts, expert_mask

@dataclass
class MOEconfig:
  num_experts: int = 8
  top_k: int = 2
  embedding_size: int = 768

class sparseMOE(nn.Module):
  def __init__(self, num_experts, embedding_size, top_k):
    super().__init__()
    self.embedding_size = embedding_size
    self.experts = nn.ModuleList(
        [basicExpert(embedding_size, embedding_size)
        for _ in range(num_experts)]
    )
    self.num_experts = num_experts
    self.topk = top_k

  def forward(self, x):
    batch, seq_length, embedding_size = x.size()

    x = x.view(-1, embedding_size) # (b*s, embedding_size)

    router = Router(embedding_size, self.num_experts, self.topk)
    topk_weights, selected_experts, expert_mask = router(x)

    hidden_states = torch.zeros(
        [batch*seq_length, embedding_size],
        dtype = x.dtype
    )
    for expert_idx in range(self.num_experts):
      expert_layer = self.experts[expert_idx]
      # idx: 0/1 -> top1/top2 topx: 0,1,2,...,num_experts -> which experts
      idx, topx = torch.where(expert_mask[expert_idx])
      current_state = x[topx, :].reshape(-1, embedding_size)
      current_hidden_states = expert_layer(
          current_state
      ) * topk_weights[topx, idx].unsqueeze(-1)

      hidden_states.index_add_(0, topx, current_hidden_states).to(x.dtype)

    hidden_states = hidden_states.reshape(batch, seq_length, embedding_size)

    return hidden_states

def test_sparseMOE(x, num_experts, embedding_size, top_k):
  moe = sparseMOE(num_experts, embedding_size, top_k)
  output = moe(x)
  return output

x = torch.rand([3,2,16])
# print(x)
output = test_sparseMOE(x, num_experts = 4, embedding_size = 16, top_k = 2)
print(output.shape)

torch.Size([3, 2, 16])
