# Multi-task Training Implementation

This notebook provides a comprehensive implementation of multi-task joint learning, applied to training models such as `MLP` and `CNN` on datasets like `Permuted MNIST` and `Split CIFAR-10/100`.

While the primary focus of this study is on Continual Learning, this section takes a different approach by leveraging the entire dataset, processed at once altogether/in sequential chunks rather than task-by-task progression.

## Packages and Presets

Imports and logging setup:

In [1]:
import os
import sys
import logging
import numpy as np
from datetime import datetime
import torch
import torch.nn as nn
from torch.optim import SGD
import pandas as pd

import warnings
from omegaconf import DictConfig, OmegaConf

sys.path.append(os.path.abspath("../"))
from modules.JointTrainer import JointTrainer
from modules.mlp import MLP
from modules.cnn import CNN
from utils.data_utils.permuted_mnist import PermutedMNIST
from utils.data_utils.sequential_CIFAR import CL_CIFAR10, CL_CIFAR100

warnings.filterwarnings("ignore")

# Set up logging:
log_dir = "../logs"
os.makedirs(log_dir, exist_ok=True)
log_filename = os.path.join(
    log_dir, f'multitask_training_{datetime.now().strftime("%Y-%m-%d_%H-%M-%S")}.log'
)

logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s - %(levelname)s - %(message)s",
    handlers=[
        logging.FileHandler(log_filename),
        logging.StreamHandler(),
    ],
)

%load_ext autoreload

## Main Training Function for Permuted MNIST

In [2]:
# Define the main training function
def main(config: DictConfig) -> None:
    """
    Main training function for Permuted MNIST multitask joint training.

    Args:
        config: Configuration object loaded from YAML.

    """
    # Set the save directory
    save_dir = config.hydra.run.dir
    os.makedirs(save_dir, exist_ok=True)
    logging.info(f"Saving results to {save_dir}")

    # Create permuted MNIST dataset
    pmnist = PermutedMNIST(num_tasks=config.data.num_tasks, seed=config.data.seed)
    pmnist.setup_tasks(
        # adjust batch size as during training we aggregate losses of all tasks
        # befor every optimization step
        batch_size=config.data.batch_size // config.data.num_tasks,
        data_root=config.data.data_root,
        num_workers=config.data.num_workers,
    )

    # Initialize model
    model = MLP(
        input_dim=config.model.input_dim,
        output_dim=config.model.output_dim,
        hidden_dim=config.model.hidden_dim,
    )
    model.to(config.training.device)

    # Initialize loss function
    criterion = nn.CrossEntropyLoss()

    # Initialize optimizer
    # Note that we do not need customized SubspaceSGD anymore;
    # the standard SGD will do the job
    optimizer = SGD(
        model.parameters(),
        lr=config.optimizer.lr,
        momentum=config.optimizer.momentum,
        weight_decay=config.optimizer.weight_decay,
        nesterov=config.optimizer.nesterov,
    )

    # Initialize joint trainer
    trainer = JointTrainer(
        model=model,
        optimizer=optimizer,
        criterion=criterion,
        save_dir=save_dir,
        num_tasks=config.data.num_tasks,
        num_epochs=config.training.num_epochs,
        log_interval=config.training.log_interval,
        eval_freq=config.training.eval_freq,
        checkpoint_freq=config.training.checkpoint_freq,
        seed=config.training.seed,
        cl_batch_size=config.data.batch_size,
        scheduler=None,
        device=config.training.device,
        use_wandb=config.wandb.enabled,
        wandb_project=config.wandb.project,
        wandb_config=OmegaConf.to_container(config, resolve=True),
    )

    try:
        # Train and evaluate jointly on all tasks
        test_accuracies, test_losses = trainer.train_and_evaluate(cl_dataset=pmnist)

        # Calculate final metrics
        final_accuracies = [
            test_accuracies[task_id][-1] for task_id in range(config.data.num_tasks)
        ]
        final_avg_accuracy = np.mean(final_accuracies)

        pd.DataFrame(
            {"average_accuracy": [final_avg_accuracy], "forgetting": [0]}
        ).to_csv("forgetting_metrics_pmnist.csv", index=False)
        print(f"saved to {os.getcwd()}/forgetting_metrics_pmnist.csv")

        # Log final results
        logging.info("\n=== Final Results ===")
        for task_id in range(config.data.num_tasks):
            logging.info(
                f"Task {task_id + 1} Accuracy: {final_accuracies[task_id]:.2f}%"
            )
        logging.info(f"Final Average Accuracy: {final_avg_accuracy:.2f}%")

    except Exception as e:
        logging.error(f"Training failed: {str(e)}")
        raise
    finally:
        torch.cuda.empty_cache()

