In [8]:
import numpy as np
import torch
import torch.nn as nn
from torchvision import datasets
from torchvision import transforms
from torch.utils.data.sampler import SubsetRandomSampler
import torch.nn.functional as F

# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
def get_train_valid_loader(data_dir,
                               batch_size,
                               augment,
                               random_seed,
                               valid_size=0.1,
                               shuffle=True):
    normalize = transforms.Normalize(
        mean=[0.4914, 0.4822, 0.4465],
        std=[0.2023, 0.1994, 0.2010],
    )

    # define transforms
    valid_transform = transforms.Compose([
            transforms.Resize((227,227)),
            transforms.ToTensor(),
            normalize,
    ])
    if augment:
        train_transform = transforms.Compose([
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            normalize,
        ])
    else:
        train_transform = transforms.Compose([
            transforms.Resize((227,227)),
            transforms.ToTensor(),
            normalize,
        ])

    train_dataset = datasets.CIFAR10(
        root=data_dir, train=True,
        download=True, transform=train_transform,
    )

    valid_dataset = datasets.CIFAR10(
        root=data_dir, train=True,
        download=True, transform=valid_transform,
    )

    num_train = len(train_dataset)
    indices = list(range(num_train))
    split = int(np.floor(valid_size * num_train))

    if shuffle:
        np.random.seed(random_seed)
        np.random.shuffle(indices)

    train_idx, valid_idx = indices[split:], indices[:split]
    train_sampler = SubsetRandomSampler(train_idx)
    valid_sampler = SubsetRandomSampler(valid_idx)

    train_loader = torch.utils.data.DataLoader(
        train_dataset, batch_size=batch_size, sampler=train_sampler)

    valid_loader = torch.utils.data.DataLoader(
        valid_dataset, batch_size=batch_size, sampler=valid_sampler)

    return (train_loader, valid_loader)


def get_test_loader(data_dir,
                    batch_size,
                    shuffle=True):
    normalize = transforms.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225],
    )

    # define transform
    transform = transforms.Compose([
        transforms.Resize((227,227)),
        transforms.ToTensor(),
        normalize,
    ])

    dataset = datasets.CIFAR10(
        root=data_dir, train=False,
        download=True, transform=transform,
    )

    data_loader = torch.utils.data.DataLoader(
        dataset, batch_size=batch_size, shuffle=shuffle
    )

    return data_loader

# CIFAR10 dataset 
train_loader, valid_loader = get_train_valid_loader(data_dir = './data', batch_size = 64, augment = False,  random_seed = 1)

test_loader = get_test_loader(data_dir = './data', batch_size = 64)

