In [12]:
import os
import sys

module_path = os.path.abspath(os.path.join('..'))
if module_path not in sys.path:
    sys.path = [module_path] + sys.path
import torch
import torch.nn as nn
import torch.nn.functional as F 
import torchvision
import numpy as np
from tqdm import tqdm
from augmentations import get_aug
import models
from tools import AverageMeter, knn_monitor, Logger, file_exist_check
from datasets import get_dataset
from datetime import datetime
from utils.loggers import *
from utils.metrics import mask_classes
from utils.loggers import CsvLogger
from datasets.utils.continual_dataset import ContinualDataset
from models.utils.continual_model import ContinualModel
from typing import Tuple
import importlib

import yaml
from arguments import Namespace
from types import SimpleNamespace
import arguments
importlib.reload(arguments)
importlib.reload(models)


<module 'models' from '/juice/scr/ananya/continual/UCL/models/__init__.py'>

In [94]:
# Load config

def get_args(config_path):
    args = {
        'debug': True,
        'debug_subset_size': 8,
        'download': True,
        'data_dir': '../Data/',
        'log_dir': False,
        'ckpt_dir': False,
        'device': 'cuda',
        'eval_from': None,
        'hide_progress': True,
        'cl_default': False,
        'validation': True,
    }
    with open(config_path, 'r') as f:
        for key, value in Namespace(yaml.load(f, Loader=yaml.FullLoader)).__dict__.items():
            args[key] = value

    args['aug_kwargs'] = {
        'name': args['model'].name,
        'image_size': args['dataset'].image_size
    }
    args['dataset_kwargs'] = {
        'dataset':args['dataset'].name,
        'data_dir': '../Data/',
        'download': False,
        'debug_subset_size': False,
    }
    args['dataloader_kwargs'] = {
        'drop_last': True,
        'pin_memory': True,
        'num_workers': args['dataset'].num_workers,
    }
    args = SimpleNamespace(**args)
    return args

def load_model(path, args, train_loader=None, dataset_copy=None):
    if train_loader is None or dataset_copy is None:
        dataset_copy = get_dataset(args)
        train_loader, _, _ = dataset_copy.get_data_loaders(args)
    state_dict = torch.load(path)['state_dict']
    updated_state_dict = {}
    for key in state_dict:
        updated_state_dict['net.'+key] = state_dict[key]
    model.load_state_dict(updated_state_dict)
    return model

def load_models(T, args, checkpoints_dir, name_base):
    device = 'cuda'
    dataset_copy = get_dataset(args)
    train_loader, _, _ = dataset_copy.get_data_loaders(args)
    models_list = []
    for m in range(T):
        model_path = checkpoints_dir + name_base + str(m) + '.pth'
        models_list.append(load_model(model_path, args, train_loader, dataset_copy))
    return models_list

def get_soup(models_list):
    soup_model = models.get_model(args, device, len(train_loader), dataset.get_transform(args))
    soup_dict = soup_model.state_dict()
    other_dicts_list = [model.state_dict() for model in models_list[-3:]]
    for key in soup_dict.keys():
        if soup_dict[key].dtype == torch.int64:
            new_val = other_dicts_list[0][key]
        else:
            new_val = torch.mean(torch.stack([d[key] for d in other_dicts_list], axis=0), axis=0)
        soup_dict[key] = new_val
    soup_model.load_state_dict(soup_dict)
    return soup_model

def get_accs(backbone):
    mean_acc = 0.0
    dataset = get_dataset(args)
    for t in range(T):
        train_loader, memory_loader, test_loader = dataset.get_data_loaders(args)
        acc, acc_mask = knn_monitor(backbone, dataset, dataset.memory_loaders[t], dataset.test_loaders[t], device, args.cl_default, task_id=t, k=min(args.train.knn_k, len(memory_loader.dataset)))
        print(acc)
        mean_acc += acc
    mean_acc = mean_acc / float(T)
    print(mean_acc)
    return mean_acc
    
class OutputEnsembler(nn.Module):
    def __init__(self, models):
        super(OutputEnsembler, self).__init__()
        self._models = models
        
    def forward(self, x):
        outputs = torch.stack([model(x) for model in self._models], axis=0)
        return torch.mean(outputs, axis=0)



In [95]:
checkpoints_dir = '/u/scr/ananya/continual/UCL/checkpoints/vanilla_simsiam_cifar10/'
name_base = 'finetune_simsiam-c10-experiment-resnet18_'
config_path = '/u/scr/ananya/continual/logs/vanilla_simsiam_cifar10/simsiam_c10.yaml'
args = get_args(config_path)
print(args)
T = 5
models_list = load_models(T, args, checkpoints_dir, name_base)

namespace(aug_kwargs={'name': 'simsiam', 'image_size': 32}, ckpt_dir=False, cl_default=False, data_dir='../Data/', dataloader_kwargs={'drop_last': True, 'pin_memory': True, 'num_workers': 4}, dataset=<arguments.Namespace object at 0x7f6f4c3380a0>, dataset_kwargs={'dataset': 'seq-cifar10', 'data_dir': '../Data/', 'download': False, 'debug_subset_size': False}, debug=True, debug_subset_size=8, device='cuda', download=True, eval=<arguments.Namespace object at 0x7f6f4c338040>, eval_from=None, hide_progress=True, log_dir=False, logger=<arguments.Namespace object at 0x7f6f4c338670>, model=<arguments.Namespace object at 0x7f6f4c338b50>, name='simsiam-c10-experiment-resnet18', seed=None, train=<arguments.Namespace object at 0x7f6f4c3384c0>, validation=True)
Files already downloaded and verified
Files already downloaded and verified


In [98]:
all_accs = [get_accs(m.net.module.backbone) for m in models_list]
print(np.mean(accs, axis=1))

Files already downloaded and verified
Files already downloaded and verified
93.58074222668003
Files already downloaded and verified
Files already downloaded and verified
79.97997997997997
Files already downloaded and verified
Files already downloaded and verified
84.6
Files already downloaded and verified
Files already downloaded and verified
93.74369323915236
Files already downloaded and verified
Files already downloaded and verified
98.02566633761106
89.9860163566847
Files already downloaded and verified
Files already downloaded and verified
93.58074222668003
Files already downloaded and verified
Files already downloaded and verified
79.97997997997997
Files already downloaded and verified
Files already downloaded and verified
84.6
Files already downloaded and verified
Files already downloaded and verified
93.74369323915236
Files already downloaded and verified
Files already downloaded and verified
98.02566633761106
89.9860163566847
Files already downloaded and verified
Files already 

In [None]:
# Soup ensembling
soup_model = get_soup(models_list)
mean_acc = get_accs(soup_model.net.module.backbone)

In [None]:
# Output ensembling
output_ensemble_model = OutputEnsembler(models_list)
mean_acc = get_accs(output_ensemble_model)

array([83.92545173, 89.24941351, 90.39585146, 89.95289182, 89.98601636])