# Multi-task Training Implementation

This notebook provides a comprehensive implementation of multi-task learning, applied to training models such as `MLP` and `CNN` on datasets like `Permuted MNIST` and `Split CIFAR-10/100`.

While the primary focus of this study is on Continual Learning, this section takes a different approach by leveraging the entire dataset, processed at once altogether/in sequential chunks rather than task-by-task progression.

## Packages and Presets

Imports setup:

In [50]:
import os
import sys
import numpy as np
import torch
import torch.nn as nn
from torch.optim import SGD
from torch.utils.data import DataLoader, ConcatDataset

import warnings
from omegaconf import DictConfig, OmegaConf

sys.path.append(os.path.abspath("../"))
from modules.MTLTrainer import MTLTrainer
from modules.mlp import MLP
from modules.cnn import CNN
from modules.subspace_sgd import SubspaceSGD
from utils.data_utils.permuted_mnist import PermutedMNIST
from utils.data_utils.sequential_CIFAR import CL_CIFAR10, CL_CIFAR100

warnings.filterwarnings("ignore")

## Main Training Function for Permuted MNIST

In [51]:
# Load the YAML config as a plain text file
with open("../configs/permuted_mnist.yaml", "r") as f:
    config_str = f.read()

# Replace Hydra-style placeholders and convert backslashes to forward slashes (e.g. Windows-like, change for Linux/MacOS)
cwd = os.getcwd().replace("\\", "/")
config_str = config_str.replace("${hydra:runtime.cwd}", cwd)

# Load the updated config into OmegaConf
config_pmnist = OmegaConf.create(config_str)

# Dynamically resolve Hydra-style paths
config_pmnist.data.data_root = config_pmnist.data.data_root.replace("${hydra:runtime.cwd}", cwd)
config_pmnist.hydra.run.dir = config_pmnist.hydra.run.dir.replace("${hydra:runtime.cwd}", cwd)


# Define the main training function
def main(config: DictConfig) -> None:
    """
    Main training function for Permuted MNIST multi-task learning.

    Args:
        config: Configuration object loaded from YAML.

    """
    # Set the save directory
    save_dir = config.hydra.run.dir
    os.makedirs(save_dir, exist_ok=True)
    print(f"Saving results to {save_dir}")

    # Create permuted MNIST datasets but combine them into one loader
    pmnist = PermutedMNIST(num_tasks=config.data.num_tasks, seed=config.data.seed)
    pmnist.setup_tasks(
        batch_size=config.data.batch_size,
        data_root=config.data.data_root,
        num_workers=config.data.num_workers,
    )

    # Combine all task datasets into one
    combined_train_dataset = ConcatDataset([
        pmnist.train_loaders[task_id].dataset for task_id in range(config.data.num_tasks)
    ])
    combined_test_dataset = ConcatDataset([
        pmnist.test_loaders[task_id].dataset for task_id in range(config.data.num_tasks)
    ])

    # Create unified dataloaders
    train_loader = DataLoader(
        combined_train_dataset,
        batch_size=config.data.batch_size,
        shuffle=True,
        num_workers=config.data.num_workers,
        pin_memory=True
    )
    
    val_loader = DataLoader(
        combined_test_dataset,
        batch_size=config.data.batch_size,
        shuffle=False,
        num_workers=config.data.num_workers,
        pin_memory=True
    )

    # Initialize model
    model = MLP(
        input_dim=config.model.input_dim,
        output_dim=config.model.output_dim, 
        hidden_dim=config.model.hidden_dim,
    )
    model.to(config.training.device)

    # Initialize loss function
    criterion = nn.CrossEntropyLoss()

    # Initialize optimizer
    optimizer = SGD(
        model.parameters(),
        lr=config.optimizer.lr,
        momentum=config.optimizer.momentum,
        weight_decay=config.optimizer.weight_decay,
        nesterov=config.optimizer.nesterov,
    )

    # Initialize trainer
    trainer = MTLTrainer(
        model=model,
        optimizer=optimizer,
        criterion=criterion,
        save_dir=save_dir,
        num_epochs=config.training.num_epochs,
        log_interval=config.training.log_interval,
        eval_freq=config.training.eval_freq,
        checkpoint_freq=config.training.checkpoint_freq,
        seed=config.training.seed,
        subspace_type=config.training.subspace_type,
        scheduler=None,
        device=config.training.device,
        use_wandb=config.wandb.enabled,
        wandb_project=config.wandb.project,
        wandb_config=OmegaConf.to_container(config, resolve=True),
    )

    try:
        # Train and evaluate
        train_metrics, val_metrics = trainer.train_and_evaluate(
            train_loader=train_loader,
            val_loader=val_loader,
        )

        print("Training completed successfully!")
        print(f"Final training accuracy: {train_metrics['accuracies'][-1]:.2f}%")
        print(f"Final validation accuracy: {val_metrics['accuracies'][-1]:.2f}%")

    except Exception as e:
        print(f"Training failed: {str(e)}")
        raise
    finally:
        torch.cuda.empty_cache()


