# Latent Replay for Continual Learning on Edge devices with Efficient Architectures
In this Jupyter Notebook, we will explore various deep learning strategies for continual learning using the PyTorch framework. We will use the Avalanche library to handle the continual learning benchmarks, and we will train and evaluate different models and strategies.

## Installation
Install all the necessary library.

In [None]:
! pip install avalanche-lib==0.3.1
! pip install micromind
! pip install torchinfo

Importing the necessary libraries and setting up the environment.

In [1]:
# Standard libraries
import json

# PyTorch modules
import torch
import torch.nn as nn
from torch.optim import SGD, Adam

# Torchvision modules
import torchvision
import torchvision.transforms as transforms

# Avalanche modules
from avalanche.benchmarks.classic import SplitMNIST, SplitCIFAR10
from avalanche.models import MobilenetV1
from avalanche.training.storage_policy import ReservoirSamplingBuffer

# Custom Strategy modules
from strategy.joint_training import JointTraining
from strategy.fine_tuning import FineTuning
from strategy.comulative import Comulative
from strategy.replay import Replay
from strategy.latent_replay import LatentReplay

# Model modules
from micromind import PhiNet
from model.phinet_v2 import PhiNet_v2
from model.phinet_v3 import PhiNetV3

# Uutils modules
import utility.utils as utils
import utility.evaluations as evaluations

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Device: {device}")

torch.manual_seed(0)

  from .autonotebook import tqdm as notebook_tqdm


Device: cpu


<torch._C.Generator at 0x28994a29f90>

## Benchmark
In this section, we will prepare the data for the continual learning experiments. We will use two classic benchmark datasets: 
- SplitMNIST;
- SplitCIFAR10.

### MNIST
The SplitMNIST dataset with 5 experiences. The SplitMNIST is normailize in Avalnche with `mean = 0.1307` and `std = 0.3081`. 

In [2]:
split_mnist = SplitMNIST(n_experiences=5, seed=0, return_task_id = True)

# recovering the train and test streams
train_stream = split_mnist.train_stream
test_stream = split_mnist.test_stream

### CIFAR10
The SplitCIFAR10 dataset with 5 experiences.

In [None]:
transform = transforms.Compose([
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Resize((160, 160)),
    transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
])

split_cifar = SplitCIFAR10(n_experiences=5, seed=0, return_task_id = True, train_transform = transform, eval_transform = transform)

# recovering the train and test streams
train_stream = split_cifar.train_stream
test_stream = split_cifar.test_stream

## Training
Set up the necessary components for training our deep learning model.

In [3]:
# Loss criterion
criterion = nn.CrossEntropyLoss()

# Training and evaluation batch size parameters
train_mb_size = 32
eval_mb_size = 32

# Train-Validation split ratio
split_ratio = 0.8

# Early stopping patience
patience = 3

# Accuracies dictionary to store accuracies for each strategy
accs = dict()

### Fine Tuning Strategy

In [None]:
#model1 = PhiNet(input_shape = (1, 28, 28), alpha = 0.5, beta = 1, t_zero = 6,num_layers=7 ,include_top = True, num_classes = 10).to(device)
model1 = PhiNet.from_pretrained("CIFAR-10", 3.0, 0.75, 6.0, 7, 160, classifier=True).to(device)
optimizer1 = Adam(model1.parameters(), lr=0.01, weight_decay=0)

train_epochs = 10

fine_tuning = FineTuning(
    model=model1,
    optimizer=optimizer1,
    criterion=criterion,
    train_mb_size=train_mb_size,
    train_epochs=train_epochs,
    eval_mb_size=eval_mb_size,
    split_ratio = split_ratio,
    patience = patience,
    device=device
)

fine_tuning.train(train_stream, test_stream, plotting=True)
b,c = fine_tuning.test(test_stream)
a = fine_tuning.get_tasks_acc()
accs['Fine Tuning'] = a

### Joint Training Strategy

In [None]:
#model2 = PhiNet(input_shape = (1, 28, 28), alpha = 0.5, beta = 1, t_zero = 6,num_layers=7 ,include_top = True, num_classes = 10).to(device)
model2 = PhiNet.from_pretrained("CIFAR-10", 3.0, 0.75, 6.0, 7, 160, classifier=True).to(device)
optimizer2 = Adam(model2.parameters(), lr=0.01, weight_decay=0)

train_epochs = 10

joint_training = JointTraining(
    model=model2,
    optimizer=optimizer2,
    criterion=criterion,
    train_mb_size=train_mb_size,
    train_epochs=train_epochs,
    eval_mb_size=eval_mb_size,
    split_ratio = split_ratio,
    patience = patience,
    device=device
)

joint_training.train(train_stream, test_stream, plotting=True)
b,c = joint_training.test(test_stream)
a = joint_training.get_tasks_acc()
accs['Joint Training'] = a

### Comulative Strategy

