In [3]:
import sys
sys.path.append('../')
import argparse
from tqdm import tqdm

from scipy.stats import entropy
import numpy as np
import torch
import torch.nn.functional as F

from timm.models import create_model

from datasets import get_dataset, build_transform

import models
import utils
from utils import get_free_gpu


num_gpus = 1
gpu_chosen = get_free_gpu(num_gpus)
device = torch.device('cuda' if torch.cuda.is_available() else "cpu")

In [None]:
args = {'input_size': 224}
args = argparse.Namespace(**args)

In [None]:
def get_accuracy(output, target, topk=(1,)):
    """ Computes the precision@k for the specified values of k """
    maxk = max(topk)
    batch_size = target.size(0)

    _, pred = output.topk(maxk, 1, True, True)
    pred = pred.t()
    # one-hot case
    if target.ndimension() > 1:
        target = target.max(1)[1]

    correct = pred.eq(target.view(1, -1).expand_as(pred))

    res = dict()
    for k in topk:
        correct_k = correct[:k].reshape(-1).float().sum(0)
        res["acc{}".format(k)] = correct_k.mul_(1.0 / batch_size).item()
    return res


def predict(teacher_model, student_model, data_loader, device='cuda'):
    teacher_model.to(device)
    student_model.to(device)
    teacher_model.eval()    
    student_model.eval()

    preds_teacher = []
    preds_student = []
    labels = []
    with torch.no_grad():
        for x, y in tqdm(data_loader):
            x, y = x.to(device, non_blocking=True), y.to(device, non_blocking=True)
            preds_teacher.append(teacher_model(x).to('cpu'))
            preds_student.append(student_model(x).to('cpu'))
            labels.append(y.to('cpu'))
    return torch.cat(preds_teacher), torch.cat(preds_student), torch.cat(labels)


def evaluate(teacher_model, student_model, teacher_datasets, student_datasets=None, batch_size=100, num_workers=4):
    if student_datasets is None:
        student_datasets = teacher_datasets
        
    if not isinstance(teacher_datasets, tuple):
        teacher_datasets = (teacher_datasets, )
        student_datasets = (student_datasets, )
        
    print(f"Teacher model: {teacher_model.__class__.__name__} with size {sum(p.numel() for p in teacher_model.parameters())}")
    print(f"Student model: {student_model.__class__.__name__} with size {sum(p.numel() for p in student_model.parameters())}")
    print()        
    results = {}
    for i, dataset in enumerate(teacher_datasets):
        d_type = "" if len(teacher_datasets) == 1 else ["train", "test"][i]
        data_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, num_workers=num_workers)
        
        print(f'Evaluate on {dataset.__class__.__name__} {d_type} data:')
        preds_teacher, preds_student, labels = predict(teacher_model, student_model, data_loader)
        num_classes = preds_teacher.shape[1]
        
        print(f'Results on {dataset.__class__.__name__} {d_type} data:')
        print('Task accuracy teacher:', get_accuracy(preds_teacher, labels)['acc1'])
        print('Task accuracy student:', get_accuracy(preds_student, labels)['acc1'])
        print('Task fidelity student:', get_accuracy(preds_teacher, preds_student)['acc1'])
        
        softmax_teacher = F.softmax(preds_teacher, 1)
        softmax_student = F.softmax(preds_student, 1)
        print('Mean relative entropy teacher:', np.mean(entropy(softmax_teacher, axis=1, base=2) / np.log2(num_classes)))
        print('Mean relative entropy student:', np.mean(entropy(softmax_student, axis=1, base=2) / np.log2(num_classes)))
        print('Mean max/min teacher:', torch.mean(softmax_teacher.max(1)[0] / softmax_teacher.min(1)[0]).item())
        print('Mean max/min student:', torch.mean(softmax_student.max(1)[0] / softmax_student.min(1)[0]).item())
        print()
        results[d_type] = (preds_teacher, preds_student, labels)
    return results

In [None]:
teacher_model = create_model(
    'googlenet',
    num_classes=200
)
teacher_model.load_state_dict(torch.load(f'checkpoints/teacher/checkpoint.pth')['model'])


student_model = create_model(
    'deit_base_patch16_224',
    num_classes=200
)
student_model.load_state_dict(torch.load(f'checkpoints/student/checkpoint.pth')['model'])

datasets = get_dataset('cubs', train_transform=build_transform(False, args), val_transform=build_transform(False, args))
results = evaluate(teacher_model, student_model, datasets)