
# Test run

Welcome to the "_Putting All Together_" tutorial of the "_From Zero to Hero_" series. In this part we will summarize the major Avalanche features and how you can put them together for your continual learning experiments.

In [1]:
!pip install avalanche-lib==0.4
!pip install einops



In [2]:
#einops -> seemless matrix operations


In [3]:

import torch
from torch import nn
import math
from einops import rearrange, repeat
from einops.layers.torch import Rearrange

# helpers

def pair(t):
    return t if isinstance(t, tuple) else (t, t)

# classes

class FeedForward(nn.Module):
    def __init__(self, dim, hidden_dim, dropout = 0.):
        super().__init__()
        self.net = nn.Sequential(
            nn.LayerNorm(dim),
            nn.Linear(dim, hidden_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, dim),
            nn.Dropout(dropout)
        )

    def forward(self, x):
        return self.net(x)

class Attention(nn.Module):
    def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
        super().__init__()
        inner_dim = dim_head *  heads
        project_out = not (heads == 1 and dim_head == dim)

        self.heads = heads
        self.scale = dim_head ** -0.5

        self.norm = nn.LayerNorm(dim)

        self.attend = nn.Softmax(dim = -1)
        self.dropout = nn.Dropout(dropout)

        self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)

        self.to_out = nn.Sequential(
            nn.Linear(inner_dim, dim),
            nn.Dropout(dropout)
        ) if project_out else nn.Identity()

    def forward(self, x):
        x = self.norm(x)

        qkv = self.to_qkv(x).chunk(3, dim = -1)
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)

        dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale

        attn = self.attend(dots)
        attn = self.dropout(attn)

        out = torch.matmul(attn, v)
        out = rearrange(out, 'b h n d -> b n (h d)')
        return self.to_out(out)

class Transformer(nn.Module):
    def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.):
        super().__init__()
        self.norm = nn.LayerNorm(dim)
        self.layers = nn.ModuleList([])
        for _ in range(depth):
            self.layers.append(nn.ModuleList([
                Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout),
                FeedForward(dim, mlp_dim, dropout = dropout)
            ]))

    def forward(self, x):
        for attn, ff in self.layers:
            x = attn(x) + x
            x = ff(x) + x

        return self.norm(x)

class ViT(nn.Module):
    def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, pool = 'cls', channels = 1, dim_head = 64, dropout = 0., emb_dropout = 0., n_examples = 3):
        super().__init__()
        image_height, image_width = pair(image_size)
        patch_height, patch_width = pair(patch_size)

        assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'

        num_patches = (image_height // patch_height) * (image_width // patch_width)
        patch_dim = channels * patch_height * patch_width
        assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)'

        self.to_patch_embedding = nn.Sequential(
            Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_height, p2 = patch_width),
            nn.LayerNorm(patch_dim),
            nn.Linear(patch_dim, dim),
            nn.LayerNorm(dim),
        )
        ### positional embedding options: 1. learned 2. sinusoidal ###
        # self.pos_embedding = nn.Parameter(torch.randn(1, (num_patches*2 + 3)*(examples), dim))
        if n_examples == 0:
            self.pos_embedding = self.sinusoidal_embeddings(num_patches * n_examples + 1, dim)
        else:
            self.pos_embedding = self.sinusoidal_embeddings((num_patches * 2 + 3) * n_examples, dim)

        self.comma_token = nn.Parameter(torch.randn(1, 1, dim)) # Token for ','
        self.arrow_token = nn.Parameter(torch.randn(1, 1, dim))  # Token for '->'
        self.pipe_token = nn.Parameter(torch.randn(1, 1, dim))   # Token for '|'
        self.dropout = nn.Dropout(emb_dropout)

        self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout)

        self.pool = pool
        self.to_latent = nn.Identity()

        self.mlp_head = nn.Linear(dim, num_classes)
        self.n_examples = n_examples
        self.dim = dim

        self.saved_imgs = None
        self.saved_labels = None

    def sinusoidal_embeddings(self, n_pos, dim):
        position = torch.arange(0, n_pos, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, dim, 2).float() * (-math.log(10000.0) / dim))
        sinusoidal_emb = torch.zeros(n_pos, dim)
        sinusoidal_emb[:, 0::2] = torch.sin(position * div_term)
        sinusoidal_emb[:, 1::2] = torch.cos(position * div_term)
        return sinusoidal_emb.unsqueeze(0)

    def forward(self, mbatch):

        ## mbatch contains labels
        if mbatch[0].dim() == 4:
            imgs = mbatch[0]
            labels = mbatch[1]
            n_examples = self.n_examples
        ## mbatch contains images only
        else:
            imgs = mbatch
            n_examples = 0

        batch_size, _, _, _ = imgs.shape

        seq = []

         ## training and testing without exemples, the pipe token can serve as the cls token used for ViT classification
        if n_examples == 0:
            seq.append(repeat(self.pipe_token, '1 1 d -> b 1 d', b=batch_size))
            seq.append(self.to_patch_embedding(imgs))
        else:
            if self.training:

                if self.saved_imgs == None:
                    self.saved_imgs = imgs[:n_examples]
                    self.saved_labels = labels[:n_examples]

                for i in range(n_examples):
                    # Shift images and labels by 1 index
                    # shifting and append aligns instances with other instances in the same batch, while these instances serve as the examples
                    shifted_imgs = torch.roll(imgs, shifts=-i-1, dims=0)
                    shifted_labels = torch.roll(labels, shifts=-i-1, dims=0)

                    # Convert labels to one-hot with length equal to embedding dimension (self.dim)
                    shifted_labels = nn.functional.one_hot(shifted_labels, num_classes=self.dim).float()


                    # append example (image, delimitor, label)
                    seq.append(self.to_patch_embedding(shifted_imgs))
                    seq.append(repeat(self.arrow_token, '1 1 d -> b 1 d', b=batch_size))
                    seq.append(shifted_labels.unsqueeze(1))

                    # append target (image)de
                    seq.append(repeat(self.pipe_token, '1 1 d -> b 1 d', b=batch_size))
                    seq.append(self.to_patch_embedding(imgs))
            else:
                # if self.saved_imgs == None:
                #     self.saved_imgs = imgs[:n_examples]
                #     self.saved_labels = labels[:n_examples]
                for i in range(n_examples):
                    # Shift images and labels by 1 index
                    # shifting and append aligns instances with other instances in the same batch, while these instances serve as the examples
                    saved_img = self.saved_imgs[i].repeat(batch_size,1,1,1)
                    saved_label = nn.functional.one_hot(self.saved_labels[i].repeat(batch_size,), num_classes=self.dim).float()

                    seq.append(self.to_patch_embedding(saved_img))
                    seq.append(repeat(self.arrow_token, '1 1 d -> b 1 d', b=batch_size))
                    seq.append(saved_label.unsqueeze(1))
                    # append target (image)de
                    seq.append(repeat(self.pipe_token, '1 1 d -> b 1 d', b=batch_size))
                    seq.append(self.to_patch_embedding(imgs))

        seq = torch.cat(seq, dim=1)

        # Add positional embeddings, ensure they match the sequence length
        x = seq + self.pos_embedding[:, :seq.size(1)]
        x = self.dropout(x)

        x = self.transformer(x)

        x = x.mean(dim = 1) if self.pool == 'mean' else x[:, 0]

        x = self.to_latent(x)
        return self.mlp_head(x)

