In [1]:
import torch
from models import MultiTaskCNN,ModSquadCNN
from optimisers_APO import AdaptiveProtectiveOptimizer
from utils import *
import random
import numpy as np
import pandas as pd
import copy 

random.seed(20)
torch.manual_seed(20)
np.random.seed(20)

In [2]:
# Hyperparameters
batch_size = 256
epochs = 20
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [3]:
# Load datasets
train_loaders, val_loaders, test_loaders = get_all_loaders(batch_size=batch_size)

In [4]:
# Initialize model
output_sizes = [10, 10,10]  # MNIST and Fashion-MNIST both have 10 classes
model = ModSquadCNN(output_sizes,num_experts=8,k=6).to(device)
task_specific_layer_names = ["task_layers.0.weight", "task_layers.0.bias",
                             "task_layers.1.weight", "task_layers.1.bias", 
                             "task_layers.2.weight", "task_layers.2.bias"]

# Initialize the optimizer, passing in the model and the task-specific layer names
optimizer = AdaptiveProtectiveOptimizer(model, task_specific_layer_names, 3, lr=0.01)
criterion = torch.nn.CrossEntropyLoss()

In [5]:
def update_task_priority(task_losses):
    sorted_tasks = sorted(range(len(task_losses)), key=lambda k: task_losses[k], reverse=True)
    return sorted_tasks


In [6]:
def validate_model(model, val_loaders):
    model.eval()
    total_correct = {task_id: 0 for task_id in range(len(val_loaders))}
    total_samples = {task_id: 0 for task_id in range(len(val_loaders))}
    
    with torch.no_grad():
        for task_id, dataloader in enumerate(val_loaders):
            for inputs, targets in dataloader:
                inputs, targets = inputs.to(device), targets.to(device)
                outputs = model(inputs, task_id)
                _, preds = torch.max(outputs, 1)
                total_correct[task_id] += (preds == targets).sum().item()
                total_samples[task_id] += targets.size(0)

    val_accuracies = {task_id: (100 * total_correct[task_id] / total_samples[task_id])
                      for task_id in total_correct}
    return val_accuracies

In [7]:
# Initialize tracking for best validation accuracy
best_val_accuracy = 0.0
task_names = ["MNIST", "FashionMNIST", "KMNIST"]
original_task_ids = list(range(len(train_loaders)))
sorted_task_ids = list(range(len(train_loaders)))


BEST=[0,0,0]


# Initialize list to accumulate metrics for each epoch and task
metrics_list = [[] for _ in range(len(task_names))]

# Training loop
for epoch in range(epochs):
    task_losses = {task_id: 0.0 for task_id in original_task_ids}
    pro_per = {task_id: 0.0 for task_id in original_task_ids}
    VAL=[]
    
    for task_id in sorted_task_ids:

        
        dataloader=train_loaders[task_id]
        val_loader=val_loaders[task_id]
        model.train()
        total_loss = 0.0
        
        for inputs, targets in dataloader:
            inputs, targets = inputs.to(device), targets.to(device)

            optimizer.zero_grad()

            # Forward pass
            outputs = model(inputs, task_id)
            loss = criterion(outputs, targets)

            # Backward pass
            loss.backward()

            # Update parameters
            optimizer.step(task_id, epoch)

            total_loss += loss.item()

        # Store average loss for each task
        task_losses[task_id] = total_loss / len(dataloader)
        task_related_percent, protected_percent, unclaimed_percent = optimizer.calculate_percentages(task_id) 
        pro_per[task_id]=protected_percent
        val_acc,val_loss = evaluate_model_task(model, val_loader,criterion,task_id) 
        
        if val_acc>BEST[task_id]:
           BEST[task_id]=val_acc
           torch.save(model, './trained_models/Prop_MOE_'+task_names[task_id]) 
             
        print(f"Epoch [{epoch+1}/{epochs}], Task Name [{task_names[task_id]}], Loss: {task_losses[task_id]:.3f} Val ACC: {val_acc:.2f} Task-related: {task_related_percent:.2f}%, Protected: {protected_percent:.2f}%, Unclaimed: {unclaimed_percent:.2f}%")
      
    
        
        
        metrics_list[task_id].append({
            'Epoch': epoch + 1,
            'Task': task_names[task_id],
            'Training Loss': task_losses[task_id],
            'Validation Loss': val_loss,
            'Validation Accuracy': val_acc,
            'Task-related %': task_related_percent,
            'Protected %': protected_percent,
            'Unclaimed %': unclaimed_percent
        })
    
    # Validation step
    print('-----------------------------------')
    avg_val_accuracy=evaluate_model_avg(model, val_loaders,sorted_task_ids)
    

    # Save the best model based on validation accuracy
    if avg_val_accuracy > best_val_accuracy:
        best_val_accuracy = avg_val_accuracy
        torch.save(model,'./trained_models/Prop_MOE_global')


    print('=================================')   
    if epoch%10==0:
       sorted_task_ids = update_task_priority(list(task_losses.values()))
       print(sorted_task_ids)
    print('=================================')   


Epoch [1/20], Task Name [MNIST], Loss: 2.293 Val ACC: 37.19 Task-related: 15.92%, Protected: 0.00%, Unclaimed: 84.08%
Epoch [1/20], Task Name [FashionMNIST], Loss: 2.275 Val ACC: 23.81 Task-related: 16.44%, Protected: 15.92%, Unclaimed: 67.64%
Epoch [1/20], Task Name [KMNIST], Loss: 2.232 Val ACC: 43.68 Task-related: 10.27%, Protected: 32.36%, Unclaimed: 57.37%
-----------------------------------
Task 0 Accuracy: 26.23%


KeyboardInterrupt: 

In [None]:
sorted_task_ids = list(range(len(train_loaders)))
model=torch.load('./trained_models/Prop_MOE_global')
avg_accuracy=evaluate_model_avg(model, test_loaders,sorted_task_ids)
print(avg_accuracy)

In [None]:
model=torch.load('./trained_models/Prop_MOE_MNIST')
val_acc,val_loss = evaluate_model_task(model, test_loaders[0],criterion,0)
print(val_acc) 

In [None]:
model=torch.load('./trained_models/Prop_MOE_FashionMNIST')
val_acc,val_loss = evaluate_model_task(model, test_loaders[1],criterion,1)
print(val_acc) 

In [None]:
model=torch.load('./trained_models/Prop_MOE_KMNIST')
val_acc,val_loss = evaluate_model_task(model, test_loaders[2],criterion,2)
print(val_acc) 