# Correlation matrices visualization of CIFAR-100 models

In [1]:
import numpy as np
import torch
from tqdm import tqdm
import matplotlib.pyplot as plt
import seaborn

from mdistiller.models import cifar_model_dict
from mdistiller.dataset import get_dataset
from mdistiller.engine.utils import load_checkpoint
from mdistiller.engine.cfg import CFG as cfg

In [2]:
# visualize the difference between the teacher's output logits and the student's
def get_output_metric(model, val_loader, num_classes=100):
    model.eval()
    all_preds, all_labels = [], []
    with torch.no_grad():
        for i, (data, labels) in tqdm(enumerate(val_loader)):
            outputs, _ = model(data)
            preds = outputs
            all_preds.append(preds.data.cpu().numpy())
            all_labels.append(labels.data.cpu().numpy())
    
    all_preds = np.concatenate(all_preds, 0)
    all_labels = np.concatenate(all_labels, 0)
    matrix = np.zeros((num_classes, num_classes))
    cnt = np.zeros((num_classes, 1))
    for p, l in zip(all_preds, all_labels):
        cnt[l, 0] += 1
        matrix[l] += p
    matrix /= cnt
    return matrix

def get_tea_stu_diff(tea, stu, mpath, max_diff):
    cfg.defrost()
    cfg.DISTILLER.STUDENT = stu
    cfg.DISTILLER.TEACHER = tea
    cfg.DATASET.TYPE = 'cifar100'
    cfg.freeze()
    train_loader, val_loader, num_data, num_classes = get_dataset(cfg)
    model = cifar_model_dict[cfg.DISTILLER.STUDENT][0](num_classes=num_classes)
    model.load_state_dict(load_checkpoint(mpath)["model"])
    tea_model = cifar_model_dict[cfg.DISTILLER.TEACHER][0](num_classes=num_classes)
    tea_model.load_state_dict(load_checkpoint(cifar_model_dict[cfg.DISTILLER.TEACHER][1])["model"])
    print("load model successfully!")
    ms = get_output_metric(model, val_loader)
    mt = get_output_metric(tea_model, val_loader)
    diff = np.abs((ms - mt)) / max_diff
    for i in range(100):
        diff[i, i] = 0
    print('max(diff):', diff.max())
    print('mean(diff):', diff.mean())
    seaborn.heatmap(diff, vmin=0, vmax=1.0, cmap="PuBuGn")
    plt.show()

In [3]:
# set a common max-value of the difference for fair comparsion between different methods
MAX_DIFF = 3.0

In [4]:
# KD baseline
mpath = "../../download_ckpts/best"
get_tea_stu_diff("resnet32x4", "resnet8x4", mpath, MAX_DIFF)

Files already downloaded and verified
Files already downloaded and verified


