In [23]:
import torch 
import torch.nn as nn
import torch.nn.functional as F

from utils.dataset import cifar10_dataset, cifar100_dataset
from models import resnet, vgg, mobilenet, googlenet, densenet

import numpy as np 
import seaborn as sns
from sklearn.metrics import normalized_mutual_info_score
from scipy.stats import pearsonr

import os 

%load_ext autoreload
%autoreload 2

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [24]:
def load_model(model, state_dict_path, print_layers=False):
    state_dict = torch.load(state_dict_path, map_location=device)
    model.load_state_dict(state_dict)
    model.to(device)
    if print_layers:
        for key in state_dict.keys():
            print(key)
    return model


# fetch embedding features of batches*64 samples
def fetch_features(model, model_ic, data_loader, batches = 1):
    
    non_IC_features = []
    IC_features = []
    
    def forward_hook(module, data_inputs, data_outputs):
        nonlocal non_IC_features

        data_inputs_copy = torch.transpose(data_inputs[0].detach(), 0, 1)
        for i, data_input in enumerate(data_inputs_copy):
            data_input = data_input.cpu().numpy()
            if len(non_IC_features) == 64:
                non_IC_features[i] = np.hstack((non_IC_features[i], data_input))
            else:
                non_IC_features.append(data_input)
            
    def forward_hook_ic(module, data_inputs, data_outputs):
        nonlocal IC_features

        data_inputs_copy = torch.transpose(data_inputs[0].detach(), 0, 1)
        for i, data_input in enumerate(data_inputs_copy):
            data_input = data_input.cpu().numpy()
            if len(IC_features) == 64:
                IC_features[i] = np.hstack((IC_features[i], data_input))
            else:
                IC_features.append(data_input)

    handler = model.linear.register_forward_hook(forward_hook)
    handler_ic = model_ic.linear.register_forward_hook(forward_hook_ic)
    
    for i, (images, labels) in enumerate(data_loader):
        images = images.to(device)
        if i==batches:
            break
        else: 
            model(images)
            model_ic(images)

    handler.remove()
    handler_ic.remove()
    
    return non_IC_features, IC_features

# compute per-sample dependence, based on metric func
def compute_dependence(func, features):
    num_feats = len(features)
    mutual_mat = np.empty([num_feats, num_feats])
    for i, feature_1 in enumerate(features):
        for j, feature_2 in enumerate(features):
            mutual_info = abs(func(feature_1, feature_2)[0])
            mutual_mat[i,j] = mutual_info
    return mutual_mat

def run_experiment(model, model_ic, data_loader, batches=20):
    non_IC_features, IC_features = fetch_features(model, model_ic, data_loader, batches)
    mat = compute_dependence(pearsonr, non_IC_features)
    mat_ic = compute_dependence(pearsonr, IC_features)
    return {"res_non_ic": mat, "res_ic": mat_ic}
    
    
    

In [25]:
%ls res/models

model_dir = "./res/models"

