## CIFAR10 Hparam search

In [1]:
import torch
import numpy as np
import os, sys
import shutil
import pickle
import cv2
from tqdm import tqdm
from sklearn.cluster import KMeans
from datetime import datetime

sys.path.insert(0, 'src')
from utils.utils import ensure_dir, read_json, informal_log
from utils.visualizations import plot
from utils.model_utils import prepare_device

import model.metric as module_metric
import model.loss as module_loss
import datasets.datasets as module_data
import model.model as module_model

from train import main as train_fn
from predict import predict
from parse_config import ConfigParser


sys.path.insert(0, 'setup')
from setup_cifar10 import setup_cifar10 
# import cv2
# print(cv2.__version__)

In [2]:
config_path = 'configs/train_cifar10_sift_explainer.json'
debug = False
if debug:
    learning_rates = [1e-6] #, 1e-5, 1e-4, 1e-3, 5e-2, 1e-2, 5e-1, 1e-1]
    weight_decays = [0, 1e-1] #, 1e-2, 1e-3]
else:
    learning_rates = [1e-4, 1e-3, 5e-2, 1e-2, 5e-1, 1e-1]
    weight_decays = [0, 1e-1, 1e-2, 1e-3]

config_json = read_json(config_path)


### Set up data loaders

In [3]:
dataset_args = config_json['dataset']['args']
train_descriptors_dataset = module_data.KDDataset(split='train', **dataset_args)
test_descriptors_dataset = module_data.KDDataset(split='test', **dataset_args)

dataloader_args = config_json['data_loader']['args']
train_descriptors_dataloader = torch.utils.data.DataLoader(
    train_descriptors_dataset,
    shuffle=True,
    **dataloader_args)
test_descriptors_dataloader = torch.utils.data.DataLoader(
    test_descriptors_dataset,
    shuffle=False,
    **dataloader_args)

device, device_ids = prepare_device(config_json['n_gpu'])
metric_fns = [getattr(module_metric, met) for met in config_json['metrics']]
loss_fn = getattr(module_loss, config_json['loss'])

### Run Hyperparameter search

In [None]:
best = {
    'lr': -1,
    'wd': -1,
    'val_acc': -1
}
n_trials = len(learning_rates) * len(weight_decays)
trial_idx = 1
timestamp = datetime.now().strftime(r'%m%d_%H%M%S')

# Logging
log_path = os.path.join(config_json['trainer']['save_dir'], timestamp, 'log.txt')
ensure_dir(os.path.dirname(log_path))
informal_log("Hyperparameter search", log_path)
informal_log("Learning rates: {}".format(learning_rates), log_path)
informal_log("Weight decays: {}".format(weight_decays), log_path)

# 

# Debug mode
if debug:
    config_json['trainer']['epochs'] = 1
    
