<a href="https://colab.research.google.com/github/OneFineStarstuff/OneFineStarstuff/blob/main/Example_Mixture_of_Experts.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class MixtureOfExperts(nn.Module):
    def __init__(self, num_experts, expert_dim, input_dim):
        super().__init__()
        self.experts = nn.ModuleList([nn.Linear(input_dim, expert_dim) for _ in range(num_experts)])
        self.gating_network = nn.Linear(input_dim, num_experts)

    def forward(self, x):
        gating_weights = F.softmax(self.gating_network(x), dim=-1)  # (batch_size, num_experts)
        expert_outputs = torch.stack([expert(x) for expert in self.experts], dim=1)  # (batch_size, num_experts, expert_dim)
        gating_weights = gating_weights.unsqueeze(2)  # (batch_size, num_experts, 1)
        output = torch.sum(gating_weights * expert_outputs, dim=1)  # (batch_size, expert_dim)
        return output

# Example usage
input_dim = 10
expert_dim = 5
num_experts = 3

model = MixtureOfExperts(num_experts, expert_dim, input_dim)

# Create a dummy input
dummy_input = torch.randn(4, input_dim)  # Batch size of 4

# Forward pass through the mixture of experts model
output = model(dummy_input)
print(output)