Run the training and evaluation for Permuted MNIST

In [52]:
if __name__ == "__main__":
    main(config=config_pmnist)

Saving results to c:/Users/rufat/cf-tiny-subspaces/notebooks/../results/permuted_mnist/subspace-None/k-10/batch_size-128/hidden_dim-100/lr-0.01/seed-42


Epoch 0:   0%|          | 0/4688 [00:08<?, ?it/s]


Training failed: SGD.step() got an unexpected keyword argument 'fp16'


TypeError: SGD.step() got an unexpected keyword argument 'fp16'

## Main Training Function for Split CIFAR-10


In [None]:
# Load the YAML config as a plain text file
with open("../configs/split_cifar10.yaml", "r") as f:
    config_str = f.read()

# Replace Hydra-style placeholders and convert backslashes to forward slashes (e.g. Windows-like, change for Linux/MacOS)
cwd = os.getcwd().replace("\\", "/")
config_str = config_str.replace("${hydra:runtime.cwd}", cwd)

# Load the updated config into OmegaConf
config_scifar10 = OmegaConf.create(config_str)

# Dynamically resolve Hydra-style paths
config_scifar10.data.data_root = config_scifar10.data.data_root.replace("${hydra:runtime.cwd}", cwd)
config_scifar10.hydra.run.dir = config_scifar10.hydra.run.dir.replace("${hydra:runtime.cwd}", cwd)


# Define the main training function
def main(config):
    """
    Main training function for Split CIFAR multitask learning.
    
    Args:
        config: Configuration object loaded from YAML.
    """
    save_dir = config.hydra.run.dir
    os.makedirs(save_dir, exist_ok=True)
    print(f"Saving results to {save_dir}")

    # Create sequential CIFAR-10 dataset
    split_cifar10 = CL_CIFAR10(
        classes_per_task=config.data.classes_per_task,
        num_tasks=config.data.num_tasks,
        seed=config.data.seed,
    )
    split_cifar10.setup_tasks(
        batch_size=config.data.batch_size,
        data_root=config.data.data_root,
        num_workers=config.data.num_workers,
    )

    # Initialize model
    model = CNN(
        width=config.model.width,
        num_tasks=config.data.num_tasks,
        classes_per_task=config.data.classes_per_task,
    )
    model.to(config.training.device)

    # Initialize loss function
    criterion = nn.CrossEntropyLoss()

    # Initialize optimizer
    optimizer = SGD(
        model.parameters(),
        lr=config.optimizer.lr,
        momentum=config.optimizer.momentum,
        weight_decay=config.optimizer.weight_decay,
        nesterov=config.optimizer.nesterov,
    )

    # Initialize trainer
    trainer = MTLTrainer(
        model=model,
        optimizer=optimizer,
        criterion=criterion,
        save_dir=save_dir,
        num_epochs=config.training.num_epochs,
        log_interval=config.training.log_interval,
        eval_freq=config.training.eval_freq,
        checkpoint_freq=config.training.checkpoint_freq,
        seed=config.training.seed,
        subspace_type=config.training.subspace_type,
        scheduler=None,
        device=config.training.device,
        use_wandb=config.wandb.enabled,
        wandb_project=config.wandb.project,
        wandb_config=OmegaConf.to_container(config, resolve=True),
    )

    try:
        avg_train_accuracy, avg_val_accuracy = [], []
        # Train and evaluate each task sequentially
        for task_id in range(config.data.num_tasks):
            print(f"\n=== Training Task {task_id + 1}/{config.data.num_tasks} ===")
            
            # Get task-specific dataloaders
            train_loader = split_cifar10.train_loaders[task_id]
            val_loader = split_cifar10.test_loaders[task_id]

            # Train and evaluate the task
            train_metrics, val_metrics = trainer.train_and_evaluate(
                train_loader=train_loader,
                val_loader=val_loader,
            )

            avg_train_accuracy.append(train_metrics['accuracies'][-1])
            avg_val_accuracy.append(val_metrics["accuracies"][-1])

            print(f"Task {task_id + 1} Training Accuracy: {train_metrics['accuracies'][-1]:.2f}%")
            print(f"Task {task_id + 1} Validation Accuracy: {val_metrics['accuracies'][-1]:.2f}%")

        # Calculate and print average accuracies
        if avg_train_accuracy and avg_val_accuracy:
            print(f"\n=== Average Multi-task Training Accuracy: {np.mean(avg_train_accuracy):.2f}% ===")
            print(f"=== Average Multi-task Validation Accuracy: {np.mean(avg_val_accuracy):.2f}% ===")

    except Exception as e:
        print(f"Training failed: {str(e)}")
        raise
    finally:
        torch.cuda.empty_cache()


    print("All tasks trained successfully!")