for lr in learning_rates:
    for wd in weight_decays:
        # Update config json
        config_json['optimizer']['args'].update({
            'lr': lr,
            'weight_decay': wd
        })
        
        # Create run ID for trial
        itr_timestamp = datetime.now().strftime(r'%m%d_%H%M%S')
        informal_log("[{}] Trial {}/{}: LR = {} WD = {}".format(
            itr_timestamp, trial_idx, n_trials, lr, wd), log_path)
        run_id = os.path.join(timestamp, 'trials', 'lr_{}-wd_{}'.format(lr, wd))
        config = ConfigParser(config_json, run_id=run_id)
        print(config.config['optimizer']['args'])
        
        # Train model
        model = train_fn(
            config=config, 
            train_data_loader=train_descriptors_dataloader,
            val_data_loader=test_descriptors_dataloader)
        
        # Restore model
        model_restore_path = os.path.join(config.save_dir, 'model_best.pth')
        
        model.restore_model(model_restore_path)
        print("restored model")
        # Run on validation set using predict function
        
        trial_path = os.path.dirname(os.path.dirname(model_restore_path))
        output_save_path = os.path.join(trial_path, "val_outputs.pth")
        log_save_path = os.path.join(trial_path, "val_metrics.pth")
        
        validation_data = predict(
            data_loader=test_descriptors_dataloader,
            model=model,
            metric_fns=metric_fns,
            device=device,
            loss_fn=loss_fn,
            output_save_path=output_save_path,
            log_save_path=log_save_path)
       
        # Obtain accuracy and compare to previous best
        print(validation_data['metrics'].keys())
        val_accuracy = validation_data['metrics']['accuracy']
        if val_accuracy > best['val_acc']:
            best.update({
                'lr': lr,
                'wd': wd,
                'val_acc': val_accuracy
            })
            informal_log("Best accuracy of {:.3f} with lr={} and wd={}".format(val_accuracy, lr, wd), log_path)
            informal_log("Trial path: {}".format(trial_path), log_path)
            # Copy model and outputs to 1 directory for easy access
            best_save_dir = os.path.join(os.path.dirname(os.path.dirname(trial_path)), 'best')
            ensure_dir(best_save_dir)
            best_outputs_save_path = os.path.join(best_save_dir, 'outputs.pth')
            best_model_save_path = os.path.join(best_save_dir, 'model.pth')
            torch.save(validation_data['logits'], best_outputs_save_path)
            model.save_model(best_model_save_path)
            informal_log("Saved model and outputs to {}".format(best_save_dir), log_path)
            
            
        trial_idx += 1