RuntimeError: Error(s) in loading state_dict for ResNet:
	Missing key(s) in state_dict: "conv1.weight", "bn1.weight", "bn1.bias", "bn1.running_mean", "bn1.running_var", "layer1.0.conv1.weight", "layer1.0.bn1.weight", "layer1.0.bn1.bias", "layer1.0.bn1.running_mean", "layer1.0.bn1.running_var", "layer1.0.conv2.weight", "layer1.0.bn2.weight", "layer1.0.bn2.bias", "layer1.0.bn2.running_mean", "layer1.0.bn2.running_var", "layer1.0.downsample.0.weight", "layer1.0.downsample.1.weight", "layer1.0.downsample.1.bias", "layer1.0.downsample.1.running_mean", "layer1.0.downsample.1.running_var", "layer2.0.conv1.weight", "layer2.0.bn1.weight", "layer2.0.bn1.bias", "layer2.0.bn1.running_mean", "layer2.0.bn1.running_var", "layer2.0.conv2.weight", "layer2.0.bn2.weight", "layer2.0.bn2.bias", "layer2.0.bn2.running_mean", "layer2.0.bn2.running_var", "layer2.0.downsample.0.weight", "layer2.0.downsample.1.weight", "layer2.0.downsample.1.bias", "layer2.0.downsample.1.running_mean", "layer2.0.downsample.1.running_var", "layer3.0.conv1.weight", "layer3.0.bn1.weight", "layer3.0.bn1.bias", "layer3.0.bn1.running_mean", "layer3.0.bn1.running_var", "layer3.0.conv2.weight", "layer3.0.bn2.weight", "layer3.0.bn2.bias", "layer3.0.bn2.running_mean", "layer3.0.bn2.running_var", "layer3.0.downsample.0.weight", "layer3.0.downsample.1.weight", "layer3.0.downsample.1.bias", "layer3.0.downsample.1.running_mean", "layer3.0.downsample.1.running_var", "fc.weight", "fc.bias". 
	Unexpected key(s) in state_dict: "module.student.conv1.weight", "module.student.bn1.weight", "module.student.bn1.bias", "module.student.bn1.running_mean", "module.student.bn1.running_var", "module.student.bn1.num_batches_tracked", "module.student.layer1.0.conv1.weight", "module.student.layer1.0.bn1.weight", "module.student.layer1.0.bn1.bias", "module.student.layer1.0.bn1.running_mean", "module.student.layer1.0.bn1.running_var", "module.student.layer1.0.bn1.num_batches_tracked", "module.student.layer1.0.conv2.weight", "module.student.layer1.0.bn2.weight", "module.student.layer1.0.bn2.bias", "module.student.layer1.0.bn2.running_mean", "module.student.layer1.0.bn2.running_var", "module.student.layer1.0.bn2.num_batches_tracked", "module.student.layer1.0.downsample.0.weight", "module.student.layer1.0.downsample.1.weight", "module.student.layer1.0.downsample.1.bias", "module.student.layer1.0.downsample.1.running_mean", "module.student.layer1.0.downsample.1.running_var", "module.student.layer1.0.downsample.1.num_batches_tracked", "module.student.layer2.0.conv1.weight", "module.student.layer2.0.bn1.weight", "module.student.layer2.0.bn1.bias", "module.student.layer2.0.bn1.running_mean", "module.student.layer2.0.bn1.running_var", "module.student.layer2.0.bn1.num_batches_tracked", "module.student.layer2.0.conv2.weight", "module.student.layer2.0.bn2.weight", "module.student.layer2.0.bn2.bias", "module.student.layer2.0.bn2.running_mean", "module.student.layer2.0.bn2.running_var", "module.student.layer2.0.bn2.num_batches_tracked", "module.student.layer2.0.downsample.0.weight", "module.student.layer2.0.downsample.1.weight", "module.student.layer2.0.downsample.1.bias", "module.student.layer2.0.downsample.1.running_mean", "module.student.layer2.0.downsample.1.running_var", "module.student.layer2.0.downsample.1.num_batches_tracked", "module.student.layer3.0.conv1.weight", "module.student.layer3.0.bn1.weight", "module.student.layer3.0.bn1.bias", "module.student.layer3.0.bn1.running_mean", "module.student.layer3.0.bn1.running_var", "module.student.layer3.0.bn1.num_batches_tracked", "module.student.layer3.0.conv2.weight", "module.student.layer3.0.bn2.weight", "module.student.layer3.0.bn2.bias", "module.student.layer3.0.bn2.running_mean", "module.student.layer3.0.bn2.running_var", "module.student.layer3.0.bn2.num_batches_tracked", "module.student.layer3.0.downsample.0.weight", "module.student.layer3.0.downsample.1.weight", "module.student.layer3.0.downsample.1.bias", "module.student.layer3.0.downsample.1.running_mean", "module.student.layer3.0.downsample.1.running_var", "module.student.layer3.0.downsample.1.num_batches_tracked", "module.student.fc.weight", "module.student.fc.bias", "module.teacher.conv1.weight", "module.teacher.bn1.weight", "module.teacher.bn1.bias", "module.teacher.bn1.running_mean", "module.teacher.bn1.running_var", "module.teacher.bn1.num_batches_tracked", "module.teacher.layer1.0.conv1.weight", "module.teacher.layer1.0.bn1.weight", "module.teacher.layer1.0.bn1.bias", "module.teacher.layer1.0.bn1.running_mean", "module.teacher.layer1.0.bn1.running_var", "module.teacher.layer1.0.bn1.num_batches_tracked", "module.teacher.layer1.0.conv2.weight", "module.teacher.layer1.0.bn2.weight", "module.teacher.layer1.0.bn2.bias", "module.teacher.layer1.0.bn2.running_mean", "module.teacher.layer1.0.bn2.running_var", "module.teacher.layer1.0.bn2.num_batches_tracked", "module.teacher.layer1.0.downsample.0.weight", "module.teacher.layer1.0.downsample.1.weight", "module.teacher.layer1.0.downsample.1.bias", "module.teacher.layer1.0.downsample.1.running_mean", "module.teacher.layer1.0.downsample.1.running_var", "module.teacher.layer1.0.downsample.1.num_batches_tracked", "module.teacher.layer1.1.conv1.weight", "module.teacher.layer1.1.bn1.weight", "module.teacher.layer1.1.bn1.bias", "module.teacher.layer1.1.bn1.running_mean", "module.teacher.layer1.1.bn1.running_var", "module.teacher.layer1.1.bn1.num_batches_tracked", "module.teacher.layer1.1.conv2.weight", "module.teacher.layer1.1.bn2.weight", "module.teacher.layer1.1.bn2.bias", "module.teacher.layer1.1.bn2.running_mean", "module.teacher.layer1.1.bn2.running_var", "module.teacher.layer1.1.bn2.num_batches_tracked", "module.teacher.layer1.2.conv1.weight", "module.teacher.layer1.2.bn1.weight", "module.teacher.layer1.2.bn1.bias", "module.teacher.layer1.2.bn1.running_mean", "module.teacher.layer1.2.bn1.running_var", "module.teacher.layer1.2.bn1.num_batches_tracked", "module.teacher.layer1.2.conv2.weight", "module.teacher.layer1.2.bn2.weight", "module.teacher.layer1.2.bn2.bias", "module.teacher.layer1.2.bn2.running_mean", "module.teacher.layer1.2.bn2.running_var", "module.teacher.layer1.2.bn2.num_batches_tracked", "module.teacher.layer1.3.conv1.weight", "module.teacher.layer1.3.bn1.weight", "module.teacher.layer1.3.bn1.bias", "module.teacher.layer1.3.bn1.running_mean", "module.teacher.layer1.3.bn1.running_var", "module.teacher.layer1.3.bn1.num_batches_tracked", "module.teacher.layer1.3.conv2.weight", "module.teacher.layer1.3.bn2.weight", "module.teacher.layer1.3.bn2.bias", "module.teacher.layer1.3.bn2.running_mean", "module.teacher.layer1.3.bn2.running_var", "module.teacher.layer1.3.bn2.num_batches_tracked", "module.teacher.layer1.4.conv1.weight", "module.teacher.layer1.4.bn1.weight", "module.teacher.layer1.4.bn1.bias", "module.teacher.layer1.4.bn1.running_mean", "module.teacher.layer1.4.bn1.running_var", "module.teacher.layer1.4.bn1.num_batches_tracked", "module.teacher.layer1.4.conv2.weight", "module.teacher.layer1.4.bn2.weight", "module.teacher.layer1.4.bn2.bias", "module.teacher.layer1.4.bn2.running_mean", "module.teacher.layer1.4.bn2.running_var", "module.teacher.layer1.4.bn2.num_batches_tracked", "module.teacher.layer2.0.conv1.weight", "module.teacher.layer2.0.bn1.weight", "module.teacher.layer2.0.bn1.bias", "module.teacher.layer2.0.bn1.running_mean", "module.teacher.layer2.0.bn1.running_var", "module.teacher.layer2.0.bn1.num_batches_tracked", "module.teacher.layer2.0.conv2.weight", "module.teacher.layer2.0.bn2.weight", "module.teacher.layer2.0.bn2.bias", "module.teacher.layer2.0.bn2.running_mean", "module.teacher.layer2.0.bn2.running_var", "module.teacher.layer2.0.bn2.num_batches_tracked", "module.teacher.layer2.0.downsample.0.weight", "module.teacher.layer2.0.downsample.1.weight", "module.teacher.layer2.0.downsample.1.bias", "module.teacher.layer2.0.downsample.1.running_mean", "module.teacher.layer2.0.downsample.1.running_var", "module.teacher.layer2.0.downsample.1.num_batches_tracked", "module.teacher.layer2.1.conv1.weight", "module.teacher.layer2.1.bn1.weight", "module.teacher.layer2.1.bn1.bias", "module.teacher.layer2.1.bn1.running_mean", "module.teacher.layer2.1.bn1.running_var", "module.teacher.layer2.1.bn1.num_batches_tracked", "module.teacher.layer2.1.conv2.weight", "module.teacher.layer2.1.bn2.weight", "module.teacher.layer2.1.bn2.bias", "module.teacher.layer2.1.bn2.running_mean", "module.teacher.layer2.1.bn2.running_var", "module.teacher.layer2.1.bn2.num_batches_tracked", "module.teacher.layer2.2.conv1.weight", "module.teacher.layer2.2.bn1.weight", "module.teacher.layer2.2.bn1.bias", "module.teacher.layer2.2.bn1.running_mean", "module.teacher.layer2.2.bn1.running_var", "module.teacher.layer2.2.bn1.num_batches_tracked", "module.teacher.layer2.2.conv2.weight", "module.teacher.layer2.2.bn2.weight", "module.teacher.layer2.2.bn2.bias", "module.teacher.layer2.2.bn2.running_mean", "module.teacher.layer2.2.bn2.running_var", "module.teacher.layer2.2.bn2.num_batches_tracked", "module.teacher.layer2.3.conv1.weight", "module.teacher.layer2.3.bn1.weight", "module.teacher.layer2.3.bn1.bias", "module.teacher.layer2.3.bn1.running_mean", "module.teacher.layer2.3.bn1.running_var", "module.teacher.layer2.3.bn1.num_batches_tracked", "module.teacher.layer2.3.conv2.weight", "module.teacher.layer2.3.bn2.weight", "module.teacher.layer2.3.bn2.bias", "module.teacher.layer2.3.bn2.running_mean", "module.teacher.layer2.3.bn2.running_var", "module.teacher.layer2.3.bn2.num_batches_tracked", "module.teacher.layer2.4.conv1.weight", "module.teacher.layer2.4.bn1.weight", "module.teacher.layer2.4.bn1.bias", "module.teacher.layer2.4.bn1.running_mean", "module.teacher.layer2.4.bn1.running_var", "module.teacher.layer2.4.bn1.num_batches_tracked", "module.teacher.layer2.4.conv2.weight", "module.teacher.layer2.4.bn2.weight", "module.teacher.layer2.4.bn2.bias", "module.teacher.layer2.4.bn2.running_mean", "module.teacher.layer2.4.bn2.running_var", "module.teacher.layer2.4.bn2.num_batches_tracked", "module.teacher.layer3.0.conv1.weight", "module.teacher.layer3.0.bn1.weight", "module.teacher.layer3.0.bn1.bias", "module.teacher.layer3.0.bn1.running_mean", "module.teacher.layer3.0.bn1.running_var", "module.teacher.layer3.0.bn1.num_batches_tracked", "module.teacher.layer3.0.conv2.weight", "module.teacher.layer3.0.bn2.weight", "module.teacher.layer3.0.bn2.bias", "module.teacher.layer3.0.bn2.running_mean", "module.teacher.layer3.0.bn2.running_var", "module.teacher.layer3.0.bn2.num_batches_tracked", "module.teacher.layer3.0.downsample.0.weight", "module.teacher.layer3.0.downsample.1.weight", "module.teacher.layer3.0.downsample.1.bias", "module.teacher.layer3.0.downsample.1.running_mean", "module.teacher.layer3.0.downsample.1.running_var", "module.teacher.layer3.0.downsample.1.num_batches_tracked", "module.teacher.layer3.1.conv1.weight", "module.teacher.layer3.1.bn1.weight", "module.teacher.layer3.1.bn1.bias", "module.teacher.layer3.1.bn1.running_mean", "module.teacher.layer3.1.bn1.running_var", "module.teacher.layer3.1.bn1.num_batches_tracked", "module.teacher.layer3.1.conv2.weight", "module.teacher.layer3.1.bn2.weight", "module.teacher.layer3.1.bn2.bias", "module.teacher.layer3.1.bn2.running_mean", "module.teacher.layer3.1.bn2.running_var", "module.teacher.layer3.1.bn2.num_batches_tracked", "module.teacher.layer3.2.conv1.weight", "module.teacher.layer3.2.bn1.weight", "module.teacher.layer3.2.bn1.bias", "module.teacher.layer3.2.bn1.running_mean", "module.teacher.layer3.2.bn1.running_var", "module.teacher.layer3.2.bn1.num_batches_tracked", "module.teacher.layer3.2.conv2.weight", "module.teacher.layer3.2.bn2.weight", "module.teacher.layer3.2.bn2.bias", "module.teacher.layer3.2.bn2.running_mean", "module.teacher.layer3.2.bn2.running_var", "module.teacher.layer3.2.bn2.num_batches_tracked", "module.teacher.layer3.3.conv1.weight", "module.teacher.layer3.3.bn1.weight", "module.teacher.layer3.3.bn1.bias", "module.teacher.layer3.3.bn1.running_mean", "module.teacher.layer3.3.bn1.running_var", "module.teacher.layer3.3.bn1.num_batches_tracked", "module.teacher.layer3.3.conv2.weight", "module.teacher.layer3.3.bn2.weight", "module.teacher.layer3.3.bn2.bias", "module.teacher.layer3.3.bn2.running_mean", "module.teacher.layer3.3.bn2.running_var", "module.teacher.layer3.3.bn2.num_batches_tracked", "module.teacher.layer3.4.conv1.weight", "module.teacher.layer3.4.bn1.weight", "module.teacher.layer3.4.bn1.bias", "module.teacher.layer3.4.bn1.running_mean", "module.teacher.layer3.4.bn1.running_var", "module.teacher.layer3.4.bn1.num_batches_tracked", "module.teacher.layer3.4.conv2.weight", "module.teacher.layer3.4.bn2.weight", "module.teacher.layer3.4.bn2.bias", "module.teacher.layer3.4.bn2.running_mean", "module.teacher.layer3.4.bn2.running_var", "module.teacher.layer3.4.bn2.num_batches_tracked", "module.teacher.fc.weight", "module.teacher.fc.bias". 

In [None]:
# Our DKD
mpath = "../../download_ckpts/best"
get_tea_stu_diff("resnet32x4", "resnet8x4", mpath, MAX_DIFF)