In [None]:
class AlexNetMoELossFree(nn.Module):
    def __init__(self, num_classes=1000, num_experts=3, expert_hidden=4096):
        super(AlexNetMoELossFree, self).__init__()
        # Convolutional feature extractor (AlexNet style)
        self.features = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=11, stride=4, padding=2),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),
            nn.Conv2d(64, 192, kernel_size=5, padding=2),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),
            nn.Conv2d(192, 384, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(384, 256, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(256, 256, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),
        )
        self.flatten_dim = 256 * 6 * 6

        self.fc = nn.Sequential(
            nn.Dropout(),
            nn.Linear(self.flatten_dim, expert_hidden),
            nn.ReLU(inplace=True),
            nn.Dropout(),
            nn.Linear(expert_hidden, expert_hidden),
            nn.ReLU(inplace=True),
        )
        
        self.num_experts = num_experts
        self.gating = nn.Linear(expert_hidden, num_experts)
        self.experts = nn.ModuleList([
            nn.Linear(expert_hidden, num_classes) for _ in range(num_experts)
        ])

        self.register_buffer("expert_bias", torch.zeros(num_experts))
        self.gamma = 0.001

    def forward(self, x):
        x = self.features(x)
        x = x.view(x.size(0), -1)
        x = self.fc(x)  
        
        gate_scores = self.gating(x) 

        biased_scores = gate_scores + self.expert_bias

        selected_expert = torch.argmax(biased_scores, dim=1) 

        outputs = []
        for i in range(x.size(0)):
            expert_idx = selected_expert[i].item()
            expert_output = self.experts[expert_idx](x[i])
            outputs.append(expert_output)
        output = torch.stack(outputs, dim=0)
        return output, selected_expert

    @torch.no_grad()
    def update_expert_bias(self, selected_expert):

        batch_size = selected_expert.size(0)
        expected_load = batch_size / self.num_experts
        load_counts = torch.zeros(self.num_experts, device=selected_expert.device)
        for idx in selected_expert:
            load_counts[idx] += 1
        
        for i in range(self.num_experts):
            if load_counts[i] > expected_load:
                self.expert_bias[i] -= self.gamma
            elif load_counts[i] < expected_load:
                self.expert_bias[i] += self.gamma


In [11]:
num_classes = 10
num_epochs = 20
batch_size = 64
learning_rate = 0.005

model = AlexNetMoE(num_classes).to(device)

# Loss and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate, weight_decay = 0.005, momentum = 0.9)  

# Train the model
total_step = len(train_loader)

In [12]:
total_step = len(train_loader)

for epoch in range(num_epochs):
    for i, (images, labels) in enumerate(train_loader):  
        # Move tensors to the configured device
        images = images.to(device)
        labels = labels.to(device)

        # Forward pass
        outputs = model(images)
        loss = criterion(outputs, labels)

        # Backward and optimize
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    print ('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}' 
                    .format(epoch+1, num_epochs, i+1, total_step, loss.item()))

    # Validation
    with torch.no_grad():
        correct = 0
        total = 0
        for images, labels in valid_loader:
            images = images.to(device)
            labels = labels.to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
            del images, labels, outputs

        print('Accuracy of the network on the {} validation images: {} %'.format(5000, 100 * correct / total))

Epoch [1/20], Step [704/704], Loss: 1.8446
Accuracy of the network on the 5000 validation images: 35.86 %
Epoch [2/20], Step [704/704], Loss: 1.5361
Accuracy of the network on the 5000 validation images: 43.48 %
Epoch [3/20], Step [704/704], Loss: 1.3587
Accuracy of the network on the 5000 validation images: 56.56 %
Epoch [4/20], Step [704/704], Loss: 1.0419
Accuracy of the network on the 5000 validation images: 55.48 %
Epoch [5/20], Step [704/704], Loss: 0.4062
Accuracy of the network on the 5000 validation images: 65.52 %
Epoch [6/20], Step [704/704], Loss: 1.2311
Accuracy of the network on the 5000 validation images: 68.46 %
Epoch [7/20], Step [704/704], Loss: 0.5360
Accuracy of the network on the 5000 validation images: 71.38 %
Epoch [8/20], Step [704/704], Loss: 0.6138
Accuracy of the network on the 5000 validation images: 70.0 %
Epoch [9/20], Step [704/704], Loss: 0.6637
Accuracy of the network on the 5000 validation images: 71.68 %
Epoch [10/20], Step [704/704], Loss: 0.4168
Acc

In [13]:
with torch.no_grad():
        correct = 0
        total = 0
        for images, labels in test_loader:
            images = images.to(device)
            labels = labels.to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
            del images, labels, outputs

        print('Accuracy of the network on the {} test images: {} %'.format(10000, 100 * correct / total))

Accuracy of the network on the 10000 test images: 78.12 %


In [14]:
torch.save(model.state_dict(), 'alexnet_moe.pth')

In [31]:
import time

def measure_inference_time(model, dummy_input, num_runs=100, warmup_runs=10, device='cuda'):
    """
    Measures and prints the average inference time per run.
    """
    model.eval()  # Set model to evaluation mode
    with torch.no_grad():
        # Warm-up runs to stabilize timing
        for _ in range(warmup_runs):
            _ = model(dummy_input)
            if device.startswith('cuda'):
                torch.cuda.synchronize()
        # Timing runs
        start_time = time.time()
        for _ in range(num_runs):
            _ = model(dummy_input)
            if device.startswith('cuda'):
                torch.cuda.synchronize()
        total_time = time.time() - start_time
        avg_time = total_time / num_runs
    print(f"Average inference time over {num_runs} runs: {avg_time:.6f} seconds")
    return avg_time

device = 'cuda' if torch.cuda.is_available() else 'cpu'


data_iter = iter(train_loader)
images, labels = next(data_iter)
images = images.to(device)


In [28]:
torch.set_num_threads(1)

avg_inference_time = measure_inference_time(model, images, num_runs=100, warmup_runs=10, device=device)

Average inference time over 100 runs: 0.367163 seconds


In [29]:
torch.set_num_threads(4)

avg_inference_time = measure_inference_time(model, images, num_runs=100, warmup_runs=10, device=device)

Average inference time over 100 runs: 0.328416 seconds


In [30]:
torch.set_num_threads(8)

avg_inference_time = measure_inference_time(model, images, num_runs=100, warmup_runs=10, device=device)

Average inference time over 100 runs: 0.319836 seconds