Hyperparameter search
Learning rates: [0.0001, 0.001, 0.05, 0.01, 0.5, 0.1]
Weight decays: [0, 0.1, 0.01, 0.001]
[0526_132630] Trial 1/24: LR = 0.0001 WD = 0
OrderedDict([('lr', 0.0001), ('weight_decay', 0), ('amsgrad', False)])
Created LinearLayers model with 510 trainable parameters
Training from scratch.
Checkpoint save directory: saved/cifar10/resnet18/explainer/sift_descriptor_histogram/cifar128_128/hparam_search/0526_132630/trials/lr_0.0001-wd_0/models
    epoch          : 1
    val_TP         : [  0 297  21  25 657   0   2  29   0   8]
    val_TN         : [9000 6563 8935 8875 3195 9000 8943 8620 9000 8908]
    val_FPs        : [   0 2437   65  125 5805    0   57  380    0   92]
    val_FNs        : [1000  703  979  975  343 1000  998  971 1000  992]
    val_accuracy   : 0.1039
    val_per_class_accuracy: [0.9    0.686  0.8956 0.89   0.3852 0.9    0.8945 0.8649 0.9    0.8916]
    val_per_class_accuracy_mean: 0.8207800000000001
    val_precision  : [0.         0.10863204 0.244186

  return np.nan_to_num(TPs / (TPs + FPs))
  return np.nan_to_num(2 * precisions * recalls / (precisions + recalls))


    epoch          : 2
    val_TP         : [  0 358  44  59 626   0   2  43   0  26]
    val_TN         : [9000 6540 8861 8720 3813 9000 8902 8526 9000 8796]
    val_FPs        : [   0 2460  139  280 5187    0   98  474    0  204]
    val_FNs        : [1000  642  956  941  374 1000  998  957 1000  974]
    val_accuracy   : 0.1158
    val_per_class_accuracy: [0.9    0.6898 0.8905 0.8779 0.4439 0.9    0.8904 0.8569 0.9    0.8822]
    val_per_class_accuracy_mean: 0.82316
    val_precision  : [0.         0.12704045 0.24043716 0.1740413  0.10768966 0.
 0.02       0.08317215 0.         0.11304348]
    val_precision_mean: 0.08654241969951267
    val_recall     : [0.    0.358 0.044 0.059 0.626 0.    0.002 0.043 0.    0.026]
    val_recall_mean: 0.11580000000000001
    val_predicted_class_distribution: [   0 2818  183  339 5813    0  100  517    0  230]
    val_f1         : [0.         0.18753274 0.07438715 0.08812547 0.18376633 0.
 0.00363636 0.05669084 0.         0.04227642]
    val_f1_mean 

  return np.nan_to_num(TPs / (TPs + FPs))
  return np.nan_to_num(2 * precisions * recalls / (precisions + recalls))


    epoch          : 3
    val_TP         : [  0 405  74  91 591   0   3  47   0  52]
    val_TN         : [9000 6452 8747 8567 4496 9000 8857 8475 9000 8669]
    val_FPs        : [   0 2548  253  433 4504    0  143  525    0  331]
    val_FNs        : [1000  595  926  909  409 1000  997  953 1000  948]
    val_accuracy   : 0.1263
    val_per_class_accuracy: [0.9    0.6857 0.8821 0.8658 0.5087 0.9    0.886  0.8522 0.9    0.8721]
    val_per_class_accuracy_mean: 0.8252600000000001
    val_precision  : [0.         0.13714866 0.22629969 0.17366412 0.11599607 0.
 0.02054795 0.08216783 0.         0.13577023]
    val_precision_mean: 0.08915945656474318
    val_recall     : [0.    0.405 0.074 0.091 0.591 0.    0.003 0.047 0.    0.052]
    val_recall_mean: 0.12630000000000002
    val_predicted_class_distribution: [   0 2953  327  524 5095    0  146  572    0  383]
    val_f1         : [0.         0.20490767 0.11152977 0.11942257 0.19392945 0.
 0.0052356  0.05979644 0.         0.07519884]
    v

  return np.nan_to_num(TPs / (TPs + FPs))
  return np.nan_to_num(2 * precisions * recalls / (precisions + recalls))


    epoch          : 4
    val_TP         : [  0 428 114 131 560   0  14  64   0  84]
    val_TN         : [9000 6577 8606 8389 5083 9000 8810 8414 8999 8517]
    val_FPs        : [   0 2423  394  611 3917    0  190  586    1  483]
    val_FNs        : [1000  572  886  869  440 1000  986  936 1000  916]
    val_accuracy   : 0.1395
    val_per_class_accuracy: [0.9    0.7005 0.872  0.852  0.5643 0.9    0.8824 0.8478 0.8999 0.8601]
    val_per_class_accuracy_mean: 0.8279
    val_precision  : [0.         0.15012276 0.22440945 0.17654987 0.12508376 0.
 0.06862745 0.09846154 0.         0.14814815]
    val_precision_mean: 0.09914029770279609
    val_recall     : [0.    0.428 0.114 0.131 0.56  0.    0.014 0.064 0.    0.084]
    val_recall_mean: 0.1395
    val_predicted_class_distribution: [   0 2851  508  742 4477    0  204  650    1  567]
    val_f1         : [0.         0.22227993 0.15119363 0.15040184 0.20449151 0.
 0.02325581 0.07757576 0.         0.10721123]
    val_f1_mean    : 0.0936409

  return np.nan_to_num(TPs / (TPs + FPs))
  return np.nan_to_num(2 * precisions * recalls / (precisions + recalls))


    epoch          : 5
    val_TP         : [  0 441 148 163 520   0  22  74   1 131]
    val_TN         : [9000 6748 8467 8246 5558 9000 8738 8406 8998 8339]
    val_FPs        : [   0 2252  533  754 3442    0  262  594    2  661]
    val_FNs        : [1000  559  852  837  480 1000  978  926  999  869]
    val_accuracy   : 0.15
    val_per_class_accuracy: [0.9    0.7189 0.8615 0.8409 0.6078 0.9    0.876  0.848  0.8999 0.847 ]
    val_per_class_accuracy_mean: 0.8299999999999998
    val_precision  : [0.         0.16375789 0.21732746 0.17775354 0.13124685 0.
 0.07746479 0.11077844 0.33333333 0.16540404]
    val_precision_mean: 0.13770663452233436
    val_recall     : [0.    0.441 0.148 0.163 0.52  0.    0.022 0.074 0.001 0.131]
    val_recall_mean: 0.14999999999999997
    val_predicted_class_distribution: [   0 2693  681  917 3962    0  284  668    3  792]
    val_f1         : [0.         0.23883022 0.17608566 0.17005738 0.20959291 0.
 0.03426791 0.08872902 0.00199402 0.14620536]
    val

  return np.nan_to_num(TPs / (TPs + FPs))
  return np.nan_to_num(2 * precisions * recalls / (precisions + recalls))


    epoch          : 6
    val_TP         : [  0 432 187 200 463   0  35  94   1 179]
    val_TN         : [9000 6896 8323 8075 6091 9000 8663 8398 8998 8147]
    val_FPs        : [   0 2104  677  925 2909    0  337  602    2  853]
    val_FNs        : [1000  568  813  800  537 1000  965  906  999  821]
    val_accuracy   : 0.1591
    val_per_class_accuracy: [0.9    0.7328 0.851  0.8275 0.6554 0.9    0.8698 0.8492 0.8999 0.8326]
    val_per_class_accuracy_mean: 0.8318199999999999
    val_precision  : [0.         0.170347   0.21643519 0.17777778 0.13730724 0.
 0.09408602 0.13505747 0.33333333 0.17344961]
    val_precision_mean: 0.14377936406854
    val_recall     : [0.    0.432 0.187 0.2   0.463 0.    0.035 0.094 0.001 0.179]
    val_recall_mean: 0.1591
    val_predicted_class_distribution: [   0 2536  864 1125 3372    0  372  696    3 1032]
    val_f1         : [0.         0.24434389 0.20064378 0.18823529 0.21180238 0.
 0.05102041 0.11084906 0.00199402 0.1761811 ]
    val_f1_mean    : 

  return np.nan_to_num(TPs / (TPs + FPs))
  return np.nan_to_num(2 * precisions * recalls / (precisions + recalls))


    epoch          : 7
    val_TP         : [  0 417 208 229 410   0  52 106   5 217]
    val_TN         : [9000 7059 8202 7894 6521 9000 8572 8380 8993 8023]
    val_FPs        : [   0 1941  798 1106 2479    0  428  620    7  977]
    val_FNs        : [1000  583  792  771  590 1000  948  894  995  783]
    val_accuracy   : 0.1644
    val_per_class_accuracy: [0.9    0.7476 0.841  0.8123 0.6931 0.9    0.8624 0.8486 0.8998 0.824 ]
    val_per_class_accuracy_mean: 0.8328800000000001
    val_precision  : [0.         0.17684478 0.20675944 0.17153558 0.14191762 0.
 0.10833333 0.14600551 0.41666667 0.18174204]
    val_precision_mean: 0.1549804979325412
    val_recall     : [0.    0.417 0.208 0.229 0.41  0.    0.052 0.106 0.005 0.217]
    val_recall_mean: 0.1644
    val_predicted_class_distribution: [   0 2358 1006 1335 2889    0  480  726   12 1194]
    val_f1         : [0.         0.24836212 0.20737787 0.19614561 0.21085112 0.
 0.07027027 0.12282735 0.00988142 0.19781222]
    val_f1_mean    

  return np.nan_to_num(TPs / (TPs + FPs))
  return np.nan_to_num(2 * precisions * recalls / (precisions + recalls))


    epoch          : 8
    val_TP         : [  0 396 231 248 379   0  74 114  21 258]
    val_TN         : [8999 7251 8086 7807 6863 8998 8477 8367 8980 7893]
    val_FPs        : [   1 1749  914 1193 2137    2  523  633   20 1107]
    val_FNs        : [1000  604  769  752  621 1000  926  886  979  742]
    val_accuracy   : 0.1721
    val_per_class_accuracy: [0.8999 0.7647 0.8317 0.8055 0.7242 0.8998 0.8551 0.8481 0.9001 0.8151]
    val_per_class_accuracy_mean: 0.8344199999999999
    val_precision  : [0.         0.18461538 0.20174672 0.17210271 0.15063593 0.
 0.1239531  0.15261044 0.51219512 0.18901099]
    val_precision_mean: 0.1686870397564508
    val_recall     : [0.    0.396 0.231 0.248 0.379 0.    0.074 0.114 0.021 0.258]
    val_recall_mean: 0.17209999999999998
    val_predicted_class_distribution: [   1 2145 1145 1441 2516    2  597  747   41 1365]
    val_f1         : [0.         0.2518283  0.21538462 0.20319541 0.21558589 0.
 0.09267376 0.13050944 0.04034582 0.21818182]
    va

  return np.nan_to_num(2 * precisions * recalls / (precisions + recalls))


    epoch          : 9
    val_TP         : [  1 388 253 262 327   1 102 117  37 290]
    val_TN         : [8994 7321 7982 7741 7207 8990 8387 8404 8948 7804]
    val_FPs        : [   6 1679 1018 1259 1793   10  613  596   52 1196]
    val_FNs        : [999 612 747 738 673 999 898 883 963 710]
    val_accuracy   : 0.1778
    val_per_class_accuracy: [0.8995 0.7709 0.8235 0.8003 0.7534 0.8991 0.8489 0.8521 0.8985 0.8094]
    val_per_class_accuracy_mean: 0.8355600000000001
    val_precision  : [0.14285714 0.18771166 0.19905586 0.1722551  0.15424528 0.09090909
 0.14265734 0.16409537 0.41573034 0.19515478]
    val_precision_mean: 0.18646719623855698
    val_recall     : [0.001 0.388 0.253 0.262 0.327 0.001 0.102 0.117 0.037 0.29 ]
    val_recall_mean: 0.1778
    val_predicted_class_distribution: [   7 2067 1271 1521 2120   11  715  713   89 1486]
    val_f1         : [0.0019861  0.25301598 0.22280934 0.20785403 0.20961538 0.00197824
 0.11895044 0.13660245 0.06795225 0.23330652]
    val_f1_m

### Check best model's results on training data

In [8]:
# Trained on soft labels with 75 clusters
# model_restore_path = 'saved/cifar10/resnet18/explainer/sift_descriptor_histogram/hparam_search/0525_103130/best/model.pth'
# Trained on hard GT labels with 75 clusters
model_restore_path = 'saved/cifar10/resnet18/explainer/sift_descriptor_histogram/true_labels/75_clusters/models/model_best.pth'
# Variable
model_restore_path = 'saved/cifar10/resnet18/explainer/sift_descriptor_histogram/mlp_512/true_labels/sift32_32_75means/models/model_best.pth'
model_args = config_json['arch']['args']
model = module_model.LinearLayers(
    checkpoint_path=model_restore_path,
    **model_args)

model.eval()
model = model.to(device)

# Evaluate on validation set as a sanity check
validation_results = predict(
        data_loader=test_descriptors_dataloader,
        model=model,
        metric_fns=metric_fns,
        device=device,
        loss_fn=loss_fn,
        output_save_path=None,
        log_save_path=None)

print("Validation accuracy: {}".format(validation_results['metrics']['accuracy']))

training_results = predict(
        data_loader=train_descriptors_dataloader,
        model=model,
        metric_fns=metric_fns,
        device=device,
        loss_fn=loss_fn,
        output_save_path=None,
        log_save_path=None)

print("Training accuracy: {}".format(training_results['metrics']['accuracy']))

100%|███████████████████████████████████████████| 40/40 [00:00<00:00, 80.11it/s]


Validation accuracy: 0.2436


100%|████████████████████████████████████████| 196/196 [00:00<00:00, 296.49it/s]


Training accuracy: 0.2519