In [4]:
from avalanche.core import SupervisedPlugin
class FewShotMetaLearningPlugin(SupervisedPlugin):
    # ... other methods ...

    def before_forward(self, strategy: "SupervisedTemplate", **kwargs):
        """
        Hook called before the forward pass of the model during training.
        Our model takes additional label information, we provide it as a (inputs, labels) tuple
        """

        inputs = strategy.mbatch[0]  # Inputs (features)
        labels = strategy.mbatch[1]  # Labels (targets)

        # Package inputs and labels in the format expected by the model
        packaged_input = (inputs, labels)

        # Set packaged_input, check the api document for other available controls. mbatch is sent as the only argument taken in the forward function
        strategy.mbatch[0] = packaged_input

    def before_eval_forward(self, strategy: "SupervisedTemplate", **kwargs):
        """
        Hook called before the forward pass of the model during testing.
        """


        inputs = strategy.mbatch[0]  # Inputs (features)
        labels = strategy.mbatch[1]  # Labels (targets)

        # Package inputs and labels in the format expected by the model
        packaged_input = (inputs, labels)

        # Set packaged_input
        strategy.mbatch[0] = packaged_input



  from .autonotebook import tqdm as notebook_tqdm


In [5]:
from torch.optim import SGD, AdamW
from torch.nn import CrossEntropyLoss
from avalanche.benchmarks.classic import SplitMNIST
from avalanche.evaluation.metrics import forgetting_metrics, accuracy_metrics, \
    loss_metrics, timing_metrics, cpu_usage_metrics, confusion_matrix_metrics, disk_usage_metrics
from avalanche.models import SimpleMLP
from avalanche.logging import InteractiveLogger, TextLogger, TensorboardLogger, WandBLogger
from avalanche.training.plugins import EvaluationPlugin
from avalanche.training.supervised import Naive
from avalanche.training.templates import SupervisedTemplate
from avalanche.training.plugins import ReplayPlugin, EWCPlugin
# there's a bug in wandb logging of the last iteration, we have to mannually finish it
import wandb

scenario = SplitMNIST(n_experiences=5, seed = 1234)

# MODEL CREATION
model = SimpleMLP(num_classes=scenario.n_classes)

model = ViT(
    image_size = 28,
    patch_size = 7,
    num_classes = 10,
    dim = 512,
    depth = 3,
    heads = 4,
    mlp_dim = 256,
    dropout = 0.1,
    emb_dropout = 0.1,
    channels = 1,
    dim_head = 64
)

# DEFINE THE EVALUATION PLUGIN and LOGGERS
# The evaluation plugin manages the metrics computation.
# It takes as argument a list of metrics, collectes their results and returns
# them to the strategy it is attached to.

loggers = []

# log to Tensorboard
# log to text file
loggers.append(TextLogger(open('log.txt', 'a')))

