In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
%env CUDA_VISIBLE_DEVICES=1

In [None]:
import sys
sys.path.append('..')
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
import torch

In [None]:
%matplotlib inline
params = {
    "ytick.color" : "black",
    "xtick.color" : "black",
    "axes.labelcolor" : "w",
    "axes.edgecolor" : "w",
    
    "figure.figsize": (14, 5),
    "axes.grid": True,
    "grid.color": '0.7',
    "axes.facecolor": 'w',
    
    "axes.labelsize": 'medium',
    "xtick.labelsize": 'medium',
    "ytick.labelsize": 'medium'
    }
plt.rcParams.update(params)

## Загружаем сетки

In [None]:
from models.resnet import resnet18

def search_best_ckpt(m_folder: Path):

    ckpts = sorted(
        m_folder.rglob('*.pth'), 
        key=lambda ckpt: int(ckpt.name.split("-")[1]),
        reverse=True
    )
    
    for ckpt in ckpts:
        ckpt_type = ckpt.name.split("-")[-1]
        if ckpt_type == 'best.pth':
            return ckpt

    raise ValueError("No correct ckpts folder (need at least one ckpt with 'best' in name)")    



def load_models_from_folder(folder_path: str, bp_filt_size=None, merge_conv_bp=False, num_classes=100) -> list:
    folder_path = Path(folder_path)
    models = []
    for m_folder in sorted(folder_path.iterdir()):
        if m_folder.is_dir():
            # exist utils.most_recent_folder and most_recent_weights
            best_ckpt_path = search_best_ckpt(m_folder)
            print(best_ckpt_path)
            model = resnet18(bp_filt_size, num_classes=num_classes)
            weights = torch.load(best_ckpt_path, map_location='cpu')
            model.load_state_dict(weights)
            try:
                model.disable_distillation()
            except:
                pass
            models.append({'model': model.eval(), 'name': m_folder.name})
        
    return models

In [None]:
folder = 'checkpoint/resnet18_tiny_imagenet_x2_data_rc_aug_0.9_aug_mode_pad_log_no_w'

bp_filt_size = 3 if "lpf3" in folder else None
merge_conv_bp = "merge_conv_bp" in folder
num_classes = 200 if "tiny_imagenet" in folder else 100

distil_models = load_models_from_folder(folder, bp_filt_size=bp_filt_size, merge_conv_bp=merge_conv_bp, num_classes=num_classes)

In [None]:
len(distil_models)

In [None]:
distil_models[0]

## Валидация сеток

In [None]:
from datasets.tiny_imagenet import get_tiny_imagenet_test_dataloader

device = 'cuda'
batch_size = 128
workers = 4
shift_diag = 2*8
shuffle = False

seed = 0
np.random.seed(seed)
torch.manual_seed(seed)

In [None]:
val_loader = get_tiny_imagenet_test_dataloader(
    batch_size=batch_size,
    shuffle=shuffle,
    num_workers=workers,
)

In [None]:
image = next(iter(val_loader))[0]
image.shape

In [None]:
from logger import set_logging
set_logging() 

In [None]:
from validate_utils import AverageMeter, accuracy, agreement
from validate_utils import validate, validate_shift, validate_diagonal

from torch import nn

criterion = nn.CrossEntropyLoss().cuda()

In [None]:
validate(val_loader, distil_models[0]['model'].cuda(), criterion, print_log=True)

## Оценим работу антиалиасинга
Код был взят из скриптов автора статьи: https://github.com/adobe/antialiased-cnns/blob/master/main.py

In [None]:
import pandas as pd
from tqdm.auto import tqdm

def stability_score(diag_probs):
    """ Assessment of the stability of the prediction of the model at different image shifts 
    diag_probs:    model prediction confidence at different image shifts 
    """
    return np.mean(np.var(diag_probs[:, 1:] - diag_probs[:, :-1], axis=1))

def get_metrics(val_loader, model, out_dir, D=32, epochs_shift=5, name=''):
    model.to(device)
    with_distil = False
    if hasattr(model, 'disable_distillation') and hasattr(model, 'free_feature_list'):
        with_distil = True
        model.free_feature_list()
        model.disable_distillation()
    
    prob, top1, _ = validate_diagonal(val_loader, model, out_dir, D=D, print_log=False, name=name)
    consist = validate_shift(val_loader, model, epochs_shift, D=D, print_log=False, name=name)
    
    if with_distil:
        model.enable_distillation()
    
    return prob, top1, consist

def construct_and_save_tabel(data_dict, out_dir: Path = None):
    df = pd.DataFrame.from_dict(
        data_dict, orient='index', columns=["prob", "top1", "consist"]
    )  
    
    if out_dir:
        out_dir = Path(out_dir)
        out_dir.mkdir(exist_ok=True)
        df.to_csv(out_dir / 'validate_tabel.csv')
        
    return df

    
def validate_tabel(val_loader, mobilenet, mobilenet_antialiased, distil_models):
    data_dict = dict()
    data_dict['orig'] = get_metrics(val_loader, mobilenet, './mobilenet', name='orig')
    data_dict['aliased'] = get_metrics(val_loader, mobilenet_antialiased, './mobilenet_antialiased', name='aliased')
    for dm in tqdm(distil_models):
        # dm['name'] = exp_dw{int}_temp{int} 
        data_dict['distil' + dm['name'][3:8]] = get_metrics(val_loader, dm['model'], f'./distil/{dm["name"]}', name=dm["name"])
    
    return construct_and_save_tabel(data_dict)


def statisical_validation_model(val_loader, models, D=4, model_name='model'):
    
    out_dir = Path(f'./stat_exp_{model_name}')
    metrics = np.array([
        get_metrics(
            val_loader, 
            named_model['model'], 
            out_dir / named_model["name"], 
            name=named_model["name"],
            D=D
        )
        for named_model in tqdm(models)
    ])
    
    data_dict = dict()
    for i, metric in enumerate(tqdm(metrics)):
         data_dict[models[i]["name"]] = metric
    data_dict['mean'] = metrics.mean(0).tolist()
    data_dict['std'] = metrics.std(0).tolist()
    
    return construct_and_save_tabel(data_dict, out_dir)
    

In [None]:
val_shift_data_loader = get_tiny_imagenet_test_dataloader(
    batch_size=1,
    shuffle=shuffle,
    num_workers=workers,
    img_pad=shift_diag // 2
)

### validate_tabel

In [None]:
acc_dict = dict()
for distil_model in tqdm(distil_models):
    acc = validate(val_loader, distil_model['model'].cuda(), criterion, 
                                  print_log=False, print_acc=False)
    acc_dict[distil_model['name']] = round(acc.item(), 3)
    print(acc_dict[distil_model['name']], end = ' ')
    if hasattr(distil_model['model'], 'free_feature_list'):
        distil_model['model'].free_feature_list()

acc_dict = pd.DataFrame.from_dict(
        acc_dict, orient='index', columns=["acc"]
    )  
acc_dict.loc['mean'] = acc_dict['acc'].mean(0).tolist()
acc_dict.loc['std'] = acc_dict['acc'].std(0).tolist()

In [None]:
tab = statisical_validation_model(val_shift_data_loader, distil_models, D=shift_diag, model_name=folder)

In [None]:
tab = tab.join(acc_dict) 

In [None]:
tab

In [None]:
shift_diag

In [None]:
folder

In [None]:
distil_models[0]