In [1]:
%load_ext autoreload
%autoreload 2   

In [2]:
import numpy as np
import random
import torch
from torch.nn import CrossEntropyLoss
from torch.optim import SGD, Adam
from torchvision import datasets, transforms, models

seed = 0
np.random.seed(seed)
random.seed(seed)

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [3]:
import sys
import torchvision

sys.path.append("../")
from mftma.manifold_analysis_correlation import manifold_analysis_corr
from mftma.utils.make_manifold_data import make_manifold_data
from mftma.utils.activation_extractor import extractor
from mftma.utils.analyze_pytorch import analyze

In [4]:
from avalanche.training import EWC

from avalanche.benchmarks.classic import PermutedMNIST
from avalanche.benchmarks.classic import SplitCIFAR10, SplitCIFAR100
from avalanche.models.pytorchcv_wrapper import resnet, vgg
from avalanche.models import SimpleMLP
from avalanche.training import Naive
from avalanche.checkpointing import maybe_load_checkpoint, save_checkpoint

  from .autonotebook import tqdm as notebook_tqdm


In [6]:
model = torchvision.models.resnet18(num_classes=100).to(device)


# CL Benchmark Creation
# perm_mnist = PermutedMNIST(n_experiences=1, seed=0)
split_CIFAR = SplitCIFAR100(n_experiences=10, seed=0, return_task_id=True)
train_stream = split_CIFAR.train_stream
test_stream = split_CIFAR.test_stream

# Prepare for training & testing
optimizer = SGD(model.parameters(), lr=0.01, momentum=0.9)
# optimizer = Adam(model.parameters(), lr=0.01)
criterion = CrossEntropyLoss()

# Continual learning strategy
cl_strategy = EWC(
    model,
    optimizer,
    criterion,
    ewc_lambda=1e6,
    train_mb_size=2048,
    train_epochs=100,
    eval_mb_size=32,
    device=device,
)

# cl_strategy, initial_exp = maybe_load_checkpoint(cl_strategy, "./0_checkpoint.pth")

# train and test loop over the stream of experiences
results = []
for c, train_exp in enumerate(train_stream):
    # cl_strategy.eval(test_stream)
    cl_strategy.train(train_exp)
    save_checkpoint(cl_strategy, f"{c}_checkpoint.pth")
    # torch.save({"data": train_exp}, f"{c}_data.pth")
    # results.append(cl_strategy.eval(test_stream))

    if c == 2:
        break

Files already downloaded and verified
Files already downloaded and verified
-- >> Start of training phase << --
100%|██████████| 3/3 [00:04<00:00,  1.66s/it]
Epoch 0 ended.
	Loss_Epoch/train_phase/train_stream/Task000 = 4.5355
	Top1_Acc_Epoch/train_phase/train_stream/Task000 = 0.0384
100%|██████████| 3/3 [00:00<00:00,  3.16it/s]
Epoch 1 ended.
	Loss_Epoch/train_phase/train_stream/Task000 = 2.7223
	Top1_Acc_Epoch/train_phase/train_stream/Task000 = 0.2486
100%|██████████| 3/3 [00:01<00:00,  2.59it/s]
Epoch 2 ended.
	Loss_Epoch/train_phase/train_stream/Task000 = 1.8926
	Top1_Acc_Epoch/train_phase/train_stream/Task000 = 0.3360
100%|██████████| 3/3 [00:01<00:00,  2.79it/s]
Epoch 3 ended.
	Loss_Epoch/train_phase/train_stream/Task000 = 1.6974
	Top1_Acc_Epoch/train_phase/train_stream/Task000 = 0.3992
100%|██████████| 3/3 [00:01<00:00,  2.06it/s]
Epoch 4 ended.
	Loss_Epoch/train_phase/train_stream/Task000 = 1.5713
	Top1_Acc_Epoch/train_phase/train_stream/Task000 = 0.4368
100%|██████████| 3/3 [0

KeyboardInterrupt: 