In [1]:
from _configs import OFA_MODEL_PATH
%cd ..

/home/fsahli/vvcastro/continual-nas


In [2]:
from _constants import DATASET_N_CLASSES

from continual_learning.continual_trainers import GrowingDataContinualTrainer
from search_space.base_ofa import OFAEvaluator
from search_space import get_search_space
from tools.metrics import binary_accuracy
from _utils import set_seed

import torch.nn as nn

OFA_FAMILY = "mobilenetv3"

# These settings vary with the env we run the code
BATCH_SIZE = 32
NUM_WORKERS = 10

def main(
    dataset: str,
    n_tasks: int,
    optimiser_name: str,
    learning_rate: float,
    weight_decay: float,
    epochs_per_task: int,
    random_seed: int,
    show_progress: bool,
) -> None:
    """
    Main function to initialize the trainer and start the training process.
    """
    print(
        f"Running with D: {dataset}, n_tasks: {n_tasks}, "
        f"optimiser_name: {optimiser_name}, learning_rate: {learning_rate}, "
        f"weight_decay: {weight_decay}, epochs_per_task: {epochs_per_task}, "
        f"random_seed: {random_seed}"
    )
    set_seed(random_seed)

    search_space = get_search_space(family=OFA_FAMILY, fixed=False)
    ofa_net = OFAEvaluator(
        family=OFA_FAMILY,
        model_path=OFA_MODEL_PATH,
        data_classes=DATASET_N_CLASSES[dataset],
        pretrained=True,
    )

    # Step 0: Sample an architecture
    sampled_architecture = search_space.sample(n_samples=8)[2]
    print(f"Sampled dir: {sampled_architecture['direction']}")
    base_model = ofa_net.get_architecture_model(sampled_architecture)

    # Step 1. Load the trainer with the dataset
    trainer = GrowingDataContinualTrainer(
        capacity_tau=0.08,
        expand_is_frozen=False,
        distill_on_expand=False,
        weights_from_ofa=True,
        dataset_name=dataset,
        search_space_family=OFA_FAMILY,
        experiment_dir="tmp",
        experiment_name=f"ND{n_tasks}-ep{epochs_per_task}@{dataset}",
        model_definition=sampled_architecture,
        base_model=base_model,
        num_tasks=n_tasks,
        random_seed=random_seed,
    )

    # Load the data and the stats to use for normalisation
    trainer.load_dataset(dataset_name=dataset)

    # Set training settings
    trainer.set_experiment_settings(
        loss_fn=nn.CrossEntropyLoss(),
        epochs_per_task=epochs_per_task,
        optim_name=optimiser_name,
        optim_params={"lr": learning_rate, "weight_decay": weight_decay},
        model_size_metrics={
            "model_size": sum([p.numel() for p in base_model.parameters()]),
        },
        training_metrics={"accuracy": binary_accuracy},
        augment=False,
    )

    # Step 5: Train the model (random and continual)
    trainer.train(
        task_epochs=epochs_per_task,
        show_progress=show_progress,
        with_random_metrics=False,
        evaluate_after_task=False,
        num_workers=NUM_WORKERS,
        batch_size=BATCH_SIZE,
    )

In [3]:
model_encoding = [
    "1",
    "1",
    "1",
    "1",
    "0",
    "0",
    "1",
    "0",
    "0",
    "0",
    "0",
    "0",
    "0",
    "0",
    "0",
    "0",
    "0",
    "0",
    "1",
    "0",
    "1",
    "0",
    "0",
    "0",
    "1",
    "1",
    "0",
    "1",
    "1",
    "1",
    "1",
    "0",
    "0",
    "0",
    "1",
    "0",
    "0",
    "0",
    "1",
    "0",
    "0",
    "0",
    "1",
    "0",
    "0",
    "9",
    "1",
    "0",
    "1",
]

model_encoding = [int(bit) for bit in model_encoding]


Check the implementation:

In [3]:
main(
    dataset="cifar10",
    n_tasks=10,
    optimiser_name="adam",
    learning_rate=7.5e-4,
    weight_decay=1e-5,
    epochs_per_task=1,
    random_seed=42,
    show_progress=True,
)


Running with D: cifar10, n_tasks: 10, optimiser_name: adam, learning_rate: 0.00075, weight_decay: 1e-05, epochs_per_task: 1, random_seed: 42
Sampled dir: [1, 1, 0]


Files already downloaded and verified


Training: 100%|██████████| 157/157 [00:06<00:00, 25.15it/s]



Training metrics:
	Loss: 0.6985 ± 0.4288
	accuracy: 0.7566 ± 0.1530


Training: 100%|██████████| 157/157 [00:05<00:00, 26.79it/s]



Training metrics:
	Loss: 0.4539 ± 0.1682
	accuracy: 0.8370 ± 0.0705


Training: 100%|██████████| 157/157 [00:05<00:00, 26.63it/s]



Training metrics:
	Loss: 0.3834 ± 0.1666
	accuracy: 0.8702 ± 0.0627


Training: 100%|██████████| 157/157 [00:05<00:00, 26.72it/s]



Training metrics:
	Loss: 0.3706 ± 0.1745
	accuracy: 0.8806 ± 0.0568


Training: 100%|██████████| 157/157 [00:05<00:00, 26.74it/s]



Training metrics:
	Loss: 0.3122 ± 0.1715
	accuracy: 0.9001 ± 0.0560


Training: 100%|██████████| 157/157 [00:06<00:00, 25.53it/s]



Training metrics:
	Loss: 0.3587 ± 0.2073
	accuracy: 0.8830 ± 0.0622


Training: 100%|██████████| 157/157 [00:06<00:00, 25.36it/s]



Training metrics:
	Loss: 0.3282 ± 0.1618
	accuracy: 0.8893 ± 0.0566


Training: 100%|██████████| 157/157 [00:06<00:00, 24.26it/s]



Training metrics:
	Loss: 0.3207 ± 0.1584
	accuracy: 0.8945 ± 0.0572


Training: 100%|██████████| 157/157 [00:06<00:00, 24.15it/s]



Training metrics:
	Loss: 0.3213 ± 0.1742
	accuracy: 0.8945 ± 0.0570


Training: 100%|██████████| 157/157 [00:06<00:00, 24.15it/s]



Training metrics:
	Loss: 0.3004 ± 0.1431
	accuracy: 0.8969 ± 0.0577


: 