# print to stdout
loggers.append(InteractiveLogger())

# W&B logger - comment this if you don't have a W&B account
loggers.append(WandBLogger(project_name="avalanche", run_name="test"))

eval_plugin = EvaluationPlugin(
    accuracy_metrics(minibatch=True, epoch=True, experience=True, stream=True),
    timing_metrics(epoch=True, epoch_running=True),
    forgetting_metrics(experience=True, stream=True),
    confusion_matrix_metrics(num_classes=scenario.n_classes, save_image=False,
                             stream=True),
    loggers=loggers
)


# strategy plugign
replay = ReplayPlugin(mem_size=1000)
fsml = FewShotMetaLearningPlugin()
ewc = EWCPlugin(ewc_lambda=1)

# CREATE THE STRATEGY INSTANCE
cl_strategy = SupervisedTemplate(
    model, AdamW(model.parameters(), lr=0.001, weight_decay=0.0001),
    CrossEntropyLoss(), train_mb_size=256, train_epochs=1, eval_mb_size=256,
    evaluator=eval_plugin,plugins=[replay, fsml])

# TRAINING LOOP
print('Starting experiment...')
results = []
for experience in scenario.train_stream:
    print("Start of experience: ", experience.current_experience)
    print("Current Classes: ", experience.classes_in_this_experience)

    # train returns a dictionary which contains all the metric values
    res = cl_strategy.train(experience)
    print('Training completed')

    print('Computing accuracy on the whole test set')
    # test also returns a dictionary which contains all the metric values
    results.append(cl_strategy.eval(scenario.test_stream))

# there's a bug in wandb logging of the last iteration, we have to mannually finish it
wandb.finish()

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Logging into wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)
[34m[1mwandb[0m: You can find your API key in your browser here: https://wandb.ai/authorize
[34m[1mwandb[0m: Paste an API key from your profile and hit enter, or press ctrl+c to quit:[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: C:\Users\saira\.netrc


Starting experiment...
Start of experience:  0
Current Classes:  [4, 5]
-- >> Start of training phase << --
100%|██████████| 44/44 [00:48<00:00,  1.10s/it]
Epoch 0 ended.
	RunningTime_Epoch/train_phase/train_stream/Task000 = 0.0321
	Time_Epoch/train_phase/train_stream/Task000 = 48.2470
	Top1_Acc_Epoch/train_phase/train_stream/Task000 = 0.7083
	Top1_Acc_MB/train_phase/train_stream/Task000 = 0.9098
-- >> End of training phase << --
Training completed
Computing accuracy on the whole test set
-- >> Start of eval phase << --
-- Starting eval on experience 0 (Task 0) from test stream --
100%|██████████| 8/8 [00:02<00:00,  3.24it/s]
> Eval on experience 0 (Task 0) from test stream ended.
	Top1_Acc_Exp/eval_phase/test_stream/Task000/Exp000 = 0.9626
-- Starting eval on experience 1 (Task 0) from test stream --
100%|██████████| 8/8 [00:02<00:00,  3.04it/s]
> Eval on experience 1 (Task 0) from test stream ended.
	Top1_Acc_Exp/eval_phase/test_stream/Task000/Exp001 = 0.0000
-- Starting eval on expe

0,1
ExperienceForgetting/eval_phase/test_stream/Task000/Exp000,▁▂▄█
ExperienceForgetting/eval_phase/test_stream/Task000/Exp001,▁▃█
ExperienceForgetting/eval_phase/test_stream/Task000/Exp002,▁█
ExperienceForgetting/eval_phase/test_stream/Task000/Exp003,▁
RunningTime_Epoch/train_phase/train_stream/Task000,▃▁▁▁▁▁▁▁▃▂▂▁▁▁▁█▂▁▁▁▁▁▁▆▂▂▁▁▁▁▁▁▂▂▁▁▁▁▁▁
StreamForgetting/eval_phase/test_stream,▁▆▅▃█
Time_Epoch/train_phase/train_stream/Task000,▁▄▅█▅
Top1_Acc_Epoch/train_phase/train_stream/Task000,▁▄▆█▅
Top1_Acc_Exp/eval_phase/test_stream/Task000/Exp000,█▅▅▃▁
Top1_Acc_Exp/eval_phase/test_stream/Task000/Exp001,▁███▆

0,1
ExperienceForgetting/eval_phase/test_stream/Task000/Exp000,0.14248
ExperienceForgetting/eval_phase/test_stream/Task000/Exp001,0.20189
ExperienceForgetting/eval_phase/test_stream/Task000/Exp002,-0.0727
ExperienceForgetting/eval_phase/test_stream/Task000/Exp003,0.05779
RunningTime_Epoch/train_phase/train_stream/Task000,0.0592
StreamForgetting/eval_phase/test_stream,0.08237
Time_Epoch/train_phase/train_stream/Task000,135.65283
Top1_Acc_Epoch/train_phase/train_stream/Task000,0.77717
Top1_Acc_Exp/eval_phase/test_stream/Task000/Exp000,0.82017
Top1_Acc_Exp/eval_phase/test_stream/Task000/Exp001,0.70538
