In [1]:
import time
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from torch.optim.lr_scheduler import StepLR
from deepmoe_utils import deepmoe_loss
from deepResNet import resnet18_moe
from ResNet import resnet18



In [2]:
transform = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.RandomCrop(32, padding=4),
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761))
])

train_dataset = datasets.CIFAR100(root='./data', train=True, download=True, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True, num_workers=4)

val_dataset = datasets.CIFAR100(root='./data', train=False, download=True, transform=transform)
val_loader = DataLoader(val_dataset, batch_size=128, shuffle=False, num_workers=4)


Downloading https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz to ./data/cifar-100-python.tar.gz


100%|██████████| 169001437/169001437 [02:39<00:00, 1059209.52it/s] 


Extracting ./data/cifar-100-python.tar.gz to ./data
Files already downloaded and verified


In [3]:
def train(model, device, train_loader, optimizer, criterion, moe=False):
    model.train()
    total_loss, correct = 0, 0
    for data, target in train_loader:
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        if not moe:
            output = model(data)
            loss = criterion(output, target)
        else:
            output, gates, emb_y_hat = model(data)
            loss = criterion(output, emb_y_hat, target, gates)
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
        pred = output.argmax(dim=1, keepdim=True)
        correct += pred.eq(target.view_as(pred)).sum().item()
    
    return total_loss / len(train_loader), correct / len(train_loader.dataset)


In [4]:
def validate(model, device, val_loader, criterion, moe=False):
    model.eval()
    total_loss, correct = 0, 0
    with torch.no_grad():
        for data, target in val_loader:
            data, target = data.to(device), target.to(device)
            if not moe:
                output = model(data)
                loss = criterion(output, target)
            else:
                output, gates, emb_y_hat = model(data)
                loss = criterion(output, emb_y_hat, target, gates)
            
            total_loss += loss.item()
            pred = output.argmax(dim=1, keepdim=True)
            correct += pred.eq(target.view_as(pred)).sum().item()
            
    return total_loss / len(val_loader), correct / len(val_loader.dataset)


In [None]:
def main():
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    # Initialize models
    model_resnet18_moe = resnet18_moe(num_classes=100, wide=True).to(device)
    
    # Initialize optimizers
    optimizer_resnet18_moe = optim.Adam(model_resnet18_moe.parameters(), lr=0.001)
    
    # Initialize criteria and schedulers
    moe_criterion = deepmoe_loss(lambda_val=0.001, mu=1.0)
    scheduler_resnet18_moe = StepLR(optimizer_resnet18_moe, step_size=20, gamma=0.5)
    
    num_epochs = 50
    for epoch in range(num_epochs):
        print(f"Epoch {epoch+1}/{num_epochs}")
        
        # Update moe_criterion parameters for the last 5 epochs
        if epoch >= num_epochs - 5:
            moe_criterion = deepmoe_loss(lambda_val=0.0, mu=0.0)
        
        # Train ResNet-18 MoE
        start_time_train_resnet18_moe = time.time()
        train_loss_resnet18_moe, train_acc_resnet18_moe = train(model_resnet18_moe, device, train_loader, optimizer_resnet18_moe, moe_criterion, True)
        end_time_train_resnet18_moe = time.time()
        train_duration_resnet18_moe = end_time_train_resnet18_moe - start_time_train_resnet18_moe
        
        # Validate ResNet-18 MoE
        start_time_val_resnet18_moe = time.time()
        val_loss_resnet18_moe, val_acc_resnet18_moe = validate(model_resnet18_moe, device, val_loader, moe_criterion, True)
        end_time_val_resnet18_moe = time.time()
        val_duration_resnet18_moe = end_time_val_resnet18_moe - start_time_val_resnet18_moe
        
        print(f"ResNet18 MoE - Loss: {train_loss_resnet18_moe:.4f}, Acc: {train_acc_resnet18_moe:.4f} | Val Loss: {val_loss_resnet18_moe:.4f}, Val Acc: {val_acc_resnet18_moe:.4f}")
        print(f"ResNet18 MoE Training Duration: {train_duration_resnet18_moe:.2f} seconds")
        print(f"ResNet18 MoE Validation Duration: {val_duration_resnet18_moe:.2f} seconds")
        
        # Update schedulers
        scheduler_resnet18_moe.step()

if __name__ == "__main__":
    main()

Epoch 1/50
ResNet18 MoE - Loss: 8.4895, Acc: 0.0721 | Val Loss: 8.7148, Val Acc: 0.0988
ResNet18 MoE Training Duration: 183.44 seconds
ResNet18 MoE Validation Duration: 12.24 seconds
Epoch 2/50
ResNet18 MoE - Loss: 7.7881, Acc: 0.1194 | Val Loss: 2100489.2020, Val Acc: 0.1286
ResNet18 MoE Training Duration: 182.97 seconds
ResNet18 MoE Validation Duration: 12.29 seconds
Epoch 3/50
ResNet18 MoE - Loss: 7.3869, Acc: 0.1609 | Val Loss: 4601.8628, Val Acc: 0.1634
ResNet18 MoE Training Duration: 181.58 seconds
ResNet18 MoE Validation Duration: 12.18 seconds
Epoch 4/50
ResNet18 MoE - Loss: 6.9343, Acc: 0.2015 | Val Loss: 11763919277.3472, Val Acc: 0.2084
ResNet18 MoE Training Duration: 180.48 seconds
ResNet18 MoE Validation Duration: 12.24 seconds
Epoch 5/50
ResNet18 MoE - Loss: 6.4722, Acc: 0.2386 | Val Loss: 1487.4166, Val Acc: 0.2428
ResNet18 MoE Training Duration: 180.34 seconds
ResNet18 MoE Validation Duration: 12.31 seconds
Epoch 6/50
ResNet18 MoE - Loss: 6.0446, Acc: 0.2760 | Val Loss: