# Multi-task Training Implementation

This notebook provides a comprehensive implementation of multi-task learning, applied to training models such as `MLP` and `CNN` on datasets like `Permuted MNIST` and `Split CIFAR-10/100`.

While the primary focus of this study is on Continual Learning, this section takes a different approach by leveraging the entire dataset, processed at once altogether/in sequential chunks rather than task-by-task progression.

## Packages and Presets

Imports and logging setup:

In [1]:
import os
import sys
import logging
from datetime import datetime
import torch
import torch.nn as nn
from torch.optim import SGD
from torch.utils.data import DataLoader, ConcatDataset

import warnings
from omegaconf import DictConfig, OmegaConf

sys.path.append(os.path.abspath("../"))
from modules.MTLTrainer import MTLTrainer
from modules.mlp import MLP
from modules.cnn import MultCNN
from utils.data_utils.permuted_mnist import PermutedMNIST
from utils.data_utils.sequential_CIFAR import CL_CIFAR10, CL_CIFAR100

warnings.filterwarnings("ignore")

# Set up logging:
log_dir = "../logs"
os.makedirs(log_dir, exist_ok=True)
log_filename = os.path.join(
    log_dir, f'multitask_training_{datetime.now().strftime("%Y-%m-%d_%H:%M:%S")}.log'
)

logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s - %(levelname)s - %(message)s",
    handlers=[
        logging.FileHandler(log_filename),
        logging.StreamHandler(),
    ],
)


## Main Training Function for Permuted MNIST

In [2]:
# Load the YAML config as a plain text file
with open("../configs/permuted_mnist.yaml", "r") as f:
    config_str = f.read()

# Replace Hydra-style placeholders and convert backslashes to forward slashes (e.g. Windows-like, change for Linux/MacOS)
cwd = os.getcwd().replace("\\", "/")
config_str = config_str.replace("${hydra:runtime.cwd}", cwd)

# Load the updated config into OmegaConf
config_pmnist = OmegaConf.create(config_str)

# Dynamically resolve Hydra-style paths
config_pmnist.data.data_root = config_pmnist.data.data_root.replace("${hydra:runtime.cwd}", cwd)
config_pmnist.hydra.run.dir = config_pmnist.hydra.run.dir.replace("${hydra:runtime.cwd}", cwd)


# Define the main training function
def main(config: DictConfig) -> None:
    """
    Main training function for Permuted MNIST multi-task learning.

    Args:
        config: Configuration object loaded from YAML.

    """
    # Set the save directory
    save_dir = config.hydra.run.dir
    os.makedirs(save_dir, exist_ok=True)
    logging.info(f"Saving results to {save_dir}")

    # Create permuted MNIST datasets but combine them into one loader
    pmnist = PermutedMNIST(num_tasks=config.data.num_tasks, seed=config.data.seed)
    pmnist.setup_tasks(
        batch_size=config.data.batch_size,
        data_root=config.data.data_root,
        num_workers=config.data.num_workers,
    )

    # Combine all task datasets into one
    combined_train_dataset = ConcatDataset([
        pmnist.train_loaders[task_id].dataset for task_id in range(config.data.num_tasks)
    ])
    combined_test_dataset = ConcatDataset([
        pmnist.test_loaders[task_id].dataset for task_id in range(config.data.num_tasks)
    ])

    # Create unified dataloaders
    train_loader = DataLoader(
        combined_train_dataset,
        batch_size=config.data.batch_size,
        shuffle=True,
        num_workers=config.data.num_workers,
        pin_memory=True
    )
    
    val_loader = DataLoader(
        combined_test_dataset,
        batch_size=config.data.batch_size,
        shuffle=False,
        num_workers=config.data.num_workers,
        pin_memory=True
    )

    # Initialize model
    model = MLP(
        input_dim=config.model.input_dim,
        output_dim=config.model.output_dim, 
        hidden_dim=config.model.hidden_dim,
    )
    model.to(config.training.device)

    # Initialize loss function
    criterion = nn.CrossEntropyLoss()

    # Initialize optimizer
    optimizer = SGD(
        model.parameters(),
        lr=config.optimizer.lr,
        momentum=config.optimizer.momentum,
        weight_decay=config.optimizer.weight_decay,
        nesterov=config.optimizer.nesterov,
    )

    # Initialize trainer
    trainer = MTLTrainer(
        model=model,
        optimizer=optimizer,
        criterion=criterion,
        save_dir=save_dir,
        num_epochs=config.training.num_epochs,
        log_interval=config.training.log_interval,
        eval_freq=config.training.eval_freq,
        checkpoint_freq=config.training.checkpoint_freq,
        seed=config.training.seed,
        subspace_type=config.training.subspace_type,
        scheduler=None,
        device=config.training.device,
        use_wandb=config.wandb.enabled,
        wandb_project=config.wandb.project,
        wandb_config=OmegaConf.to_container(config, resolve=True),
    )

    try:
        # Train and evaluate
        train_metrics, val_metrics = trainer.train_and_evaluate(
            train_loader=train_loader,
            val_loader=val_loader,
        )

        logging.info("Training completed successfully!")
        logging.info(f"Final Multi-task training accuracy: {train_metrics['accuracies'][-1]:.2f}%")
        logging.info(f"Final Multi-task validation accuracy: {val_metrics['accuracies'][-1]:.2f}%")

    except Exception as e:
        logging.error(f"Training failed: {str(e)}")
        raise
    finally:
        torch.cuda.empty_cache()


