# Online Continual Learning with Naive Single Head Model

## 1. Imports and Installations

In [1]:
# Ensure necessary libraries are installed and imported
# !pip install avalanche-lib==0.4.0 torch==2.1.2 matplotlib

import torch
import torch.nn as nn
import torch.optim as optim
from avalanche.benchmarks.classic import SplitCIFAR10
from avalanche.training.plugins import EvaluationPlugin
from avalanche.training.supervised import Naive
from avalanche.evaluation.metrics import accuracy_metrics, loss_metrics, timing_metrics, forgetting_metrics, cpu_usage_metrics
from avalanche.logging import InteractiveLogger
from avalanche.training.plugins import EarlyStoppingPlugin
from avalanche.benchmarks.scenarios import OnlineCLScenario
import matplotlib.pyplot as plt

# Create model function
from slim_resnet18 import SlimResNet18

def create_model():
    return SlimResNet18(nclasses=10, input_size=(3, 32, 32))


  from .autonotebook import tqdm as notebook_tqdm


## 2. Benchmark and Scenario Setup

In [2]:
# Create the benchmark
benchmark = SplitCIFAR10(n_experiences=10, seed=1, fixed_class_order=list(range(10)))


Files already downloaded and verified
Files already downloaded and verified


## 3. Model and Strategy Definition

In [3]:
# Create model
model = create_model()

# Create optimizer and criterion
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
criterion = nn.CrossEntropyLoss()

# Create evaluation plugin with desired metrics
interactive_logger = InteractiveLogger()
evaluation_plugin = EvaluationPlugin(
    accuracy_metrics(minibatch=True, epoch=True, experience=True, stream=True),
    loggers=[interactive_logger]
)

# Define the strategy
cl_strategy = Naive(
    model,
    optimizer,
    criterion,
    train_mb_size=10,
    train_epochs=1,
    eval_mb_size=10,
    evaluator=evaluation_plugin,
    eval_every=1,
)


## 4. Training and Evaluation

In [4]:
# Store results for plotting
accuracy_results = []
average_accuracy_results = []

for experience in benchmark.train_stream:
    print("Start of experience ", experience.current_experience)
    print("Current Classes ", experience.classes_in_this_experience)
    
    ocl_scenario = OnlineCLScenario(
        original_streams=benchmark.streams.values(),
        experiences=experience,
        experience_size=10,
        access_task_boundaries=True,
        shuffle=False
    )
    
    for i, minibatch in enumerate(ocl_scenario.train_stream):
        print("Minibatch: ", i+1)
        cl_strategy.train(minibatch)
        eval_results = cl_strategy.eval(benchmark.test_stream)
        
        # Collect and print accuracy
        #minibatch_accuracy = eval_results['Top1_Acc_Stream/eval_phase/test_stream']
        minibatch_accuracy = eval_results['Top1_Acc_Stream/eval_phase/test_stream/Task000']
        accuracy_results.append(minibatch_accuracy)
        print(f"Minibatch accuracy: {minibatch_accuracy:.4f}")
        break


Start of experience  0
Current Classes  [0]
Minibatch:  1
-- >> Start of training phase << --
-- >> Start of eval phase << --
-- Starting eval on experience 0 (Task 0) from train stream --
100%|██████████| 1/1 [00:00<00:00,  6.71it/s]
> Eval on experience 0 (Task 0) from train stream ended.
	Top1_Acc_Exp/eval_phase/train_stream/Task000/Exp000 = 0.0000
-- >> End of eval phase << --
	Top1_Acc_Stream/eval_phase/train_stream/Task000 = 0.0000
-- Starting training on experience 0 (Task 0) from train stream --
0it [00:00, ?it/s]-- >> Start of eval phase << --
-- Starting eval on experience 0 (Task 0) from train stream --
100%|██████████| 1/1 [00:00<00:00,  3.97it/s]
> Eval on experience 0 (Task 0) from train stream ended.
	Top1_Acc_Epoch/train_phase/train_stream/Task000 = 0.0000
	Top1_Acc_Exp/eval_phase/train_stream/Task000/Exp000 = 0.0000
	Top1_Acc_MB/train_phase/train_stream/Task000 = 0.0000
-- >> End of eval phase << --
	Top1_Acc_Stream/eval_phase/train_stream/Task000 = 0.0000
1it [00:00, 

KeyboardInterrupt: 

 73%|███████▎  | 73/100 [00:22<00:03,  9.00it/s]

## 5. Plotting Results

In [None]:
# Plot accuracy results
plt.figure(figsize=(10, 5))
plt.plot(accuracy_results, label='Minibatch Accuracy')
plt.xlabel('Minibatch')
plt.ylabel('Accuracy')
plt.title('Minibatch Accuracy Over Time')
plt.legend()
plt.show()
