In [5]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np

In [16]:
class Expert(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim):
        super(Expert, self).__init__()
        self.layer1 = nn.Linear(input_dim, hidden_dim)
        self.layer2 = nn.Linear(hidden_dim, output_dim)

    def forward(self, x):
        x = torch.relu(self.layer1(x))
        return torch.softmax(self.layer2(x), dim=1)

In [7]:
class Gating(nn.Module):
    def __init__(self, input_dim, num_experts, dropout_rate=0.1):
        super(Gating, self).__init__()
        self.layer1 = nn.Linear(input_dim, 128)
        self.dropout1 = nn.Dropout(dropout_rate)
        self.layer2 = nn.Linear(128, 256)
        self.leaky_relu1 = nn.LeakyReLU()
        self.dropout2 = nn.Dropout(dropout_rate)
        self.layer3 = nn.Linear(256, 128)
        self.leaky_relu2 = nn.LeakyReLU()
        self.dropout3 = nn.Dropout(dropout_rate)
        self.layer4 = nn.Linear(128, num_experts)

    def forward(self, x):
        x = torch.relu(self.layer1(x))
        x = self.dropout1(x)
        x = self.leaky_relu1(self.layer2(x))
        x = self.dropout2(x)
        x = self.leaky_relu2(self.layer3(x))
        x = self.dropout3(x)
        return torch.softmax(self.layer4(x), dim=1)

In [9]:
class MoE(nn.Module):
    def __init__(self, experts):
        super(MoE, self).__init__()
        self.experts = nn.ModuleList(experts)
        self.gating = Gating(input_dim=experts[0].layer1.in_features, num_experts=len(experts))

    def forward(self, x):
        # Get the weights from the gating network
        weights = self.gating(x)
        # Calculate the outputs of each expert
        expert_outputs = torch.stack([expert(x) for expert in self.experts], dim=2)
        # Adjust weights shape and apply them
        weights = weights.unsqueeze(1).expand_as(expert_outputs)
        return torch.sum(expert_outputs * weights, dim=2)

Generating Synthetic Data

In [12]:
num_samples = 5000
input_dim = 4
x_data = torch.randn(num_samples, input_dim)
y_data = torch.cat([torch.zeros(num_samples // 3),
                    torch.ones(num_samples // 3),
                    torch.full((num_samples - 2 * (num_samples // 3),), 2)]).long()

# Adding biases to data
for i in range(num_samples):
    if y_data[i] == 0:
        x_data[i, 0] += 1
    elif y_data[i] == 1:
        x_data[i, 1] -= 1
    elif y_data[i] == 2:
        x_data[i, 0] -= 1

# Shuffle and split data
shuffled_indices = torch.randperm(num_samples)
x_data, y_data = x_data[shuffled_indices], y_data[shuffled_indices]

In [18]:
hidden_dim = 32
output_dim = 3
epochs = 100
learning_rate = 0.001

experts = [Expert(input_dim, hidden_dim, output_dim) for _ in range(3)]
optimizers = [optim.Adam(expert.parameters(), lr=learning_rate) for expert in experts]

# Training each expert on tailored data
for i, expert in enumerate(experts):
    optimizer = optimizers[i]
    mask = y_data == i
    x_train, y_train = x_data[mask], y_data[mask]
    for epoch in range(epochs):
        optimizer.zero_grad()
        outputs = expert(x_train)
        loss = nn.CrossEntropyLoss()(outputs, y_train)
        loss.backward()
        optimizer.step()

In [20]:
moe_model = MoE(experts)
optimizer_moe = optim.Adam(moe_model.parameters(), lr=learning_rate)

x_train_moe = x_data[int(num_samples * 0.8):]
y_train_moe = y_data[int(num_samples * 0.8):]

for epoch in range(epochs):
    optimizer_moe.zero_grad()
    outputs_moe = moe_model(x_train_moe)
    loss_moe = nn.CrossEntropyLoss()(outputs_moe, y_train_moe)
    loss_moe.backward()
    optimizer_moe.step()

In [22]:
def evaluate(model, x, y):
    with torch.no_grad():
        outputs = model(x)
        _, predicted = torch.max(outputs, 1)
        correct = (predicted == y).sum().item()
        return correct / len(y)

# Testing each expert and the MoE
accuracy_expert1 = evaluate(experts[0], x_data, y_data)
accuracy_expert2 = evaluate(experts[1], x_data, y_data)
accuracy_expert3 = evaluate(experts[2], x_data, y_data)
accuracy_moe = evaluate(moe_model, x_data, y_data)

print(f"Expert 1 Accuracy: {accuracy_expert1}")
print(f"Expert 2 Accuracy: {accuracy_expert2}")
print(f"Expert 3 Accuracy: {accuracy_expert3}")
print(f"Mixture of Experts Accuracy: {accuracy_moe}")

Expert 1 Accuracy: 0.3332
Expert 2 Accuracy: 0.3332
Expert 3 Accuracy: 0.3336
Mixture of Experts Accuracy: 0.6608
