In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

In [2]:
# CIFAR-10 dataset and preprocessing
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

trainset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
trainloader = DataLoader(trainset, batch_size=64, shuffle=True, num_workers=2)

testset = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
testloader = DataLoader(testset, batch_size=64, shuffle=False, num_workers=2)

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar-10-python.tar.gz


100%|██████████| 170M/170M [00:03<00:00, 49.5MB/s]


Extracting ./data/cifar-10-python.tar.gz to ./data
Files already downloaded and verified


In [3]:
# Define the Expert Network
class Expert(nn.Module):
    def __init__(self):
        super(Expert, self).__init__()
        self.conv1 = nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)
        self.fc = nn.Linear(64 * 8 * 8, 10)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.max_pool2d(x, 2)
        x = F.relu(self.conv2(x))
        x = F.max_pool2d(x, 2)
        x = x.view(x.size(0), -1)
        x = self.fc(x)
        return x

In [4]:
# Define the Router Network
class Router(nn.Module):
    def __init__(self, num_experts):
        super(Router, self).__init__()
        self.fc = nn.Linear(3 * 32 * 32, num_experts)  # Flattened CIFAR-10 input

    def forward(self, x):
        x = x.view(x.size(0), -1)  # Flatten input
        logits = self.fc(x)
        probs = F.softmax(logits, dim=1)  # Routing probabilities
        return probs

In [5]:
# Mixture of Experts Network
class MixtureOfExperts(nn.Module):
    def __init__(self, num_experts):
        super(MixtureOfExperts, self).__init__()
        self.num_experts = num_experts
        self.experts = nn.ModuleList([Expert() for _ in range(num_experts)])
        self.router = Router(num_experts)

    def forward(self, x):
        routing_probs = self.router(x)  # Get routing probabilities
        expert_outputs = torch.stack([expert(x) for expert in self.experts], dim=1)  # Outputs from all experts
        top1_expert = torch.argmax(routing_probs, dim=1)  # Top-1 routing

        # Select the output of the top-1 expert for each sample
        batch_size = x.size(0)
        outputs = torch.zeros(batch_size, 10).to(x.device)  # Initialize outputs
        for i in range(batch_size):
            outputs[i] = expert_outputs[i, top1_expert[i]]

        return outputs

In [6]:
# Initialize the Mixture of Experts
num_experts = 3
model = MixtureOfExperts(num_experts=num_experts).to(torch.device('cuda' if torch.cuda.is_available() else 'cpu'))

# Loss and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# Training loop
num_epochs = 10
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)

for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    for inputs, labels in trainloader:
        inputs, labels = inputs.to(device), labels.to(device)

        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()

    print(f"Epoch {epoch+1}/{num_epochs}, Loss: {running_loss/len(trainloader):.4f}")

Epoch 1/10, Loss: 1.4873
Epoch 2/10, Loss: 1.1093
Epoch 3/10, Loss: 0.9409
Epoch 4/10, Loss: 0.8212
Epoch 5/10, Loss: 0.7239
Epoch 6/10, Loss: 0.6400
Epoch 7/10, Loss: 0.5657
Epoch 8/10, Loss: 0.4958
Epoch 9/10, Loss: 0.4320
Epoch 10/10, Loss: 0.3812


In [7]:
# Testing the model
model.eval()
correct = 0
total = 0
with torch.no_grad():
    for inputs, labels in testloader:
        inputs, labels = inputs.to(device), labels.to(device)
        outputs = model(inputs)
        _, predicted = torch.max(outputs, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

print(f"Accuracy: {100 * correct / total:.2f}%")


Accuracy: 67.12%