In [None]:
#model3 = PhiNet(input_shape = (1, 28, 28), alpha = 0.5, beta = 1, t_zero = 6,num_layers=7 ,include_top = True, num_classes = 10).to(device)
model3 = PhiNet.from_pretrained("CIFAR-10", 3.0, 0.75, 6.0, 7, 160, classifier=True).to(device)

optimizer3 = Adam(model3.parameters(), lr=0.01, weight_decay=0)

train_epochs = 10

comulative = Comulative(
    model=model3,
    optimizer=optimizer3,
    criterion=criterion,
    train_mb_size=train_mb_size,
    train_epochs=train_epochs,
    eval_mb_size=eval_mb_size,
    split_ratio = split_ratio,
    patience = patience,
    device=device
)

comulative.train(train_stream, test_stream, plotting=True)
b,c = comulative.test(test_stream)
a = comulative.get_tasks_acc()
accs['Comulative'] = a

### Replay Strategy

In [None]:
#model4 = PhiNet(input_shape = (1, 28, 28), alpha = 0.5, beta = 1, t_zero = 6,num_layers=7 ,include_top = True, num_classes = 10).to(device)
model4 = PhiNet.from_pretrained("CIFAR-10", 3.0, 0.75, 6.0, 7, 160, classifier=True).to(device)

optimizer4 = Adam(model4.parameters(), lr=0.01, weight_decay=0)

storage_p = ReservoirSamplingBuffer(max_size=1500)

train_epochs = 10

replay = Replay(
    model=model4,
    optimizer=optimizer4,
    criterion=criterion,
    train_mb_size=train_mb_size,
    train_epochs=train_epochs,
    eval_mb_size=eval_mb_size,
    storage_policy = storage_p,
    split_ratio = split_ratio,
    patience = patience,
    device=device
)

replay.train(train_stream, test_stream, plotting=True)
b,c = replay.test(test_stream)
a = replay.get_tasks_acc()
accs['ExpReplay'] = a

### Latent Replay Strategy

In [4]:
#model4 = PhiNet.from_pretrained("CIFAR-10", 3.0, 0.75, 6.0, 7, 160, classifier=False).to(device)
model4 = PhiNet(input_shape = (1, 28, 28), alpha = 0.5, beta = 1, t_zero = 6,num_layers=7 ,include_top = False, num_classes = 10).to(device)
model4.load_state_dict(torch.load("TestModel/7_Layers/Adam.pth", map_location=torch.device(device)))

model4 = PhiNetV3(model4, latent_layer_num = 9, replace_bn_with_brn = True).to(device)
optimizer4 = Adam(model4.parameters(), lr=0.01, weight_decay=0)

train_epochs = 1

latent_replay = LatentReplay(
    model = model4,
    optimizer = optimizer4,
    criterion = criterion,
    train_mb_size = 21,
    replay_mb_size = 107,
    train_epochs = train_epochs,
    eval_mb_size = eval_mb_size,
    rm_size_MB = 3,
    manual_mb = True,
    split_ratio = split_ratio,
    patience = patience,
    device = device
)

latent_replay.train(train_stream, plotting=False)
#b,c = latent_replay.test(test_stream)
a = latent_replay.get_tasks_acc()
accs['LatentReplay'] = a

mac = evaluations.get_MAC(model4, (1,28,28))
print(f"MAC: {mac}")

mean, std = evaluations.measure_inference_time(model4, (1,28,28))
print(f"Average inference time: {mean:.3f} +/- {std:.3f} ms")

Start of the training process...
Training of the experience with class:  [1, 4]
Train dataset size:  10067
Validation dataset size:  2517
Validation loss decreased (inf --> 0.049161).  Saving model ...
Epoch: 1/1, Train Loss: 11.0085, Train Accuracy: 98.41%
Early stopping reset.
Number of element in the replay memory: 4053
-----------------------------------------------------------------------------------
Training of the experience with class:  [5, 7]
Train dataset size:  9348
Validation dataset size:  2338
Validation loss decreased (inf --> 0.059193).  Saving model ...
Epoch: 1/1, Train Loss: 7.4946, Train Accuracy: 98.55%
Early stopping reset.
Number of element in the replay memory: 4052
-----------------------------------------------------------------------------------
Training of the experience with class:  [9, 3]
Train dataset size:  9664
Validation dataset size:  2416
Validation loss decreased (inf --> 0.220718).  Saving model ...
Epoch: 1/1, Train Loss: 15.9549, Train Accuracy: 

In [None]:
# Save the dictionary to a JSON file
with open('./results/result.json', 'w') as file:
    json.dump(accs, file)

In [None]:
with open('./results/result.json', 'r') as file:
    accs = json.load(file)

plotter = utils.TaskAccuracyPlotter()

for key, value in accs.items():
    _ = plotter.plot_task_accuracy(value, label=key, plot_task_acc=True, plot_avg_acc=True, plot_encountered_avg=True)
plotter.show_figures()