In [None]:
"""
In this script we will define a new custom dataset to be used with the Mammoth framework.

We will need:
- The `register_dataset` function to register our dataset.
- The `ContinualDataset` basas class to inherit from.
- The `set_default_from_args` function to set default parameters from command line arguments.
- The `store_masked_loaders` function to convert the datasets into chunks for continual learning.

In addition, we will use the `base_path` function to get the path where the dataset files will be stored,
  and the `load_runner` and `train` functions to run our training process.
"""

from mammoth import register_dataset, ContinualDataset, load_runner, train, base_path, set_default_from_args, store_masked_loaders

In [None]:
"""
Before defining a Continual Learning dataset to use in Mammoth, we need a data source. 
In Mammoth this is usually done by craeting a "joint" dataset, which is a dataset that contains all the data from all tasks.
This dataset will be then split into tasks later on.
We will use the CIFAR10 dataset as our data source in this example.

The source dataset SHOULD be a subclass of `torch.utils.data.Dataset` (or implement the required `__len__` and `__getitem__` methods).

In addition, the dataset MUST define:
- `data` and `targets` attributes, which contain the training/testing data and labels respectively.
- `not_aug_transform` attribute, which is a transformation that does not apply any data augmentation.
- `__getitem__` method, which returns a tuple of (image, label, not_aug_image) where:
    - `image` is the transformed image (with data augmentation applied).
    - `label` is the label of the image.
    - `not_aug_image` is the original image without any data augmentation applied.

The `not_aug_image` is used by rehearsal methods to store the original image without any data augmentation applied.
The presence of this attribute is also the main reason why we cannot simply use the `torchvision.datasets.CIFAR10` dataset directly, as it returns only the transformed image and label.
"""

from torch.nn import CrossEntropyLoss
from torchvision.datasets import CIFAR10
from torchvision import transforms
from PIL import Image

class MammothCIFAR10(CIFAR10):
    """
    Overrides the CIFAR10 dataset to change the getitem function.

    The CIFAR10 dataset already contains the data and targets attributes, so we do not need to redefine them.
    """

    def __init__(self, root, is_train=True, transform=None) -> None:
        """
        Implementing the constructor is not strictly necessary, but it is usually required to load the data and targets in more practical scenarios where data does not simply come from torchvision.
        """
        # the `not self._check_integrity()` is just a trick to avoid printing debug messages
        self.root=root
        self.not_aug_transform = transforms.Compose([transforms.ToTensor()])
        super(MammothCIFAR10, self).__init__(root, is_train, transform, download=not self._check_integrity())  

    def __getitem__(self, index: int):
        """
        Gets the requested element from the dataset.
        """
        img, target = self.data[index], self.targets[index]

        # In order to apply data augmentation, we need to convert the image from a numpy array to a PIL Image.
        img = Image.fromarray(img, mode='RGB')
        original_img = img.copy() # if you do not copy the image, the original image will be modified by the data augmentation transformations.

        # Apply the not_aug_transform to get the original image without any data augmentation.
        not_aug_img = self.not_aug_transform(original_img)

        # Apply the transform to get the augmented image.
        if self.transform is not None:
            img = self.transform(img)

        return img, target, not_aug_img

In [None]:
@register_dataset(name='custom-cifar10')
class CustomSeqCifar10(ContinualDataset):
    """
    This is the main class that defines a custom Continual Learning dataset in Mammoth.
    It MUST inherit from `ContinualDataset` and implement the required attributes and methods.

    The required attributes are:
    - NAME: name of the dataset.
    - SETTING: setting of the dataset ('class-il','domain-il',...).
    - SIZE: size of the images in the dataset. This is usually a tuple of (height, width).
    - N_CLASSES_PER_TASK: number of classes for each task. It can be a list of integers, where each integer represents the number of classes for each task.
    - N_TASKS: number of tasks.
    - MEAN: tuple of means for each channel of the dataset.
    - STD: tuple of standard deviations for each channel of the dataset.
    - TRANSFORM: torchvision transform to apply to the dataset during *training*.
    - TEST_TRANSFORM: torchvision transform to apply to the dataset during *testing*.

    In addition, it MUST implement the `get_data_loaders` method that returns the train and test datasets and the `get_backbone` method that returns the name of the backbone architecture to use for training.
    These datasets will be those that we defined earlier, which inherit from `MammothDataset` and implement the required methods.
    The `get_data_loaders` almast always end with a call to `store_masked_loaders`, which will convert the datasets into chunks for continual learning.
    """

    NAME = 'custom-cifar10'
    SETTING = 'class-il'
    SIZE = (32, 32)
    N_CLASSES_PER_TASK = 2
    N_TASKS = 5
    MEAN, STD = (0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2615)
    TRANSFORM = transforms.Compose(
        [transforms.RandomCrop(32, padding=4),
         transforms.RandomHorizontalFlip(),
         transforms.ToTensor(),
         transforms.Normalize(MEAN, STD)])
    TEST_TRANSFORM = transforms.Compose([transforms.ToTensor(), transforms.Normalize(MEAN, STD)])

    def get_data_loaders(self):
        """
        Class method that returns the train and test loaders.
        """
        train_dataset = MammothCIFAR10(base_path() + 'CIFAR10', is_train=True, transform=self.TRANSFORM)
        test_dataset = MammothCIFAR10(base_path() + 'CIFAR10', is_train=False, transform=self.TEST_TRANSFORM)

        return store_masked_loaders(train_dataset, test_dataset, self)

    @set_default_from_args("backbone")
    def get_backbone():
        """
        The name of a registered backbone (see `create_a_backbone.ipynb` for more details).  
        """
        return "resnet18"

    def get_loss(self):
        return CrossEntropyLoss()

    def get_transform(self):
        return self.TRANSFORM

In [None]:
"""
Now we can use the `load_runner` function to load our model on the custom dataset.
"""

model, dataset = load_runner('sgd','custom-cifar10',{'lr': 0.1, 'n_epochs': 1, 'batch_size': 32})
train(model, dataset)