In [None]:
%load_ext autoreload
%autoreload 2   

In [2]:
import torch
import os
import glob
import sys
import numpy as np
from natsort import natsorted

sys.path.append("../")
sys.path.append("../CL")
import mftma
from utils import factory
from utils.toolkit import count_parameters
from torch.utils.data import DataLoader
from utils.data_manager import DataManager

In [3]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [4]:
task_id = 8

In [5]:
config = {
    "prefix": "reproduce",
    "dataset": "cifar100",
    "memory_size": 2000,
    "memory_per_class": 20,
    "fixed_memory": "False",
    "shuffle": "True",
    "init_cls": 10,
    "increment": 10,
    "model_name": "ewc",
    "convnet_type": "resnet32",
    "device": [torch.device("cuda:0")],
    "seed": [1993],
}
model = factory.get_model("ewc", config)
model._network.update_fc((task_id + 1) * 10)
model._network.to(device)

IncrementalNet(
  (convnet): CifarResNet(
    (conv_1_3x3): Conv2d(3, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (bn_1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (stage_1): Sequential(
      (0): ResNetBasicblock(
        (conv_a): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn_a): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv_b): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn_b): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (1): ResNetBasicblock(
        (conv_a): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn_a): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv_b): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      

In [6]:
weights = natsorted(glob.glob("../CL/weights/*"))
task1_weights = torch.load(weights[task_id])["model_state_dict"]

In [7]:
current_model = list(model._network.state_dict().keys())
saved_model = list(task1_weights.keys())

In [8]:
new_state_dict = {
    key: value for key, value in task1_weights.items() if key in current_model
}

In [9]:
model._network.load_state_dict(new_state_dict)

<All keys matched successfully>

In [10]:
data_manager = DataManager(
    config["dataset"],
    config["shuffle"],
    config["seed"],
    config["init_cls"],
    config["increment"],
)

batch_size = 128
num_workers = 4

train_dataset = data_manager.get_dataset(
    np.arange(0, 10),
    source="train",
    mode="train",
)
train_loader = DataLoader(
    train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers
)
test_dataset = data_manager.get_dataset(
    np.arange(0, (task_id + 1) * 10), source="test", mode="test"
)
test_loader = DataLoader(
    test_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers
)

Files already downloaded and verified
Files already downloaded and verified


In [11]:
model.test_loader = test_loader
model.eval_task()

({'grouped': {'total': 0.72,
   '00-09': 0.0,
   '10-19': 1.1,
   '20-29': 0.2,
   '30-39': 0.1,
   '40-49': 0.1,
   '50-59': 1.2,
   '60-69': 1.2,
   '70-79': 1.0,
   '80-89': 1.6,
   'old': 0,
   'new': 0.72},
  'top1': 0.72,
  'top5': 4.7},
 None)

In [12]:
saved_model

['convnet.conv_1_3x3.weight',
 'convnet.bn_1.weight',
 'convnet.bn_1.bias',
 'convnet.bn_1.running_mean',
 'convnet.bn_1.running_var',
 'convnet.bn_1.num_batches_tracked',
 'convnet.stage_1.0.conv_a.weight',
 'convnet.stage_1.0.bn_a.weight',
 'convnet.stage_1.0.bn_a.bias',
 'convnet.stage_1.0.bn_a.running_mean',
 'convnet.stage_1.0.bn_a.running_var',
 'convnet.stage_1.0.bn_a.num_batches_tracked',
 'convnet.stage_1.0.conv_b.weight',
 'convnet.stage_1.0.bn_b.weight',
 'convnet.stage_1.0.bn_b.bias',
 'convnet.stage_1.0.bn_b.running_mean',
 'convnet.stage_1.0.bn_b.running_var',
 'convnet.stage_1.0.bn_b.num_batches_tracked',
 'convnet.stage_1.1.conv_a.weight',
 'convnet.stage_1.1.bn_a.weight',
 'convnet.stage_1.1.bn_a.bias',
 'convnet.stage_1.1.bn_a.running_mean',
 'convnet.stage_1.1.bn_a.running_var',
 'convnet.stage_1.1.bn_a.num_batches_tracked',
 'convnet.stage_1.1.conv_b.weight',
 'convnet.stage_1.1.bn_b.weight',
 'convnet.stage_1.1.bn_b.bias',
 'convnet.stage_1.1.bn_b.running_mean',
 '