Run the training and evaluation for Permuted MNIST

In [3]:
# Load the YAML config as a plain text file
with open("../configs/permuted_mnist.yaml", "r") as f:
    config_str = f.read()

# Replace Hydra-style placeholders and convert backslashes to forward slashes (e.g. Windows-like, change for Linux/MacOS)
cwd = os.getcwd().replace("\\", "/")
config_str = config_str.replace("${hydra:runtime.cwd}", cwd)

# Load the updated config into OmegaConf
config_pmnist = OmegaConf.create(config_str)

# Dynamically resolve Hydra-style paths
config_pmnist.data.data_root = config_pmnist.data.data_root.replace(
    "${hydra:runtime.cwd}", cwd
)
config_pmnist.hydra.run.dir = config_pmnist.hydra.run.dir.replace(
    "${hydra:runtime.cwd}", cwd
)

if __name__ == "__main__":
    main(config=config_pmnist)

2025-01-07 12:39:33,382 - INFO - Saving results to /home/janhsc/Documents/ETH/Master 3. Semester/DeepLearning/cf-tiny-subspaces/notebooks/../results/permuted_mnist/subspace-None/k-10/batch_size-128/hidden_dim-100/lr-0.01/seed-42
[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33mjanschlegel[0m ([33mjanschlegel-eth-z-rich[0m). Use [1m`wandb login --relogin`[0m to force relogin


2025-01-07 12:39:35,741 - INFO - Starting joint training on 10 tasks
Epoch 1/5: 100%|██████████| 4680/4680 [00:38<00:00, 121.02it/s, loss=2.2274, acc=94.17%]
2025-01-07 12:40:15,433 - INFO - 
Epoch 1/5
2025-01-07 12:40:15,434 - INFO - Average Loss: 0.4584
2025-01-07 12:40:15,434 - INFO - Average Accuracy: 85.84%
2025-01-07 12:40:15,434 - INFO - Time: 39.30s
Evaluating Task 0: 100%|██████████| 834/834 [00:00<00:00, 1385.36it/s]
Evaluating Task 1: 100%|██████████| 834/834 [00:00<00:00, 1528.52it/s]
Evaluating Task 2: 100%|██████████| 834/834 [00:00<00:00, 1506.37it/s]
Evaluating Task 3: 100%|██████████| 834/834 [00:00<00:00, 1507.31it/s]
Evaluating Task 4: 100%|██████████| 834/834 [00:00<00:00, 1389.35it/s]
Evaluating Task 5: 100%|██████████| 834/834 [00:00<00:00, 1543.55it/s]
Evaluating Task 6: 100%|██████████| 834/834 [00:00<00:00, 1451.03it/s]
Evaluating Task 7: 100%|██████████| 834/834 [00:00<00:00, 1530.34it/s]
Evaluating Task 8: 100%|██████████| 834/834 [00:00<00:00, 1476.59it/s]
E

saved to /home/janhsc/Documents/ETH/Master 3. Semester/DeepLearning/cf-tiny-subspaces/notebooks/forgetting_metrics_pmnist.csv


## Main Training Function for Split CIFAR-10


In [4]:
%autoreload 2
# Define the main training function
def main(config):
    """
    Main training function for Split CIFAR-10 multitask joint training.
    
    Args:
        config: Configuration object loaded from YAML.
    """
    save_dir = config.hydra.run.dir
    os.makedirs(save_dir, exist_ok=True)
    logging.info(f"Saving results to {save_dir}")

    # Create Split CIFAR-10 dataset
    split_cifar10 = CL_CIFAR10(
        classes_per_task=config.data.classes_per_task,
        num_tasks=config.data.num_tasks,
        seed=config.data.seed,
    )
    split_cifar10.setup_tasks(
        # adjust batch size as during training we aggregate losses of all tasks
        # befor every optimization step
        batch_size=config.data.batch_size // config.data.num_tasks,
        data_root=config.data.data_root,
        num_workers=config.data.num_workers,
    )

    # Initialize the model
    model = CNN(
        width=config.model.width,
        num_tasks=config.data.num_tasks,
        classes_per_task=config.data.classes_per_task,
    )
    model.to(config.training.device)

    # Initialize loss function
    criterion = nn.CrossEntropyLoss()

    # Initialize optimizer
    optimizer = SGD(
        model.parameters(),
        lr=config.optimizer.lr,
        momentum=config.optimizer.momentum,
        weight_decay=config.optimizer.weight_decay,
        nesterov=config.optimizer.nesterov,
    )

    # Initialize joint trainer
    trainer = JointTrainer(
        model=model,
        optimizer=optimizer,
        criterion=criterion,
        save_dir=save_dir,
        num_tasks=config.data.num_tasks,
        num_epochs=config.training.num_epochs,
        log_interval=config.training.log_interval,
        eval_freq=config.training.eval_freq,
        checkpoint_freq=config.training.checkpoint_freq,
        seed=config.training.seed,
        cl_batch_size=config.data.batch_size,
        scheduler=None,
        device=config.training.device,
        use_wandb=config.wandb.enabled,
        wandb_project=config.wandb.project,
        wandb_config=OmegaConf.to_container(config, resolve=True),
    )

    try:
        # Train and evaluate jointly on all tasks
        test_accuracies, test_losses = trainer.train_and_evaluate(
            cl_dataset=split_cifar10
        )

        # Calculate and log final metrics
        final_accuracies = [test_accuracies[task_id][-1] for task_id in range(config.data.num_tasks)]
        avg_accuracy = np.mean(final_accuracies)
        
        pd.DataFrame({"average_accuracy": [avg_accuracy], "forgetting": [0]}).to_csv(
           "forgetting_metrics_cifar10.csv", index=False
        )
        print(f"saved to {os.getcwd()}/forgetting_metrics_cifar10.csv")
        
        logging.info("\n=== Final Results ===")
        for task_id in range(config.data.num_tasks):
            logging.info(f"Task {task_id + 1} Accuracy: {final_accuracies[task_id]:.2f}%")
        logging.info(f"Average Accuracy across all tasks: {avg_accuracy:.2f}%")

    except Exception as e:
        logging.error(f"Training failed: {str(e)}")
        raise
    finally:
        torch.cuda.empty_cache()



Run the training and evaluation for Split CIFAR-10

In [5]:
# Load the YAML config as a plain text file
with open("../configs/split_cifar10.yaml", "r") as f:
    config_str = f.read()

# Replace Hydra-style placeholders and convert backslashes to forward slashes (e.g. Windows-like, change for Linux/MacOS)
cwd = os.getcwd().replace("\\", "/")
config_str = config_str.replace("${hydra:runtime.cwd}", cwd)

# Load the updated config into OmegaConf
config_scifar10 = OmegaConf.create(config_str)

# Dynamically resolve Hydra-style paths
config_scifar10.data.data_root = config_scifar10.data.data_root.replace(
    "${hydra:runtime.cwd}", cwd
)
config_scifar10.hydra.run.dir = config_scifar10.hydra.run.dir.replace(
    "${hydra:runtime.cwd}", cwd
)

if __name__ == "__main__":
    main(config=config_scifar10)

2025-01-07 12:43:16,234 - INFO - Saving results to /home/janhsc/Documents/ETH/Master 3. Semester/DeepLearning/cf-tiny-subspaces/notebooks/../results/split_cifar10/subspace-None/k-10/batch_size-128/width-32/lr-0.001/seed-42


Files already downloaded and verified
Files already downloaded and verified


2025-01-07 12:43:52,098 - INFO - Starting joint training on 5 tasks
Epoch 1/5: 100%|██████████| 390/390 [00:02<00:00, 179.85it/s, loss=2.9637, acc=68.80%]
2025-01-07 12:43:54,654 - INFO - 
Epoch 1/5
2025-01-07 12:43:54,655 - INFO - Average Loss: 0.6219
2025-01-07 12:43:54,655 - INFO - Average Accuracy: 66.30%
2025-01-07 12:43:54,655 - INFO - Time: 2.40s
Evaluating Task 0: 100%|██████████| 80/80 [00:00<00:00, 567.64it/s]
Evaluating Task 1: 100%|██████████| 80/80 [00:00<00:00, 548.12it/s]
Evaluating Task 2: 100%|██████████| 80/80 [00:00<00:00, 554.66it/s]
Evaluating Task 3: 100%|██████████| 80/80 [00:00<00:00, 624.92it/s]
Evaluating Task 4: 100%|██████████| 80/80 [00:00<00:00, 581.72it/s]
2025-01-07 12:43:55,361 - INFO - Best model saved: /home/janhsc/Documents/ETH/Master 3. Semester/DeepLearning/cf-tiny-subspaces/notebooks/../results/split_cifar10/subspace-None/k-10/batch_size-128/width-32/lr-0.001/seed-42/models/model_best.pt
2025-01-07 12:43:55,362 - INFO - Checkpoint saved: /home/jan

saved to /home/janhsc/Documents/ETH/Master 3. Semester/DeepLearning/cf-tiny-subspaces/notebooks/forgetting_metrics_cifar10.csv


## Main training function for Split CIFAR-100.

In [6]:
%autoreload 2
# Define the main training function
def main(config):
    """
    Main training function for Split CIFAR-100 multitask joint training.
    
    Args:
        config: Configuration object loaded from YAML.
    """
    save_dir = config.hydra.run.dir
    os.makedirs(save_dir, exist_ok=True)
    logging.info(f"Saving results to {save_dir}")

    # Create Split CIFAR-100 dataset
    split_cifar100 = CL_CIFAR100(
        classes_per_task=config.data.classes_per_task,
        num_tasks=config.data.num_tasks,
        seed=config.data.seed,
    )
    split_cifar100.setup_tasks(
        # adjust batch size as during training we aggregate losses of all tasks
        # befor every optimization step
        batch_size=config.data.batch_size // config.data.num_tasks, 
        data_root=config.data.data_root,
        num_workers=config.data.num_workers,
    )
    print(f"{config.data.batch_size // config.data.num_tasks}")

    # Initialize the model
    model = CNN(
        width=config.model.width,
        num_tasks=config.data.num_tasks,
        classes_per_task=config.data.classes_per_task,
    )
    model.to(config.training.device)

    # Initialize loss function
    criterion = nn.CrossEntropyLoss()

    # Initialize optimizer
    optimizer = SGD(
        model.parameters(),
        lr=config.optimizer.lr,
        momentum=config.optimizer.momentum,
        weight_decay=config.optimizer.weight_decay,
        nesterov=config.optimizer.nesterov,
    )

    # Initialize joint trainer
    trainer = JointTrainer(
        model=model,
        optimizer=optimizer,
        criterion=criterion,
        save_dir=save_dir,
        num_tasks=config.data.num_tasks,
        num_epochs=config.training.num_epochs,
        log_interval=config.training.log_interval,
        eval_freq=config.training.eval_freq,
        checkpoint_freq=config.training.checkpoint_freq,
        seed=config.training.seed,
        cl_batch_size=config.data.batch_size,
        scheduler=None,
        device=config.training.device,
        use_wandb=config.wandb.enabled,
        wandb_project=config.wandb.project,
        wandb_config=OmegaConf.to_container(config, resolve=True),
    )

    try:
        # Train and evaluate jointly on all tasks
        test_accuracies, test_losses = trainer.train_and_evaluate(
            cl_dataset=split_cifar100
        )

        # Calculate and log final metrics
        final_accuracies = [test_accuracies[task_id][-1] for task_id in range(config.data.num_tasks)]
        avg_accuracy = np.mean(final_accuracies)
        
        # write average accuracy and 0 forgetting to csv:
        pd.DataFrame({"average_accuracy": [avg_accuracy], "forgetting": [0]}).to_csv(
           "forgetting_metrics_cifar100.csv", index=False
        )
        print(f"saved to {os.getcwd()}/forgetting_metrics_cifar100.csv")
        
        logging.info("\n=== Final Results ===")
        for task_id in range(config.data.num_tasks):
            logging.info(f"Task {task_id + 1} Accuracy: {final_accuracies[task_id]:.2f}%")
        logging.info(f"Average Accuracy across all tasks: {avg_accuracy:.2f}%")

    except Exception as e:
        logging.error(f"Training failed: {str(e)}")
        raise
    finally:
        torch.cuda.empty_cache()

Run the training and evaluation for Split CIFAR-100

In [7]:
# Load the YAML config as a plain text file
with open("../configs/split_cifar100.yaml", "r") as f:
    config_str = f.read()

# Replace Hydra-style placeholders and convert backslashes to forward slashes (e.g. Windows-like, change for Linux/MacOS)
cwd = os.getcwd().replace("\\", "/")
config_str = config_str.replace("${hydra:runtime.cwd}", cwd)

# Load the updated config into OmegaConf
config_scifar100 = OmegaConf.create(config_str)

# Dynamically resolve Hydra-style paths
config_scifar100.data.data_root = config_scifar100.data.data_root.replace(
    "${hydra:runtime.cwd}", cwd
)
config_scifar100.hydra.run.dir = config_scifar100.hydra.run.dir.replace(
    "${hydra:runtime.cwd}", cwd
)

if __name__ == "__main__":
    main(config=config_scifar100)

2025-01-07 12:44:07,989 - INFO - Saving results to /home/janhsc/Documents/ETH/Master 3. Semester/DeepLearning/cf-tiny-subspaces/notebooks/../results/split_cifar100/subspace-None/k-10/batch_size-128/width-32/lr-0.001/seed-42


Files already downloaded and verified
Files already downloaded and verified


2025-01-07 12:45:20,000 - INFO - Starting joint training on 10 tasks


12


Epoch 1/5: 100%|██████████| 390/390 [00:03<00:00, 98.37it/s, loss=19.8749, acc=28.33%] 
2025-01-07 12:45:24,788 - INFO - 
Epoch 1/5
2025-01-07 12:45:24,789 - INFO - Average Loss: 2.1819
2025-01-07 12:45:24,789 - INFO - Average Accuracy: 20.08%
2025-01-07 12:45:24,790 - INFO - Time: 4.47s
Evaluating Task 0: 100%|██████████| 84/84 [00:00<00:00, 603.07it/s]
Evaluating Task 1: 100%|██████████| 84/84 [00:00<00:00, 673.24it/s]
Evaluating Task 2: 100%|██████████| 84/84 [00:00<00:00, 704.73it/s]
Evaluating Task 3: 100%|██████████| 84/84 [00:00<00:00, 723.27it/s]
Evaluating Task 4: 100%|██████████| 84/84 [00:00<00:00, 716.56it/s]
Evaluating Task 5: 100%|██████████| 84/84 [00:00<00:00, 676.20it/s]
Evaluating Task 6: 100%|██████████| 84/84 [00:00<00:00, 695.25it/s]
Evaluating Task 7: 100%|██████████| 84/84 [00:00<00:00, 637.41it/s]
Evaluating Task 8: 100%|██████████| 84/84 [00:00<00:00, 772.60it/s]
Evaluating Task 9: 100%|██████████| 84/84 [00:00<00:00, 665.79it/s]
2025-01-07 12:45:26,038 - INFO 

saved to /home/janhsc/Documents/ETH/Master 3. Semester/DeepLearning/cf-tiny-subspaces/notebooks/forgetting_metrics_cifar100.csv
