In [None]:
import torch
import torch.nn as nn

class GatingNetwork(nn.Module):

  def __init__(self, input_size: int ,output_size:int):
    super().__init__()
    self.fc1 = nn.Linear(input_size, output_size)
    self.softmax = nn.Softmax(dim=-1)
  def forward(self, x):
    logits = self.fc1(x)
    gating_weights = self.softmax(logits)
    return gating_weights

class Expert1(nn.Module):
  def __init__(self, input_size, hidden_size, output_size):
    super().__init__()
    self.ff = nn.Sequential(
        nn.Linear(input_size, hidden_size),
        nn.ReLU(),
        nn.Linear(hidden_size, output_size),
        nn.ReLU()
    )
  def forward(self, x):
    return self.ff(x)

class Expert2(nn.Module):

  def __init__(self, input_size, hidden_size, output_size):
    super().__init__()
    self.ff = nn.Sequential(
        nn.Linear(input_size, hidden_size),
        nn.ReLU(),
        nn.Linear(hidden_size, output_size),
        nn.ReLU()
    )
  def forward(self, x):
    return self.ff(x)

class MOE(nn.Module):

  def __init__(self, gating_network : GatingNetwork, input_size, hidden_size, output_size, top_k=2):
    super().__init__()
    self.gating_network = gating_network
    self.experts = nn.ModuleList([Expert1(input_size, hidden_size, 4), Expert2(input_size, hidden_size, 4)])
    self.top_k = top_k

  def forward(self, x):

    gating_weights = self.gating_network(x)
    topk_weights, topk_indices = torch.topk(gating_weights, self.top_k, dim=1)
    expert_outputs = [expert(x) for expert in self.experts]

    expert_outputs = torch.stack(expert_outputs)
    print(topk_weights.shape, expert_outputs.shape)
    combined_output = topk_weights@expert_outputs.squeeze(1)
    combined_output = combined_output

    return combined_output






In [None]:
gating_network = GatingNetwork(3,2)
model = MOE(gating_network, 3, 10, 2)
x = torch.randn((1,3))
model(x)

torch.Size([1, 2]) torch.Size([2, 1, 4])


tensor([[0.3801, 0.6737, 0.0897, 0.0022]], grad_fn=<MmBackward0>)

In [None]:
data1 = torch.arange(1,30)
cos_data1 = torch.cos(data1)
sin_data1 = torch.sin(data1)
print(cos_data1, sin_data1)

tensor([ 0.5403, -0.4161, -0.9900, -0.6536,  0.2837,  0.9602,  0.7539, -0.1455,
        -0.9111, -0.8391,  0.0044,  0.8439,  0.9074,  0.1367, -0.7597, -0.9577,
        -0.2752,  0.6603,  0.9887,  0.4081, -0.5477, -1.0000, -0.5328,  0.4242,
         0.9912,  0.6469, -0.2921, -0.9626, -0.7481]) tensor([ 0.8415,  0.9093,  0.1411, -0.7568, -0.9589, -0.2794,  0.6570,  0.9894,
         0.4121, -0.5440, -1.0000, -0.5366,  0.4202,  0.9906,  0.6503, -0.2879,
        -0.9614, -0.7510,  0.1499,  0.9129,  0.8367, -0.0089, -0.8462, -0.9056,
        -0.1324,  0.7626,  0.9564,  0.2709, -0.6636])
