In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class ExpertMLP(nn.Module):
    def __init__(self, in_dim, num_classes, hidden=256, dropout=0.2):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(in_dim, hidden), nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(hidden, hidden//2), nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(hidden//2, num_classes)
        )

    def forward(self, x):
        return self.net(x)

class GatingNetwork(nn.Module):
    def __init__(self, in_dim, num_experts, hidden=128):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(in_dim, hidden), nn.ReLU(),
            nn.Linear(hidden, num_experts)
        )

    def forward(self, x):
        return self.net(x)  

class MoEClassifier(nn.Module):
    def __init__(self, in_dim, num_classes, num_experts=3):
        super().__init__()
        self.num_experts = num_experts
        self.experts = nn.ModuleList([ExpertMLP(in_dim, num_classes) for _ in range(num_experts)])
        self.gate = GatingNetwork(in_dim, num_experts)

    def forward(self, x):
        gate_logits = self.gate(x)                 
        weights = F.softmax(gate_logits, dim=1)    # gating weights

        expert_outputs = []
        for exp in self.experts:
            expert_outputs.append(exp(x).unsqueeze(1)) 
        expert_outputs = torch.cat(expert_outputs, dim=1) 

        # weighted fusion
        w = weights.unsqueeze(-1)   
        fused = (w * expert_outputs).sum(dim=1) 
        return fused, weights
