In [None]:
from google.colab import drive
drive.mount('/content/drive', force_remount=True)
BASE_DIR = "/content/drive/MyDrive/LoCA/"

In [None]:
!git clone https://github.com/Bilgecelik/LoCA.git

In [None]:
!pip install -r "LoCA/requirements.txt"
!pip install "git+https://github.com/ContinualAI/avalanche.git"
!pip install avalanche-lib[l2p]

In [None]:
# CL Setup
import torch
import argparse
from torch.nn import CrossEntropyLoss
from torch.optim import SGD
from avalanche.benchmarks.classic import PermutedMNIST
from avalanche.models import SimpleMLP
from avalanche.training.supervised import Naive
from avalanche.training.supervised.l2p import LearningToPrompt
from avalanche.logging import WandBLogger, InteractiveLogger
from avalanche.evaluation.metrics import forgetting_metrics, accuracy_metrics, \
    loss_metrics
from avalanche.training.plugins import EvaluationPlugin
import wandb
from typing import Dict
WANDB_API_KEY='71b542c3072e07c51d1184841ffc50858ab2090e'


device = torch.device("cuda:0")
print(device)
benchmark = PermutedMNIST(n_experiences=3, seed=1)

# set criteria, optimizer (config hp's), model (checkpoint)
criterion = CrossEntropyLoss()

# plot with wandb
wandb_logger = WandBLogger(project_name="LOCA Trials",
                            run_name="Notebook_trials",
                            log_artifacts=True,
                            config={
                                "dataset": benchmark,
                                "strategy": "L2P"
                            }
                            )

# evaluation plugin
eval_plugin = EvaluationPlugin(
    accuracy_metrics(minibatch=True, epoch=True, experience=True, stream=True),
    loss_metrics(minibatch=True, epoch=True, experience=True, stream=True),
    forgetting_metrics(experience=True, stream=True),
    loggers=[wandb_logger, InteractiveLogger()],
    strict_checks=False
)
# define strategy
strategy = LearningToPrompt(
    model_name='vit_large_patch16_224',
    criterion=CrossEntropyLoss(),
    train_epochs=1,
    device=device,
    evaluator=eval_plugin,
    num_classes=10,  # total # of classes in all tasks
    use_vit=True,
    lr=0.03,
    pool_size=20,
    prompt_length=5,
    top_k=5,
    sim_coefficient=0.5,  # default in avalanche is 0.1, paper is 0.5, not sensitive
)

# TRAINING LOOP
print('Starting experiment...')
results = []
for experience in benchmark.train_stream:
    # train returns a dictionary which contains all the metric values
    res = strategy.train(experience)
    print('Training completed')

    print('Evaluating on experiences until current one.')
    # test also returns a dictionary which contains all the metric values
    print(strategy.eval(benchmark.test_stream[:experience.current_experience + 1]))

wandb.finish()