In [1]:
import time
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from torch.optim.lr_scheduler import StepLR
from deepmoe_utils import deepmoe_loss
from deepResNet import resnet18_moe
from ResNet import resnet18



In [2]:
transform = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.RandomCrop(32, padding=4),
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761))
])

train_dataset = datasets.CIFAR100(root='./data', train=True, download=True, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True, num_workers=4)

val_dataset = datasets.CIFAR100(root='./data', train=False, download=True, transform=transform)
val_loader = DataLoader(val_dataset, batch_size=128, shuffle=False, num_workers=4)


Files already downloaded and verified
Files already downloaded and verified


In [3]:
def train(model, device, train_loader, optimizer, criterion, moe=False):
    model.train()
    total_loss, correct = 0, 0
    for data, target in train_loader:
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        if not moe:
            output = model(data)
            loss = criterion(output, target)
        else:
            output, gates, emb_y_hat = model(data)
            loss = criterion(output, emb_y_hat, target, gates)
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
        pred = output.argmax(dim=1, keepdim=True)
        correct += pred.eq(target.view_as(pred)).sum().item()
    
    return total_loss / len(train_loader), correct / len(train_loader.dataset)


In [4]:
def validate(model, device, val_loader, criterion, moe=False):
    model.eval()
    total_loss, correct = 0, 0
    with torch.no_grad():
        for data, target in val_loader:
            data, target = data.to(device), target.to(device)
            if not moe:
                output = model(data)
                loss = criterion(output, target)
            else:
                output, gates, emb_y_hat = model(data)
                loss = criterion(output, emb_y_hat, target, gates)
            
            total_loss += loss.item()
            pred = output.argmax(dim=1, keepdim=True)
            correct += pred.eq(target.view_as(pred)).sum().item()
            
    return total_loss / len(val_loader), correct / len(val_loader.dataset)


In [5]:
def main():
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    # Initialize models
    model_resnet18 = resnet18(num_classes=100).to(device)
    model_resnet18_moe = resnet18_moe(num_classes=100).to(device)
    
    # Initialize optimizers
    optimizer_resnet18 = optim.Adam(model_resnet18.parameters(), lr=0.001)
    optimizer_resnet18_moe = optim.Adam(model_resnet18_moe.parameters(), lr=0.001)
    
    # Initialize criteria and schedulers
    res_criterion = nn.CrossEntropyLoss()
    moe_criterion = deepmoe_loss(lambda_val=1.0, mu=1.0)
    scheduler_resnet18 = StepLR(optimizer_resnet18, step_size=20, gamma=0.5)
    scheduler_resnet18_moe = StepLR(optimizer_resnet18_moe, step_size=20, gamma=0.5)
    
    num_epochs = 25
    for epoch in range(num_epochs):
        print(f"Epoch {epoch+1}/{num_epochs}")
        
        # Update moe_criterion parameters for the last 5 epochs
        if epoch >= num_epochs - 5:
            moe_criterion = deepmoe_loss(lambda_val=0.0, mu=0.0)
        
        # Train and validate ResNet-18
        start_time_train_resnet18 = time.time()
        train_loss_resnet18, train_acc_resnet18 = train(model_resnet18, device, train_loader, optimizer_resnet18, res_criterion)
        end_time_train_resnet18 = time.time()
        train_duration_resnet18 = end_time_train_resnet18 - start_time_train_resnet18
        
        start_time_val_resnet18 = time.time()
        val_loss_resnet18, val_acc_resnet18 = validate(model_resnet18, device, val_loader, res_criterion)
        end_time_val_resnet18 = time.time()
        val_duration_resnet18 = end_time_val_resnet18 - start_time_val_resnet18
        
        # Train and validate ResNet-18 MoE
        start_time_train_resnet18_moe = time.time()
        train_loss_resnet18_moe, train_acc_resnet18_moe = train(model_resnet18_moe, device, train_loader, optimizer_resnet18_moe, moe_criterion, True)
        end_time_train_resnet18_moe = time.time()
        train_duration_resnet18_moe = end_time_train_resnet18_moe - start_time_train_resnet18_moe
        
        start_time_val_resnet18_moe = time.time()
        val_loss_resnet18_moe, val_acc_resnet18_moe = validate(model_resnet18_moe, device, val_loader, moe_criterion, True)
        end_time_val_resnet18_moe = time.time()
        val_duration_resnet18_moe = end_time_val_resnet18_moe - start_time_val_resnet18_moe
        
        print(f"ResNet18 - Loss: {train_loss_resnet18:.4f}, Acc: {train_acc_resnet18:.4f} | Val Loss: {val_loss_resnet18:.4f}, Val Acc: {val_acc_resnet18:.4f}")
        print(f"ResNet18 MoE - Loss: {train_loss_resnet18_moe:.4f}, Acc: {train_acc_resnet18_moe:.4f} | Val Loss: {val_loss_resnet18_moe:.4f}, Val Acc: {val_acc_resnet18_moe:.4f}")
        
        print(f"ResNet18 Training Duration: {train_duration_resnet18:.2f} seconds")
        print(f"ResNet18 Validation Duration: {val_duration_resnet18:.2f} seconds")
        print(f"ResNet18 MoE Training Duration: {train_duration_resnet18_moe:.2f} seconds")
        print(f"ResNet18 MoE Validation Duration: {val_duration_resnet18_moe:.2f} seconds")
        
        # Update schedulers
        scheduler_resnet18.step()
        scheduler_resnet18_moe.step()