Run the training and evaluation for Split CIFAR-10

In [None]:
if __name__ == "__main__":
    main(config=config_scifar10)

## Main training function for Split CIFAR-100.

In [None]:
# Load the YAML config as a plain text file
with open("../configs/split_cifar100.yaml", "r") as f:
    config_str = f.read()

# Replace Hydra-style placeholders and convert backslashes to forward slashes (e.g. Windows-like, change for Linux/MacOS)
cwd = os.getcwd().replace("\\", "/")
config_str = config_str.replace("${hydra:runtime.cwd}", cwd)

# Load the updated config into OmegaConf
config_scifar100 = OmegaConf.create(config_str)

# Dynamically resolve Hydra-style paths
config_scifar100.data.data_root = config_scifar100.data.data_root.replace("${hydra:runtime.cwd}", cwd)
config_scifar100.hydra.run.dir = config_scifar100.hydra.run.dir.replace("${hydra:runtime.cwd}", cwd)


# Define the main training function
def main(config):
    """
    Main training function for Split CIFAR multitask learning.
    
    Args:
        config: Configuration object loaded from YAML.
    """
    save_dir = config.hydra.run.dir
    os.makedirs(save_dir, exist_ok=True)
    print(f"Saving results to {save_dir}")

    # Create sequential CIFAR-10 dataset
    split_cifar10 = CL_CIFAR100(
        classes_per_task=config.data.classes_per_task,
        num_tasks=config.data.num_tasks,
        seed=config.data.seed,
    )
    split_cifar10.setup_tasks(
        batch_size=config.data.batch_size,
        data_root=config.data.data_root,
        num_workers=config.data.num_workers,
    )

    # Initialize model
    model = CNN(
        width=config.model.width,
        num_tasks=config.data.num_tasks,
        classes_per_task=config.data.classes_per_task,
    )
    model.to(config.training.device)

    # Initialize loss function
    criterion = nn.CrossEntropyLoss()

    # Initialize optimizer
    optimizer = SGD(
        model.parameters(),
        lr=config.optimizer.lr,
        momentum=config.optimizer.momentum,
        weight_decay=config.optimizer.weight_decay,
        nesterov=config.optimizer.nesterov,
    )

    # Initialize trainer
    trainer = MTLTrainer(
        model=model,
        optimizer=optimizer,
        criterion=criterion,
        save_dir=save_dir,
        num_epochs=config.training.num_epochs,
        log_interval=config.training.log_interval,
        eval_freq=config.training.eval_freq,
        checkpoint_freq=config.training.checkpoint_freq,
        seed=config.training.seed,
        subspace_type=config.training.subspace_type,
        scheduler=None,
        device=config.training.device,
        use_wandb=config.wandb.enabled,
        wandb_project=config.wandb.project,
        wandb_config=OmegaConf.to_container(config, resolve=True),
    )

    try:
        avg_train_accuracy, avg_val_accuracy = [], []
        # Train and evaluate each task sequentially
        for task_id in range(config.data.num_tasks):
            print(f"\n=== Training Task {task_id + 1}/{config.data.num_tasks} ===")
            
            # Get task-specific dataloaders
            train_loader = split_cifar10.train_loaders[task_id]
            val_loader = split_cifar10.test_loaders[task_id]

            # Train and evaluate the task
            train_metrics, val_metrics = trainer.train_and_evaluate(
                train_loader=train_loader,
                val_loader=val_loader,
            )

            avg_train_accuracy.append(train_metrics['accuracies'][-1])
            avg_val_accuracy.append(val_metrics["accuracies"][-1])

            print(f"Task {task_id + 1} Training Accuracy: {train_metrics['accuracies'][-1]:.2f}%")
            print(f"Task {task_id + 1} Validation Accuracy: {val_metrics['accuracies'][-1]:.2f}%")

        # Calculate and print average accuracies
        if avg_train_accuracy and avg_val_accuracy:
            print(f"\n=== Average Multi-task Training Accuracy: {np.mean(avg_train_accuracy):.2f}% ===")
            print(f"=== Average Multi-task Validation Accuracy: {np.mean(avg_val_accuracy):.2f}% ===")
        
    except Exception as e:
        print(f"Training failed: {str(e)}")
        raise
    finally:
        torch.cuda.empty_cache()


    print("All tasks trained successfully!")


Run the training and evaluation for Split CIFAR-100

In [None]:
if __name__ == "__main__":
    main(config=config_scifar100)