Run the training and evaluation for Permuted MNIST

In [3]:
if __name__ == "__main__":
    main(config=config_pmnist)

Saving results to c:/Users/rufat/cf-tiny-subspaces/notebooks/../results/permuted_mnist/subspace-None/k-10/batch_size-128/hidden_dim-100/lr-0.01/seed-42


[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33mrasadlii[0m ([33mml-projects[0m). Use [1m`wandb login --relogin`[0m to force relogin


Epoch 0: 100%|██████████| 4688/4688 [00:39<00:00, 119.48it/s, loss=1.4277, acc=55.50%]
Evaluating: 100%|██████████| 782/782 [00:07<00:00, 105.77it/s, loss=0.5957, acc=81.95%]
INFO:root:Checkpoint saved: c:\Users\rufat\cf-tiny-subspaces\notebooks\..\results\permuted_mnist\subspace-None\k-10\batch_size-128\hidden_dim-100\lr-0.01\seed-42\models\model_0.pt
Epoch 1: 100%|██████████| 4688/4688 [00:32<00:00, 144.43it/s, loss=0.5190, acc=84.38%]
Evaluating: 100%|██████████| 782/782 [00:07<00:00, 104.77it/s, loss=0.4398, acc=87.01%]
INFO:root:Checkpoint saved: c:\Users\rufat\cf-tiny-subspaces\notebooks\..\results\permuted_mnist\subspace-None\k-10\batch_size-128\hidden_dim-100\lr-0.01\seed-42\models\model_1.pt
Epoch 2: 100%|██████████| 4688/4688 [00:32<00:00, 145.65it/s, loss=0.4247, acc=87.37%]
Evaluating: 100%|██████████| 782/782 [00:07<00:00, 101.50it/s, loss=0.3788, acc=88.72%]
INFO:root:Checkpoint saved: c:\Users\rufat\cf-tiny-subspaces\notebooks\..\results\permuted_mnist\subspace-None\k-10

Training completed successfully!
Final training accuracy: 90.30%
Final validation accuracy: 91.22%


## Main Training Function for Split CIFAR-10


In [2]:
# Load the YAML config as a plain text file
with open("../configs/split_cifar10.yaml", "r") as f:
    config_str = f.read()

# Replace Hydra-style placeholders and convert backslashes to forward slashes (e.g. Windows-like, change for Linux/MacOS)
cwd = os.getcwd().replace("\\", "/")
config_str = config_str.replace("${hydra:runtime.cwd}", cwd)

# Load the updated config into OmegaConf
config_scifar10 = OmegaConf.create(config_str)

# Dynamically resolve Hydra-style paths
config_scifar10.data.data_root = config_scifar10.data.data_root.replace("${hydra:runtime.cwd}", cwd)
config_scifar10.hydra.run.dir = config_scifar10.hydra.run.dir.replace("${hydra:runtime.cwd}", cwd)


# Define the main training function
def main(config):
    """
    Main training function for Split CIFAR multitask learning.
    
    Args:
        config: Configuration object loaded from YAML.
    """
    save_dir = config.hydra.run.dir
    os.makedirs(save_dir, exist_ok=True)
    logging.info(f"Saving results to {save_dir}")

    # Create sequential CIFAR-10 dataset
    split_cifar10 = CL_CIFAR10(
        classes_per_task=config.data.classes_per_task,
        num_tasks=config.data.num_tasks,
        seed=config.data.seed,
    )
    split_cifar10.setup_tasks(
        batch_size=config.data.batch_size,
        data_root=config.data.data_root,
        num_workers=config.data.num_workers,
    )

    # Combine all task datasets into one
    combined_train_dataset = ConcatDataset([
        split_cifar10.train_loaders[task_id].dataset for task_id in range(config.data.num_tasks)
    ])
    combined_test_dataset = ConcatDataset([
        split_cifar10.test_loaders[task_id].dataset for task_id in range(config.data.num_tasks)
    ])

    # Create unified dataloaders
    train_loader = DataLoader(
        combined_train_dataset,
        batch_size=config.data.batch_size,
        shuffle=True,
        num_workers=config.data.num_workers,
        pin_memory=True
    )
    
    val_loader = DataLoader(
        combined_test_dataset,
        batch_size=config.data.batch_size,
        shuffle=False,
        num_workers=config.data.num_workers,
        pin_memory=True
    )

    # Initialize model
    model = MultCNN(
        width=config.model.width,
        num_classes=config.data.num_tasks * config.data.classes_per_task,
    )
    model.to(config.training.device)

    # Initialize loss function
    criterion = nn.CrossEntropyLoss()

    # Initialize optimizer
    optimizer = SGD(
        model.parameters(),
        lr=config.optimizer.lr,
        momentum=config.optimizer.momentum,
        weight_decay=config.optimizer.weight_decay,
        nesterov=config.optimizer.nesterov,
    )

    # Initialize trainer
    trainer = MTLTrainer(
        model=model,
        optimizer=optimizer,
        criterion=criterion,
        save_dir=save_dir,
        num_epochs=config.training.num_epochs,
        log_interval=config.training.log_interval,
        eval_freq=config.training.eval_freq,
        checkpoint_freq=config.training.checkpoint_freq,
        seed=config.training.seed,
        subspace_type=config.training.subspace_type,
        scheduler=None,
        device=config.training.device,
        use_wandb=config.wandb.enabled,
        wandb_project=config.wandb.project,
        wandb_config=OmegaConf.to_container(config, resolve=True),
    )

    try:
        # Train and evaluate
        train_metrics, val_metrics = trainer.train_and_evaluate(
            train_loader=train_loader,
            val_loader=val_loader,
        )

        logging.info("Training completed successfully!")
        logging.info(f"Final Multi-task training accuracy: {train_metrics['accuracies'][-1]:.2f}%")
        logging.info(f"Final Multi-task validation accuracy: {val_metrics['accuracies'][-1]:.2f}%")

    except Exception as e:
        logging.error(f"Training failed: {str(e)}")
        raise
    finally:
        torch.cuda.empty_cache()


Run the training and evaluation for Split CIFAR-10

In [3]:
if __name__ == "__main__":
    main(config=config_scifar10)

Saving results to c:/Users/rufat/cf-tiny-subspaces/notebooks/../results/split_cifar10/subspace-None/k-10/batch_size-32/width-32/lr-0.001/seed-42
Files already downloaded and verified
Files already downloaded and verified


[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33mrasadlii[0m ([33mml-projects[0m). Use [1m`wandb login --relogin`[0m to force relogin


Epoch 0: 100%|██████████| 1563/1563 [00:55<00:00, 28.41it/s, loss=0.7087, acc=58.02%]
Evaluating: 100%|██████████| 313/313 [00:19<00:00, 15.88it/s, loss=0.6829, acc=57.25%]
INFO:root:Checkpoint saved: c:\Users\rufat\cf-tiny-subspaces\notebooks\..\results\split_cifar10\subspace-None\k-10\batch_size-32\width-32\lr-0.001\seed-42\models\model_0.pt
Epoch 1: 100%|██████████| 1563/1563 [00:50<00:00, 31.17it/s, loss=0.6534, acc=61.82%] 
Evaluating: 100%|██████████| 313/313 [00:19<00:00, 15.91it/s, loss=0.6477, acc=63.26%]
INFO:root:Checkpoint saved: c:\Users\rufat\cf-tiny-subspaces\notebooks\..\results\split_cifar10\subspace-None\k-10\batch_size-32\width-32\lr-0.001\seed-42\models\model_1.pt
Epoch 2: 100%|██████████| 1563/1563 [00:48<00:00, 32.32it/s, loss=0.6401, acc=63.25%] 
Evaluating: 100%|██████████| 313/313 [01:30<00:00,  3.46it/s, loss=0.6411, acc=63.67%] 
INFO:root:Checkpoint saved: c:\Users\rufat\cf-tiny-subspaces\notebooks\..\results\split_cifar10\subspace-None\k-10\batch_size-32\wid

Training completed successfully!
Final training accuracy: 64.40%
Final validation accuracy: 62.10%


## Main training function for Split CIFAR-100.

In [None]:
# Load the YAML config as a plain text file
with open("../configs/split_cifar100.yaml", "r") as f:
    config_str = f.read()

# Replace Hydra-style placeholders and convert backslashes to forward slashes (e.g. Windows-like, change for Linux/MacOS)
cwd = os.getcwd().replace("\\", "/")
config_str = config_str.replace("${hydra:runtime.cwd}", cwd)

# Load the updated config into OmegaConf
config_scifar100 = OmegaConf.create(config_str)

# Dynamically resolve Hydra-style paths
config_scifar100.data.data_root = config_scifar100.data.data_root.replace("${hydra:runtime.cwd}", cwd)
config_scifar100.hydra.run.dir = config_scifar100.hydra.run.dir.replace("${hydra:runtime.cwd}", cwd)


# Define the main training function
def main(config):
    """
    Main training function for Split CIFAR multitask learning.
    
    Args:
        config: Configuration object loaded from YAML.
    """
    save_dir = config.hydra.run.dir
    os.makedirs(save_dir, exist_ok=True)
    logging.info(f"Saving results to {save_dir}")

    # Create sequential CIFAR-10 dataset
    split_cifar10 = CL_CIFAR100(
        classes_per_task=config.data.classes_per_task,
        num_tasks=config.data.num_tasks,
        seed=config.data.seed,
    )
    split_cifar10.setup_tasks(
        batch_size=config.data.batch_size,
        data_root=config.data.data_root,
        num_workers=config.data.num_workers,
    )

    # Combine all task datasets into one
    combined_train_dataset = ConcatDataset([
        split_cifar10.train_loaders[task_id].dataset for task_id in range(config.data.num_tasks)
    ])
    combined_test_dataset = ConcatDataset([
        split_cifar10.test_loaders[task_id].dataset for task_id in range(config.data.num_tasks)
    ])

    # Create unified dataloaders
    train_loader = DataLoader(
        combined_train_dataset,
        batch_size=config.data.batch_size,
        shuffle=True,
        num_workers=config.data.num_workers,
        pin_memory=True
    )
    
    val_loader = DataLoader(
        combined_test_dataset,
        batch_size=config.data.batch_size,
        shuffle=False,
        num_workers=config.data.num_workers,
        pin_memory=True
    )

    # Initialize model
    model = MultCNN(
        width=config.model.width,
        num_classes=config.data.num_tasks * config.data.classes_per_task,
    )
    model.to(config.training.device)

    # Initialize loss function
    criterion = nn.CrossEntropyLoss()

    # Initialize optimizer
    optimizer = SGD(
        model.parameters(),
        lr=config.optimizer.lr,
        momentum=config.optimizer.momentum,
        weight_decay=config.optimizer.weight_decay,
        nesterov=config.optimizer.nesterov,
    )

    # Initialize trainer
    trainer = MTLTrainer(
        model=model,
        optimizer=optimizer,
        criterion=criterion,
        save_dir=save_dir,
        num_epochs=config.training.num_epochs,
        log_interval=config.training.log_interval,
        eval_freq=config.training.eval_freq,
        checkpoint_freq=config.training.checkpoint_freq,
        seed=config.training.seed,
        subspace_type=config.training.subspace_type,
        scheduler=None,
        device=config.training.device,
        use_wandb=config.wandb.enabled,
        wandb_project=config.wandb.project,
        wandb_config=OmegaConf.to_container(config, resolve=True),
    )

    try:
        # Train and evaluate
        train_metrics, val_metrics = trainer.train_and_evaluate(
            train_loader=train_loader,
            val_loader=val_loader,
        )

        logging.info("Training completed successfully!")
        logging.info(f"Final Multi-task training accuracy: {train_metrics['accuracies'][-1]:.2f}%")
        logging.info(f"Final Multi-task validation accuracy: {val_metrics['accuracies'][-1]:.2f}%")

    except Exception as e:
        logging.error(f"Training failed: {str(e)}")
        raise
    finally:
        torch.cuda.empty_cache()

Run the training and evaluation for Split CIFAR-100

In [None]:
if __name__ == "__main__":
    main(config=config_scifar100)