if __name__ == "__main__":
    main()

Epoch 1/25
ResNet18 - Loss: 3.8458, Acc: 0.1081 | Val Loss: 3.6047, Val Acc: 0.1446
ResNet18 MoE - Loss: 12.2819, Acc: 0.0854 | Val Loss: 9.7104, Val Acc: 0.0959
ResNet18 Training Duration: 63.87 seconds
ResNet18 Validation Duration: 5.05 seconds
ResNet18 MoE Training Duration: 84.43 seconds
ResNet18 MoE Validation Duration: 6.50 seconds
Epoch 2/25
ResNet18 - Loss: 3.1764, Acc: 0.2150 | Val Loss: 3.0281, Val Acc: 0.2455
ResNet18 MoE - Loss: 8.7602, Acc: 0.1203 | Val Loss: 8.3087, Val Acc: 0.1346
ResNet18 Training Duration: 66.71 seconds
ResNet18 Validation Duration: 5.33 seconds
ResNet18 MoE Training Duration: 84.55 seconds
ResNet18 MoE Validation Duration: 6.41 seconds
Epoch 3/25
ResNet18 - Loss: 2.6174, Acc: 0.3234 | Val Loss: 2.5170, Val Acc: 0.3478
ResNet18 MoE - Loss: 8.1490, Acc: 0.1436 | Val Loss: 7.9337, Val Acc: 0.1458
ResNet18 Training Duration: 65.56 seconds
ResNet18 Validation Duration: 5.04 seconds
ResNet18 MoE Training Duration: 84.18 seconds
ResNet18 MoE Validation Durat

In [7]:
def main2():
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    # Initialize models
    model_resnet18_moe = resnet18_moe(num_classes=100).to(device)
    
    # Initialize optimizers
    optimizer_resnet18_moe = optim.Adam(model_resnet18_moe.parameters(), lr=0.001)
    
    # Initialize criteria and schedulers
    moe_criterion = deepmoe_loss(lambda_val=0.001, mu=1.0)
    scheduler_resnet18_moe = StepLR(optimizer_resnet18_moe, step_size=20, gamma=0.5)
    
    num_epochs = 25
    for epoch in range(num_epochs):
        print(f"Epoch {epoch+1}/{num_epochs}")
        
        # Update moe_criterion parameters for the last 5 epochs
        if epoch >= num_epochs - 5:
            moe_criterion = deepmoe_loss(lambda_val=0.0, mu=0.0)
        
        # Train ResNet-18 MoE
        start_time_train_resnet18_moe = time.time()
        train_loss_resnet18_moe, train_acc_resnet18_moe = train(model_resnet18_moe, device, train_loader, optimizer_resnet18_moe, moe_criterion, True)
        end_time_train_resnet18_moe = time.time()
        train_duration_resnet18_moe = end_time_train_resnet18_moe - start_time_train_resnet18_moe
        
        # Validate ResNet-18 MoE
        start_time_val_resnet18_moe = time.time()
        val_loss_resnet18_moe, val_acc_resnet18_moe = validate(model_resnet18_moe, device, val_loader, moe_criterion, True)
        end_time_val_resnet18_moe = time.time()
        val_duration_resnet18_moe = end_time_val_resnet18_moe - start_time_val_resnet18_moe
        
        print(f"ResNet18 MoE - Loss: {train_loss_resnet18_moe:.4f}, Acc: {train_acc_resnet18_moe:.4f} | Val Loss: {val_loss_resnet18_moe:.4f}, Val Acc: {val_acc_resnet18_moe:.4f}")
        print(f"ResNet18 MoE Training Duration: {train_duration_resnet18_moe:.2f} seconds")
        print(f"ResNet18 MoE Validation Duration: {val_duration_resnet18_moe:.2f} seconds")
        
        # Update schedulers
        scheduler_resnet18_moe.step()

if __name__ == "__main__":
    main2()

