# Multi-task Training Implementation

This notebook provides a comprehensive implementation of multi-task joint learning, applied to training models such as `MLP` and `CNN` on datasets like `Permuted MNIST` and `Split CIFAR-10/100`.

While the primary focus of this study is on Continual Learning, this section takes a different approach by leveraging the entire dataset, processed at once altogether/in sequential chunks rather than task-by-task progression.

## Packages and Presets

Imports and logging setup:

In [1]:
import os
import sys
import logging
import numpy as np
from datetime import datetime
import torch
import torch.nn as nn
from torch.optim import SGD

import warnings
from omegaconf import DictConfig, OmegaConf

sys.path.append(os.path.abspath("../"))
from modules.JointTrainer import JointTrainer
from modules.mlp import MLP
from modules.cnn import CNN
from utils.data_utils.permuted_mnist import PermutedMNIST
from utils.data_utils.sequential_CIFAR import CL_CIFAR10, CL_CIFAR100

warnings.filterwarnings("ignore")

# Set up logging:
log_dir = "../logs"
os.makedirs(log_dir, exist_ok=True)
log_filename = os.path.join(
    log_dir, f'multitask_training_{datetime.now().strftime("%Y-%m-%d_%H-%M-%S")}.log'
)

logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s - %(levelname)s - %(message)s",
    handlers=[
        logging.FileHandler(log_filename),
        logging.StreamHandler(),
    ],
)


## Main Training Function for Permuted MNIST

In [3]:
# Define the main training function
def main(config: DictConfig) -> None:
    """
    Main training function for Permuted MNIST multitask joint training.

    Args:
        config: Configuration object loaded from YAML.

    """
    # Set the save directory
    save_dir = config.hydra.run.dir
    os.makedirs(save_dir, exist_ok=True)
    logging.info(f"Saving results to {save_dir}")

    # Create permuted MNIST dataset
    pmnist = PermutedMNIST(
        num_tasks=config.data.num_tasks,
        seed=config.data.seed
    )
    pmnist.setup_tasks(
        batch_size=config.data.batch_size,
        data_root=config.data.data_root,
        num_workers=config.data.num_workers,
    )

    # Initialize model
    model = MLP(
        input_dim=config.model.input_dim,
        output_dim=config.model.output_dim,
        hidden_dim=config.model.hidden_dim,
    )
    model.to(config.training.device)

    # Initialize loss function
    criterion = nn.CrossEntropyLoss()

    # Initialize optimizer
    # Note that we do not need customized SubspaceSGD anymore;
    # the standard SGD will do the job
    optimizer = SGD(
        model.parameters(),
        lr=config.optimizer.lr,
        momentum=config.optimizer.momentum,
        weight_decay=config.optimizer.weight_decay,
        nesterov=config.optimizer.nesterov,
    )

    # Initialize joint trainer
    trainer = JointTrainer(
        model=model,
        optimizer=optimizer,
        criterion=criterion,
        save_dir=save_dir,
        num_tasks=config.data.num_tasks,
        num_epochs=config.training.num_epochs,
        log_interval=config.training.log_interval,
        eval_freq=config.training.eval_freq,
        task_il=True,  # Permuted MNIST is task-IL
        checkpoint_freq=config.training.checkpoint_freq,
        seed=config.training.seed,
        scheduler=None,
        device=config.training.device,
        use_wandb=config.wandb.enabled,
        wandb_project=config.wandb.project,
        wandb_config=OmegaConf.to_container(config, resolve=True),
    )

    try:
        # Train and evaluate jointly on all tasks
        test_accuracies, test_losses = trainer.train_and_evaluate(
            cl_dataset=pmnist
        )

        # Calculate final metrics
        final_accuracies = [test_accuracies[task_id][-1] for task_id in range(config.data.num_tasks)]
        final_avg_accuracy = np.mean(final_accuracies)

        # Log final results
        logging.info("\n=== Final Results ===")
        for task_id in range(config.data.num_tasks):
            logging.info(f"Task {task_id + 1} Accuracy: {final_accuracies[task_id]:.2f}%")
        logging.info(f"Final Average Accuracy: {final_avg_accuracy:.2f}%")

    except Exception as e:
        logging.error(f"Training failed: {str(e)}")
        raise
    finally:
        torch.cuda.empty_cache()