[1m[36mdensenet40[m[m/ [1m[36mgooglenet[m[m/  [1m[36mresnet110[m[m/  [1m[36mvgg16[m[m/


In [None]:
# resnet 110 series

# Cifar 10 

print("Experiment: model: resnet110; dataset: cifar10")

model = load_model(
    resnet.resnet110(num_classes=10),
    os.path.join(model_dir, "resnet110/cifar10_resnet110_best.pth"), 
#     True, 
)

model_ic = load_model(
    resnet.resnet110_ic(num_classes=10),
    os.path.join(model_dir, "resnet110/cifar10_resnet110_ic_best.pth"),
#     True, 
)

_, test_loader = cifar10_dataset()

res = run_experiment(model, model_ic, test_loader)
res_non_ic = res["res_non_ic"]
res_ic = res["res_ic"]

# res_non_ic
print("results of model w/o IC")
sns.heatmap(data=res_non_ic, square=True)
print(np.mean(res_non_ic))

# res_ic
print("results of model with IC")
sns.heatmap(data=res_ic, square=True)
print(np.mean(res_ic))


# Cifar 100

print("Experiment: model: resnet110; dataset: cifar100")

model = load_model(
    resnet.resnet110(num_classes=100),
    os.path.join(model_dir, "resnet110/cifar100_resnet110_best.pth"),
#     True, 
)

model_ic = load_model(
    resnet.resnet110_ic(num_classes=100),
    os.path.join(model_dir, "resnet110/cifar100_resnet110_ic_best.pth"),
#     True, 
)

_, test_loader = cifar100_dataset()

res = run_experiment(model, model_ic, test_loader)
res_non_ic = res["res_non_ic"]
res_ic = res["res_ic"]

# res_non_ic
print("results of model w/o IC")
print(res_non_ic)
sns.heatmap(data=res_non_ic, square=True)
print(np.mean(res_non_ic))

# res_ic
print("results of model with IC")
sns.heatmap(data=res_ic, square=True)
print(np.mean(res_ic))

Experiment: model: resnet110; dataset: cifar10
INFO: Creating resnet110 model
INFO: Creating resnet110 model with IC layer
INFO: Loading CIFAR10 training dataset
Files already downloaded and verified
INFO: Loading CIFAR10 test dataset
Files already downloaded and verified


In [None]:
# resnet 110 series p=0.5

# Cifar 10 

print("Experiment: model: resnet110_p05; dataset: cifar10")

model = load_model(
    resnet.resnet110(num_classes=10),
    os.path.join(model_dir, "resnet110/cifar10_resnet110_best_p05.pth"), 
#     True, 
)

model_ic = load_model(
    resnet.resnet110_ic(num_classes=10),
    os.path.join(model_dir, "resnet110/cifar10_resnet110_best_p05.pth"),
#     True, 
)

_, test_loader = cifar10_dataset()

res = run_experiment(model, model_ic, test_loader)
res_non_ic = res["res_non_ic"]
res_ic = res["res_ic"]

# res_non_ic
print("results of model w/o IC")
sns.heatmap(data=res_non_ic, square=True)
print(np.mean(res_non_ic))

# res_ic
print("results of model with IC")
sns.heatmap(data=res_ic, square=True)
print(np.mean(res_ic))

In [None]:
# VGG16 series

# Cifar 10 

print("Experiment: model: vgg16; dataset: cifar10")

model = load_model(
    resnet.resnet110(num_classes=10),
    os.path.join(model_dir, "vgg16/cifar10_vgg16_best.pth"), 
#     True, 
)

model_ic = load_model(
    resnet.resnet110_ic(num_classes=10),
    os.path.join(model_dir, "vgg16/cifar10_vgg16_ic_best.pth"),
#     True, 
)

_, test_loader = cifar10_dataset()

res = run_experiment(model, model_ic, test_loader)
res_non_ic = res["res_non_ic"]
res_ic = res["res_ic"]

# res_non_ic
print("results of model w/o IC")
sns.heatmap(data=res_non_ic, square=True)
print(np.mean(res_non_ic))

# res_ic
print("results of model with IC")
sns.heatmap(data=res_ic, square=True)
print(np.mean(res_ic))


# Cifar 100

print("Experiment: model: vgg16; dataset: cifar100")

model = load_model(
    resnet.resnet110(num_classes=100),
    os.path.join(model_dir, "vgg16/cifar100_vgg16_best.pth"),
#     True, 
)

model_ic = load_model(
    resnet.resnet110_ic(num_classes=100),
    os.path.join(model_dir, "vgg16/cifar100_vgg16_ic_best.pth"),
#     True, 
)

_, test_loader = cifar100_dataset()

res = run_experiment(model, model_ic, test_loader)
res_non_ic = res["res_non_ic"]
res_ic = res["res_ic"]

# res_non_ic
print("results of model w/o IC")
sns.heatmap(data=res_non_ic, square=True)
print(np.mean(res_non_ic))

# res_ic
print("results of model with IC")
sns.heatmap(data=res_ic, square=True)
print(np.mean(res_ic))

In [None]:
# DenseNet series

# Cifar 10 

print("Experiment: model: densenet40; dataset: cifar10")

model = load_model(
    resnet.resnet110(num_classes=10),
    os.path.join(model_dir, "densenet40/cifar10_densenet40_best.pth"), 
#     True, 
)

model_ic = load_model(
    resnet.resnet110_ic(num_classes=10),
    os.path.join(model_dir, "densenet40/cifar10_densenet40_ic_best.pth"),
#     True, 
)

_, test_loader = cifar10_dataset()

res = run_experiment(model, model_ic, test_loader)
res_non_ic = res["res_non_ic"]
res_ic = res["res_ic"]

# res_non_ic
print("results of model w/o IC")
sns.heatmap(data=res_non_ic, square=True)
print(np.mean(res_non_ic))

# res_ic
print("results of model with IC")
sns.heatmap(data=res_ic, square=True)
print(np.mean(res_ic))


# Cifar 100

print("Experiment: model: densenet40; dataset: cifar100")

model = load_model(
    resnet.resnet110(num_classes=100),
    os.path.join(model_dir, "densenet40/cifar100_densenet40_best.pth"),
#     True, 
)

model_ic = load_model(
    resnet.resnet110_ic(num_classes=100),
    os.path.join(model_dir, "densenet40/cifar100_densenet40_ic_best.pth"),
#     True, 
)

_, test_loader = cifar100_dataset()

res = run_experiment(model, model_ic, test_loader)
res_non_ic = res["res_non_ic"]
res_ic = res["res_ic"]

# res_non_ic
print("results of model w/o IC")
sns.heatmap(data=res_non_ic, square=True)
print(np.mean(res_non_ic))

# res_ic
print("results of model with IC")
sns.heatmap(data=res_ic, square=True)
print(np.mean(res_ic))

In [None]:
# GoogleNet series

# Cifar 10 

print("Experiment: model: googlenet; dataset: cifar10")

model = load_model(
    resnet.resnet110(num_classes=10),
    os.path.join(model_dir, "googlenet/cifar10_googlenet_best.pth"), 
#     True, 
)

model_ic = load_model(
    resnet.resnet110_ic(num_classes=10),
    os.path.join(model_dir, "googlenet/cifar10_googlenet_ic_best.pth"),
#     True, 
)

_, test_loader = cifar10_dataset()

res = run_experiment(model, model_ic, test_loader)
res_non_ic = res["res_non_ic"]
res_ic = res["res_ic"]

# res_non_ic
print("results of model w/o IC")
sns.heatmap(data=res_non_ic, square=True)
print(np.mean(res_non_ic))

# res_ic
print("results of model with IC")
sns.heatmap(data=res_ic, square=True)
print(np.mean(res_ic))


# Cifar 100

print("Experiment: model: googlenet; dataset: cifar100")

model = load_model(
    resnet.resnet110(num_classes=100),
    os.path.join(model_dir, "googlenet/cifar100_googlenet_best.pth"),
#     True, 
)

model_ic = load_model(
    resnet.resnet110_ic(num_classes=100),
    os.path.join(model_dir, "googlenet/cifar100_googlenet_ic_best.pth"),
#     True, 
)

_, test_loader = cifar100_dataset()

res = run_experiment(model, model_ic, test_loader)
res_non_ic = res["res_non_ic"]
res_ic = res["res_ic"]

# res_non_ic
print("results of model w/o IC")
sns.heatmap(data=res_non_ic, square=True)
print(np.mean(res_non_ic))

# res_ic
print("results of model with IC")
sns.heatmap(data=res_ic, square=True)
print(np.mean(res_ic))