Epoch 1/25
ResNet18 MoE - Loss: 8.4175, Acc: 0.0745 | Val Loss: 8.8274, Val Acc: 0.0961
ResNet18 MoE Training Duration: 85.05 seconds
ResNet18 MoE Validation Duration: 6.35 seconds
Epoch 2/25
ResNet18 MoE - Loss: 7.7057, Acc: 0.1196 | Val Loss: 19.0168, Val Acc: 0.1375
ResNet18 MoE Training Duration: 84.70 seconds
ResNet18 MoE Validation Duration: 6.48 seconds
Epoch 3/25
ResNet18 MoE - Loss: 7.1358, Acc: 0.1717 | Val Loss: 7.0516, Val Acc: 0.1861
ResNet18 MoE Training Duration: 84.53 seconds
ResNet18 MoE Validation Duration: 6.69 seconds
Epoch 4/25
ResNet18 MoE - Loss: 6.6460, Acc: 0.2166 | Val Loss: 8.5427, Val Acc: 0.2319
ResNet18 MoE Training Duration: 84.87 seconds
ResNet18 MoE Validation Duration: 6.34 seconds
Epoch 5/25
ResNet18 MoE - Loss: 6.1332, Acc: 0.2628 | Val Loss: 116.9042, Val Acc: 0.2500
ResNet18 MoE Training Duration: 83.81 seconds
ResNet18 MoE Validation Duration: 6.38 seconds
Epoch 6/25
ResNet18 MoE - Loss: 5.7152, Acc: 0.3045 | Val Loss: 15.7891, Val Acc: 0.2950
Res

In [8]:
def main3():
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    # Initialize models
    model_resnet18_moe = resnet18_moe(num_classes=100).to(device)
    
    # Initialize optimizers
    optimizer_resnet18_moe = optim.Adam(model_resnet18_moe.parameters(), lr=0.001)
    
    # Initialize criteria and schedulers
    moe_criterion = deepmoe_loss(lambda_val=0.001, mu=0.1)
    scheduler_resnet18_moe = StepLR(optimizer_resnet18_moe, step_size=20, gamma=0.5)
    
    num_epochs = 25
    for epoch in range(num_epochs):
        print(f"Epoch {epoch+1}/{num_epochs}")
        
        # Update moe_criterion parameters for the last 5 epochs
        if epoch >= num_epochs - 5:
            moe_criterion = deepmoe_loss(lambda_val=0.0, mu=0.0)
        
        # Train ResNet-18 MoE
        start_time_train_resnet18_moe = time.time()
        train_loss_resnet18_moe, train_acc_resnet18_moe = train(model_resnet18_moe, device, train_loader, optimizer_resnet18_moe, moe_criterion, True)
        end_time_train_resnet18_moe = time.time()
        train_duration_resnet18_moe = end_time_train_resnet18_moe - start_time_train_resnet18_moe
        
        # Validate ResNet-18 MoE
        start_time_val_resnet18_moe = time.time()
        val_loss_resnet18_moe, val_acc_resnet18_moe = validate(model_resnet18_moe, device, val_loader, moe_criterion, True)
        end_time_val_resnet18_moe = time.time()
        val_duration_resnet18_moe = end_time_val_resnet18_moe - start_time_val_resnet18_moe
        
        print(f"ResNet18 MoE - Loss: {train_loss_resnet18_moe:.4f}, Acc: {train_acc_resnet18_moe:.4f} | Val Loss: {val_loss_resnet18_moe:.4f}, Val Acc: {val_acc_resnet18_moe:.4f}")
        print(f"ResNet18 MoE Training Duration: {train_duration_resnet18_moe:.2f} seconds")
        print(f"ResNet18 MoE Validation Duration: {val_duration_resnet18_moe:.2f} seconds")
        
        # Update schedulers
        scheduler_resnet18_moe.step()

if __name__ == "__main__":
    main3()

Epoch 1/25
ResNet18 MoE - Loss: 4.6307, Acc: 0.0672 | Val Loss: 4.8587, Val Acc: 0.0812
ResNet18 MoE Training Duration: 85.42 seconds
ResNet18 MoE Validation Duration: 6.79 seconds
Epoch 2/25
ResNet18 MoE - Loss: 4.3023, Acc: 0.1044 | Val Loss: 383.2595, Val Acc: 0.1008
ResNet18 MoE Training Duration: 85.33 seconds
ResNet18 MoE Validation Duration: 6.51 seconds
Epoch 3/25
ResNet18 MoE - Loss: 4.1571, Acc: 0.1291 | Val Loss: 278.5084, Val Acc: 0.1323
ResNet18 MoE Training Duration: 85.56 seconds
ResNet18 MoE Validation Duration: 6.37 seconds
Epoch 4/25
ResNet18 MoE - Loss: 4.1732, Acc: 0.1473 | Val Loss: 155489358.8796, Val Acc: 0.1557
ResNet18 MoE Training Duration: 84.99 seconds
ResNet18 MoE Validation Duration: 6.46 seconds
Epoch 5/25
ResNet18 MoE - Loss: 3.9890, Acc: 0.1751 | Val Loss: 27.2217, Val Acc: 0.1950
ResNet18 MoE Training Duration: 84.58 seconds
ResNet18 MoE Validation Duration: 6.45 seconds
Epoch 6/25
ResNet18 MoE - Loss: 3.8610, Acc: 0.1947 | Val Loss: 159503.7253, Val A