Run the training and evaluation for Permuted MNIST

In [4]:
# Load the YAML config as a plain text file
with open("../configs/permuted_mnist.yaml", "r") as f:
    config_str = f.read()

# Replace Hydra-style placeholders and convert backslashes to forward slashes (e.g. Windows-like, change for Linux/MacOS)
cwd = os.getcwd().replace("\\", "/")
config_str = config_str.replace("${hydra:runtime.cwd}", cwd)

# Load the updated config into OmegaConf
config_pmnist = OmegaConf.create(config_str)

# Dynamically resolve Hydra-style paths
config_pmnist.data.data_root = config_pmnist.data.data_root.replace("${hydra:runtime.cwd}", cwd)
config_pmnist.hydra.run.dir = config_pmnist.hydra.run.dir.replace("${hydra:runtime.cwd}", cwd)

if __name__ == "__main__":
    main(config=config_pmnist)

2025-01-07 09:32:57,642 - INFO - Saving results to c:/Users/rufat/cf-tiny-subspaces/notebooks/../results/permuted_mnist/subspace-None/k-10/batch_size-128/hidden_dim-100/lr-0.01/seed-42
[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33mrasadlii[0m ([33mml-projects[0m). Use [1m`wandb login --relogin`[0m to force relogin


2025-01-07 09:33:48,620 - INFO - Starting joint training on 10 tasks
Epoch 1/5: 100%|██████████| 469/469 [00:24<00:00, 18.92it/s, loss=6.8284, acc=78.59%] 
2025-01-07 09:34:26,871 - INFO - 
Epoch 1/5
2025-01-07 09:34:26,876 - INFO - Average Loss: 1.4087
2025-01-07 09:34:26,877 - INFO - Average Accuracy: 56.59%
2025-01-07 09:34:26,878 - INFO - Time: 31.89s
Evaluating Task 0: 100%|██████████| 79/79 [00:04<00:00, 16.71it/s]
Evaluating Task 1: 100%|██████████| 79/79 [00:04<00:00, 16.36it/s]
Evaluating Task 2: 100%|██████████| 79/79 [00:04<00:00, 15.85it/s]
Evaluating Task 3: 100%|██████████| 79/79 [00:04<00:00, 17.21it/s]
Evaluating Task 4: 100%|██████████| 79/79 [00:04<00:00, 16.56it/s]
Evaluating Task 5: 100%|██████████| 79/79 [00:04<00:00, 16.30it/s]
Evaluating Task 6: 100%|██████████| 79/79 [00:04<00:00, 17.14it/s]
Evaluating Task 7: 100%|██████████| 79/79 [00:04<00:00, 16.13it/s]
Evaluating Task 8: 100%|██████████| 79/79 [00:04<00:00, 16.26it/s]
Evaluating Task 9: 100%|██████████| 79/

## Main Training Function for Split CIFAR-10


In [2]:
# Define the main training function
def main(config):
    """
    Main training function for Split CIFAR-10 multitask joint training.
    
    Args:
        config: Configuration object loaded from YAML.
    """
    save_dir = config.hydra.run.dir
    os.makedirs(save_dir, exist_ok=True)
    logging.info(f"Saving results to {save_dir}")

    # Create Split CIFAR-10 dataset
    split_cifar10 = CL_CIFAR10(
        classes_per_task=config.data.classes_per_task,
        num_tasks=config.data.num_tasks,
        seed=config.data.seed,
    )
    split_cifar10.setup_tasks(
        batch_size=config.data.batch_size,
        data_root=config.data.data_root,
        num_workers=config.data.num_workers,
    )

    # Initialize the model
    model = CNN(
        width=config.model.width,
        num_tasks=config.data.num_tasks,
        classes_per_task=config.data.classes_per_task,
    )
    model.to(config.training.device)

    # Initialize loss function
    criterion = nn.CrossEntropyLoss()

    # Initialize optimizer
    optimizer = SGD(
        model.parameters(),
        lr=config.optimizer.lr,
        momentum=config.optimizer.momentum,
        weight_decay=config.optimizer.weight_decay,
        nesterov=config.optimizer.nesterov,
    )

    # Initialize joint trainer
    trainer = JointTrainer(
        model=model,
        optimizer=optimizer,
        criterion=criterion,
        save_dir=save_dir,
        num_tasks=config.data.num_tasks,
        num_epochs=config.training.num_epochs,
        log_interval=config.training.log_interval,
        eval_freq=config.training.eval_freq,
        task_il=True,  # Set to True for Split CIFAR-10
        checkpoint_freq=config.training.checkpoint_freq,
        seed=config.training.seed,
        scheduler=None,
        device=config.training.device,
        use_wandb=config.wandb.enabled,
        wandb_project=config.wandb.project,
        wandb_config=OmegaConf.to_container(config, resolve=True),
    )

    try:
        # Train and evaluate jointly on all tasks
        test_accuracies, test_losses = trainer.train_and_evaluate(
            cl_dataset=split_cifar10
        )

        # Calculate and log final metrics
        final_accuracies = [test_accuracies[task_id][-1] for task_id in range(config.data.num_tasks)]
        avg_accuracy = np.mean(final_accuracies)
        
        logging.info("\n=== Final Results ===")
        for task_id in range(config.data.num_tasks):
            logging.info(f"Task {task_id + 1} Accuracy: {final_accuracies[task_id]:.2f}%")
        logging.info(f"Average Accuracy across all tasks: {avg_accuracy:.2f}%")

    except Exception as e:
        logging.error(f"Training failed: {str(e)}")
        raise
    finally:
        torch.cuda.empty_cache()



Run the training and evaluation for Split CIFAR-10

In [3]:
# Load the YAML config as a plain text file
with open("../configs/split_cifar10.yaml", "r") as f:
    config_str = f.read()

# Replace Hydra-style placeholders and convert backslashes to forward slashes (e.g. Windows-like, change for Linux/MacOS)
cwd = os.getcwd().replace("\\", "/")
config_str = config_str.replace("${hydra:runtime.cwd}", cwd)

# Load the updated config into OmegaConf
config_scifar10 = OmegaConf.create(config_str)

# Dynamically resolve Hydra-style paths
config_scifar10.data.data_root = config_scifar10.data.data_root.replace("${hydra:runtime.cwd}", cwd)
config_scifar10.hydra.run.dir = config_scifar10.hydra.run.dir.replace("${hydra:runtime.cwd}", cwd)

if __name__ == "__main__":
    main(config=config_scifar10)

2025-01-07 09:00:07,691 - INFO - Saving results to c:/Users/rufat/cf-tiny-subspaces/notebooks/../results/split_cifar10/subspace-None/k-10/batch_size-128/width-32/lr-0.001/seed-42


Files already downloaded and verified
Files already downloaded and verified


[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33mrasadlii[0m ([33mml-projects[0m). Use [1m`wandb login --relogin`[0m to force relogin


2025-01-07 09:01:49,312 - INFO - Starting joint training on 5 tasks
Epoch 1/5: 100%|██████████| 79/79 [00:09<00:00,  8.36it/s, loss=3.3124, acc=63.12%]
2025-01-07 09:03:30,740 - INFO - 
Epoch 1/5
2025-01-07 09:03:30,744 - INFO - Average Loss: 0.6732
2025-01-07 09:03:30,744 - INFO - Average Accuracy: 59.27%
2025-01-07 09:03:30,745 - INFO - Time: 98.30s
Evaluating Task 0: 100%|██████████| 16/16 [00:15<00:00,  1.02it/s]
Evaluating Task 1: 100%|██████████| 16/16 [00:16<00:00,  1.02s/it]
Evaluating Task 2: 100%|██████████| 16/16 [00:17<00:00,  1.12s/it]
Evaluating Task 3: 100%|██████████| 16/16 [00:17<00:00,  1.08s/it]
Evaluating Task 4: 100%|██████████| 16/16 [00:16<00:00,  1.06s/it]
2025-01-07 09:04:54,941 - INFO - Best model saved: c:\Users\rufat\cf-tiny-subspaces\notebooks\..\results\split_cifar10\subspace-None\k-10\batch_size-128\width-32\lr-0.001\seed-42\models\model_best.pt
2025-01-07 09:04:54,943 - INFO - Checkpoint saved: c:\Users\rufat\cf-tiny-subspaces\notebooks\..\results\split_

## Main training function for Split CIFAR-100.

In [None]:
# Define the main training function
def main(config):
    """
    Main training function for Split CIFAR-100 multitask joint training.
    
    Args:
        config: Configuration object loaded from YAML.
    """
    save_dir = config.hydra.run.dir
    os.makedirs(save_dir, exist_ok=True)
    logging.info(f"Saving results to {save_dir}")

    # Create Split CIFAR-100 dataset
    split_cifar100 = CL_CIFAR100(
        classes_per_task=config.data.classes_per_task,
        num_tasks=config.data.num_tasks,
        seed=config.data.seed,
    )
    split_cifar100.setup_tasks(
        batch_size=config.data.batch_size,
        data_root=config.data.data_root,
        num_workers=config.data.num_workers,
    )

    # Initialize the model
    model = CNN(
        width=config.model.width,
        num_tasks=config.data.num_tasks,
        classes_per_task=config.data.classes_per_task,
    )
    model.to(config.training.device)

    # Initialize loss function
    criterion = nn.CrossEntropyLoss()

    # Initialize optimizer
    optimizer = SGD(
        model.parameters(),
        lr=config.optimizer.lr,
        momentum=config.optimizer.momentum,
        weight_decay=config.optimizer.weight_decay,
        nesterov=config.optimizer.nesterov,
    )

    # Initialize joint trainer
    trainer = JointTrainer(
        model=model,
        optimizer=optimizer,
        criterion=criterion,
        save_dir=save_dir,
        num_tasks=config.data.num_tasks,
        num_epochs=config.training.num_epochs,
        log_interval=config.training.log_interval,
        eval_freq=config.training.eval_freq,
        task_il=True,  # Set to True for Split CIFAR-100
        checkpoint_freq=config.training.checkpoint_freq,
        seed=config.training.seed,
        scheduler=None,
        device=config.training.device,
        use_wandb=config.wandb.enabled,
        wandb_project=config.wandb.project,
        wandb_config=OmegaConf.to_container(config, resolve=True),
    )

    try:
        # Train and evaluate jointly on all tasks
        test_accuracies, test_losses = trainer.train_and_evaluate(
            cl_dataset=split_cifar100
        )

        # Calculate and log final metrics
        final_accuracies = [test_accuracies[task_id][-1] for task_id in range(config.data.num_tasks)]
        avg_accuracy = np.mean(final_accuracies)
        
        logging.info("\n=== Final Results ===")
        for task_id in range(config.data.num_tasks):
            logging.info(f"Task {task_id + 1} Accuracy: {final_accuracies[task_id]:.2f}%")
        logging.info(f"Average Accuracy across all tasks: {avg_accuracy:.2f}%")

    except Exception as e:
        logging.error(f"Training failed: {str(e)}")
        raise
    finally:
        torch.cuda.empty_cache()

Run the training and evaluation for Split CIFAR-100

In [None]:
# Load the YAML config as a plain text file
with open("../configs/split_cifar100.yaml", "r") as f:
    config_str = f.read()

# Replace Hydra-style placeholders and convert backslashes to forward slashes (e.g. Windows-like, change for Linux/MacOS)
cwd = os.getcwd().replace("\\", "/")
config_str = config_str.replace("${hydra:runtime.cwd}", cwd)

# Load the updated config into OmegaConf
config_scifar100 = OmegaConf.create(config_str)

# Dynamically resolve Hydra-style paths
config_scifar100.data.data_root = config_scifar100.data.data_root.replace("${hydra:runtime.cwd}", cwd)
config_scifar100.hydra.run.dir = config_scifar100.hydra.run.dir.replace("${hydra:runtime.cwd}", cwd)

if __name__ == "__main__":
    main(config=config_scifar100)