In [None]:
!pip install renate

Collecting renate
  Downloading Renate-0.5.1-py3-none-any.whl (169 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m169.8/169.8 kB[0m [31m2.1 MB/s[0m eta [36m0:00:00[0m
Collecting torch<1.13.2,>=1.10.0 (from renate)
  Downloading torch-1.13.1-cp310-cp310-manylinux1_x86_64.whl (887.5 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m887.5/887.5 MB[0m [31m1.4 MB/s[0m eta [36m0:00:00[0m
Collecting boto3<1.34.3,>=1.26.0 (from renate)
  Downloading boto3-1.34.2-py3-none-any.whl (139 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m139.3/139.3 kB[0m [31m4.5 MB/s[0m eta [36m0:00:00[0m
Collecting sagemaker<2.200.2,>=2.112.0 (from renate)
  Downloading sagemaker-2.200.1-py2.py3-none-any.whl (1.4 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.4/1.4 MB[0m [31m4.5 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting syne-tune[aws,gpsearchers]<0.10.1,>=0.6.0 (from renate)
  Downloading syne_tune-0.10.0-py3-non

In [None]:
from typing import Callable, Dict, Optional

import torch
from torchmetrics import Accuracy
from torchvision.transforms import transforms

from renate import defaults
from renate.benchmark.datasets.vision_datasets import TorchVisionDataModule
from renate.benchmark.models.mlp import MultiLayerPerceptron
from renate.benchmark.scenarios import ClassIncrementalScenario, Scenario
from renate.models import RenateModule


def data_module_fn(data_path: str, chunk_id: int, seed: int = defaults.SEED) -> Scenario:
    """Returns a class-incremental scenario instance.

    The transformations passed to prepare the input data are required to convert the data to
    PyTorch tensors.
    """
    data_module = TorchVisionDataModule(
        data_path,
        dataset_name="MNIST",
        val_size=0.1,
        seed=seed,
    )

    class_incremental_scenario = ClassIncrementalScenario(
        data_module=data_module,
        groupings=((0, 1, 2, 3, 4), (5, 6, 7, 8, 9)),
        chunk_id=chunk_id,
    )
    return class_incremental_scenario


def model_fn(model_state_url: Optional[str] = None) -> RenateModule:
    """Returns a model instance."""
    if model_state_url is None:
        model = MultiLayerPerceptron(
            num_inputs=784, num_outputs=10, num_hidden_layers=2, hidden_size=128
        )
    else:
        state_dict = torch.load(model_state_url)
        model = MultiLayerPerceptron.from_state_dict(state_dict)
    return model


def train_transform() -> Callable:
    """Returns a transform function to be used in the training."""
    return transforms.Lambda(lambda x: torch.flatten(x))


def loss_fn() -> torch.nn.Module:
    return torch.nn.CrossEntropyLoss(reduction="none")


def metrics_fn() -> Dict:
    return {"accuracy": Accuracy(task="multiclass", num_classes=10)}

In [None]:
from renate.training import run_training_job


config_space = {
    "optimizer": "SGD",
    "momentum": 0.0,
    "weight_decay": 0.0,
    "learning_rate": 0.1,
    "alpha": 0.5,
    "batch_size": 64,
    "batch_memory_frac": 0.5,
    "memory_size": 500,
    "loss_normalization": 0,
    "loss_weight": 0.5,
    "early_stopping": True,
}

if __name__ == "__main__":
    # we run the first training job on the MNIST classes [0-4]
    run_training_job(
        config_space=config_space,
        mode="max",
        metric="val_accuracy",
        updater="ER",
        max_epochs=50,
        chunk_id=0,  # this selects the first chunk of the dataset
        config_file="renate_config.py",
        # this is where the model will be stored
        output_state_url="./state_dump_first_model/",
        # the training job will run on the local machine
        backend="local",
    )

    # retrieve the model from `./state_dump_first_model/` if you want
    # do not delete the model, we are going to use it below

    run_training_job(
        config_space=config_space,
        mode="max",
        metric="val_accuracy",
        updater="ER",
        max_epochs=50,
        chunk_id=1,  # this time we use the second chunk of the dataset
        config_file="renate_config.py",
        # the output of the first training job is loaded
        input_state_url="./state_dump_first_model/",
        # the new model will be stored in this folder
        output_state_url="./state_dump_second_model/",
        backend="local",
    )