In [1]:
"""
In this script we will define a new backbone to be used with the Mammoth framework.

We will need:
- The `register_backbone` function to register our bakcbone.
- The `MammothBackbone` class to define our backbone.
"""

from mammoth_lite import register_backbone, MammothBackbone, load_runner, train, ReturnTypes

In [2]:
"""
Differently from models and datasets, the `register_backbone` function is not used on the backbone class itself, but rather on a function that returns an instance of the backbone class. 
This allows us to create multiple versions of the same backbone with different parameters.
"""

from torch import nn
from torch.nn import functional as F

class CustomCNN(MammothBackbone):
    def __init__(self, num_classes: int, num_channels: int = 32, input_size: int = 32):
        """
        All backbones must inherit from the `MammothBackbone` class.
        The constructor should define the layers of the backbone and any necessary parameters.

        All parameters except `num_classes` can be customized when registering the backbone.
        The `num_classes` parameter is mandatory and will be passed automatically when the backbone is loaded.
        """
        super().__init__()

        self.layer1 = nn.Conv2d(in_channels=3, out_channels=num_channels, kernel_size=3, stride=1, padding=1)
        self.layer2 = nn.Conv2d(in_channels=num_channels, out_channels=num_channels * 2, kernel_size=3, stride=1, padding=1)
        self.maxpool = nn.MaxPool2d(kernel_size=2, stride=2, padding=0)

        self.classifier = nn.Linear(num_channels * 2 * (input_size // 4) * (input_size // 4), num_classes)

    def forward(self, x, returnt: ReturnTypes = "out"):
        """
        Implement the forward pass of your custom CNN.
        
        In addition to the standard output, you can return intermediate features or logits depending on the `returnt` parameter.
        The `returnt` parameter can be one of the following:
        - "out": return the final output of the model.
        - "features": return intermediate features.
        - "both": return both the final output and intermediate features as a tuple.

        NOTE: You do not need to implement all return types, just the ones you need. 
        The mandatory return type is "out" but it is also recommended to implement "features", as it is used by many models.
        You can also define custom return types if needed.
        """

        out1 = self.maxpool(F.relu(self.layer1(x)))
        out2 = self.maxpool(F.relu(self.layer2(out1)))
        
        out2 = out2.view(out2.size(0), -1)
        logits = self.classifier(out2)
        if returnt == "out":
            return logits
        elif returnt == "features":
            return out2
        elif returnt == "both":
            return logits, out2
        else:
            raise ValueError(f"Unknown return type: {returnt}")

@register_backbone(name='custom-cnn-v1')
def custom_cnn_v1(num_classes: int):
    """
    Register the custom CNN backbone with the Mammoth framework.

    The `num_classes` parameter will be passed automatically when the backbone is loaded.
    """
    return CustomCNN(num_classes=num_classes, num_channels=32, input_size=32)

@register_backbone(name='custom-cnn-v2')
def custom_cnn_v2(num_classes: int):
    """
    Register another version of the custom CNN backbone with different parameters.

    NOTE: In the full Mammoth framework, you can add any additional parameters to the function signature and they will be available via CLI.
    """
    return CustomCNN(num_classes=num_classes, num_channels=64, input_size=32)


In [3]:
"""
Now we can use the `load_runner` function to load our model on the custom dataset.
"""

model, dataset = load_runner('sgd', 'seq-cifar10', 
                             {'lr': 0.1, 'n_epochs': 1, 'batch_size': 32, 'backbone': 'custom-cnn-v1'})
train(model, dataset)

Loading model:  sgd
- Using CustomCNN as backbone
Using device cuda


  0%|          | 0/313 [00:00<?, ?it/s]

Task 1


Evaluating Task 1: 100%|██████████| 63/63 [00:00<00:00, 168.72it/s, acc_task_1=84.7]


Accuracy for task 1	[Class-IL]: 84.70 	[Task-IL]: 84.70


  0%|          | 0/313 [00:00<?, ?it/s]

Task 2


Evaluating Task 2: 100%|██████████| 126/126 [00:00<00:00, 175.45it/s, acc_task_2=68.4]


Accuracy for task 2	[Class-IL]: 34.20 	[Task-IL]: 69.72


  0%|          | 0/313 [00:00<?, ?it/s]

Task 3


Evaluating Task 3: 100%|██████████| 189/189 [00:01<00:00, 176.29it/s, acc_task_3=58]   


Accuracy for task 3	[Class-IL]: 19.37 	[Task-IL]: 57.12


  0%|          | 0/313 [00:00<?, ?it/s]

Task 4


Evaluating Task 4: 100%|██████████| 252/252 [00:01<00:00, 173.68it/s, acc_task_4=77.2] 


Accuracy for task 4	[Class-IL]: 19.34 	[Task-IL]: 62.41


  0%|          | 0/313 [00:00<?, ?it/s]

Task 5


Evaluating Task 5: 100%|██████████| 315/315 [00:01<00:00, 169.28it/s, acc_task_5=80.3]

Accuracy for task 5	[Class-IL]: 16.06 	[Task-IL]: 60.57





In [4]:
model, dataset = load_runner('sgd', 'seq-cifar10', 
                             {'lr': 0.1, 'n_epochs': 1, 'batch_size': 32, 'backbone': 'custom-cnn-v2'})
train(model, dataset)

Loading model:  sgd
- Using CustomCNN as backbone


  0%|          | 0/313 [00:00<?, ?it/s]

Task 1


Evaluating Task 1: 100%|██████████| 63/63 [00:00<00:00, 91.58it/s, acc_task_1=82]  


Accuracy for task 1	[Class-IL]: 81.95 	[Task-IL]: 81.95


  0%|          | 0/313 [00:00<?, ?it/s]

Task 2


Evaluating Task 2: 100%|██████████| 126/126 [00:01<00:00, 90.14it/s, acc_task_2=64.3]


Accuracy for task 2	[Class-IL]: 32.15 	[Task-IL]: 58.07


  0%|          | 0/313 [00:00<?, ?it/s]

Task 3


Evaluating Task 3: 100%|██████████| 189/189 [00:02<00:00, 90.43it/s, acc_task_3=66.1] 


Accuracy for task 3	[Class-IL]: 22.05 	[Task-IL]: 60.08


  0%|          | 0/313 [00:00<?, ?it/s]

Task 4


Evaluating Task 4: 100%|██████████| 252/252 [00:02<00:00, 89.71it/s, acc_task_4=62.9]


Accuracy for task 4	[Class-IL]: 15.74 	[Task-IL]: 51.40


  0%|          | 0/313 [00:00<?, ?it/s]

Task 5


Evaluating Task 5: 100%|██████████| 315/315 [00:03<00:00, 89.71it/s, acc_task_5=69]   

Accuracy for task 5	[Class-IL]: 13.90 	[Task-IL]: 55.69



