In [20]:
import numpy as np
import pickle
import gzip
import cca_core
from CKA import linear_CKA, kernel_CKA

import torch
import math


import torch.nn as nn
import torch.nn.functional as F

from utils import factory
from utils.data_manager import DataManager
from torch.utils.data import DataLoader

### Cargamos los argumentos

In [1]:
import json
import argparse
from trainer import train


import json
import argparse

def main(dataset="cifar10", model_name="finetune", convnet_type="resnet32",
         second_task_freeze_stage=0, logfilename="experiment_freeze"):
    # Simulando la obtención de argumentos desde argparse
    class Args:
        pass
    
    args = Args()
    args.dataset = dataset
    args.model_name = model_name
    args.convnet_type = convnet_type
    args.second_task_freeze_stage = second_task_freeze_stage
    args.logfilename = logfilename
    args.config = "./exps/finetune.json"
    
    param = load_json(args.config)
    args = vars(args)  # Convirtiendo los argumentos a un diccionario.
    param.update(args)
    return param  # Puedes hacer lo que necesites con los parámetros aquí

def load_json(settings_path):
    with open(settings_path) as data_file:
        param = json.load(data_file)
    return param

# Ejecución del programa con valores predeterminados
args = main()

print(args)

  from .autonotebook import tqdm as notebook_tqdm


{'prefix': 'reproduce', 'dataset': 'cifar10', 'memory_size': 2000, 'memory_per_class': 20, 'fixed_memory': False, 'shuffle': True, 'init_cls': 5, 'increment': 5, 'model_name': 'finetune', 'convnet_type': 'resnet32', 'device': ['0'], 'seed': [1993], 'second_task_freeze_stage': 0, 'logfilename': 'experiment_freeze', 'config': './exps/finetune.json'}


### Cargemos los modelos

In [4]:
model_task_0 = torch.load('logs/finetune/experiment_freeze_stage/w0_stage_0.pth')

model_task_1 = torch.load('logs/finetune/experiment_freeze_stage/w1_stage_0.pth')

In [5]:
del model_task_0['fc.weight']; del model_task_0['fc.bias']
del model_task_1['fc.weight']; del model_task_1['fc.bias']

### Cargamos los parametros del modelos en la red

In [6]:
model_1 = factory.get_model(args["model_name"], args)
model_2 = factory.get_model(args["model_name"], args)


In [7]:
model_1._network.load_state_dict(model_task_0)
model_2._network.load_state_dict(model_task_1)

<All keys matched successfully>

### Cargamos el entrenamiento

In [22]:
# self._total_classes

data_manager = DataManager(
    args["dataset"],
    args["shuffle"],
    args["seed"],
    args["init_cls"],
    args["increment"],
    args["second_task_freeze_stage"]
)

test_dataset = data_manager.get_dataset(
    np.arange(0, 10), source="test", mode="test"
)

test_loader = DataLoader(
    test_dataset, batch_size=10_000, shuffle=False, num_workers=8
)

Files already downloaded and verified
Files already downloaded and verified


### Extractor from stage

In [25]:
for i, (_, inputs, targets) in enumerate(test_loader):
    #inputs, targets = inputs.to('cuda'), targets.to('cuda')
    m1_vectors_stage = model_1._network.extract_vector_stage(inputs)
    m2_vectors_stage = model_2._network.extract_vector_stage(inputs)
    print(i)


In [39]:
m1_vectors_stage[2].shape

torch.Size([10000, 64, 8, 8])

In [42]:
# 'fmaps': [x_1, x_2, x_3],
x_1 = m1_vectors_stage[1]
x_2 = m2_vectors_stage[1]

# Promedios
avg_x1 = x_1.mean(dim=(2, 3)).detach().numpy()  # [bs, 16]
avg_x2 = x_2.mean(dim=(2, 3)).detach().numpy()

avg_x1.shape


(10000, 32)

In [38]:
# CKA
print('Linear CKA: {}'.format(linear_CKA(avg_x1, avg_x2)))
print('RBF Kernel CKA: {}'.format(kernel_CKA(avg_x1, avg_x2)))

Linear CKA: 0.8057130688829804
RBF Kernel CKA: 0.8479895821510914


In [41]:
# CKA
print('Linear CKA: {}'.format(linear_CKA(avg_x1, avg_x2)))
print('RBF Kernel CKA: {}'.format(kernel_CKA(avg_x1, avg_x2)))

Linear CKA: 0.3665605739498378
RBF Kernel CKA: 0.3906476069310203


In [43]:
# CKA
print('Linear CKA: {}'.format(linear_CKA(avg_x1, avg_x2)))
print('RBF Kernel CKA: {}'.format(kernel_CKA(avg_x1, avg_x2)))

Linear CKA: 0.7855459658871228
RBF Kernel CKA: 0.793218121828177
