## CIFAR10 Hparam search

In [2]:
import torch
import numpy as np
import os, sys
import shutil
import pickle
import cv2
from tqdm import tqdm
from sklearn.cluster import KMeans
from datetime import datetime

sys.path.insert(0, 'src')
from utils.utils import ensure_dir, read_json, informal_log
from utils.visualizations import plot
from utils.model_utils import prepare_device

import model.metric as module_metric
import model.loss as module_loss
import datasets.datasets as module_data
import model.model as module_model

from src.train import main as train_fn
from predict import predict
from parse_config import ConfigParser


sys.path.insert(0, 'setup')
from setup_cifar10 import setup_cifar10 
# import cv2
# print(cv2.__version__)

In [3]:
config_path = 'configs/train_cifar10_pixel_explainer.json'
debug = False
if debug:
    learning_rates = [1e-6] #, 1e-5, 1e-4, 1e-3, 5e-2, 1e-2, 5e-1, 1e-1]
    weight_decays = [0, 1e-1] #, 1e-2, 1e-3]
else:
    learning_rates = [1e-4, 1e-3, 5e-2, 1e-2, 5e-1, 1e-1]
    weight_decays = [0, 1e-1, 1e-2, 1e-3]

config_json = read_json(config_path)


### Set up data loaders

In [4]:
dataset_args = config_json['dataset']['args']
train_descriptors_dataset = module_data.KDDataset(split='train', **dataset_args)
test_descriptors_dataset = module_data.KDDataset(split='test', **dataset_args)

dataloader_args = config_json['data_loader']['args']
train_descriptors_dataloader = torch.utils.data.DataLoader(
    train_descriptors_dataset,
    shuffle=True,
    **dataloader_args)
test_descriptors_dataloader = torch.utils.data.DataLoader(
    test_descriptors_dataset,
    shuffle=False,
    **dataloader_args)

device, device_ids = prepare_device(config_json['n_gpu'])
metric_fns = [getattr(module_metric, met) for met in config_json['metrics']]
loss_fn = getattr(module_loss, config_json['loss'])

### Run Hyperparameter search

In [5]:
best = {
    'lr': -1,
    'wd': -1,
    'val_acc': -1
}
n_trials = len(learning_rates) * len(weight_decays)
trial_idx = 1
timestamp = datetime.now().strftime(r'%m%d_%H%M%S')

# Logging
log_path = os.path.join(config_json['trainer']['save_dir'], timestamp, 'log.txt')
ensure_dir(os.path.dirname(log_path))
informal_log("Hyperparameter search", log_path)
informal_log("Learning rates: {}".format(learning_rates), log_path)
informal_log("Weight decays: {}".format(weight_decays), log_path)

# 

# Debug mode
if debug:
    config_json['trainer']['epochs'] = 1
    
for lr in learning_rates:
    for wd in weight_decays:
        # Update config json
        config_json['optimizer']['args'].update({
            'lr': lr,
            'weight_decay': wd
        })
        
        # Create run ID for trial
        itr_timestamp = datetime.now().strftime(r'%m%d_%H%M%S')
        informal_log("[{}] Trial {}/{}: LR = {} WD = {}".format(
            itr_timestamp, trial_idx, n_trials, lr, wd), log_path)
        run_id = os.path.join(timestamp, 'trials', 'lr_{}-wd_{}'.format(lr, wd))
        config = ConfigParser(config_json, run_id=run_id)
        print(config.config['optimizer']['args'])
        
        # Train model
        model = train_fn(
            config=config, 
            train_data_loader=train_descriptors_dataloader,
            val_data_loader=test_descriptors_dataloader)
        
        # Restore model
        model_restore_path = os.path.join(config.save_dir, 'model_best.pth')
        
        model.restore_model(model_restore_path)
        print("restored model")
        # Run on validation set using predict function
        
        trial_path = os.path.dirname(os.path.dirname(model_restore_path))
        output_save_path = os.path.join(trial_path, "val_outputs.pth")
        log_save_path = os.path.join(trial_path, "val_metrics.pth")
        
        validation_data = predict(
            data_loader=test_descriptors_dataloader,
            model=model,
            metric_fns=metric_fns,
            device=device,
            loss_fn=loss_fn,
            output_save_path=output_save_path,
            log_save_path=log_save_path)
       
        # Obtain accuracy and compare to previous best
        print(validation_data['metrics'].keys())
        val_accuracy = validation_data['metrics']['accuracy']
        if val_accuracy > best['val_acc']:
            best.update({
                'lr': lr,
                'wd': wd,
                'val_acc': val_accuracy
            })
            informal_log("Best accuracy of {:.3f} with lr={} and wd={}".format(val_accuracy, lr, wd), log_path)
            informal_log("Trial path: {}".format(trial_path), log_path)
            # Copy model and outputs to 1 directory for easy access
            best_save_dir = os.path.join(os.path.dirname(os.path.dirname(trial_path)), 'best')
            ensure_dir(best_save_dir)
            best_outputs_save_path = os.path.join(best_save_dir, 'outputs.pth')
            best_model_save_path = os.path.join(best_save_dir, 'model.pth')
            torch.save(validation_data['logits'], best_outputs_save_path)
            model.save_model(best_model_save_path)
            informal_log("Saved model and outputs to {}".format(best_save_dir), log_path)
            
            
        trial_idx += 1

Hyperparameter search
Learning rates: [0.0001, 0.001, 0.05, 0.01, 0.5, 0.1]
Weight decays: [0, 0.1, 0.01, 0.001]
[0627_150603] Trial 1/24: LR = 0.0001 WD = 0
OrderedDict([('lr', 0.0001), ('weight_decay', 0)])
Created LinearLayers model with 30730 trainable parameters
Training from scratch.
Checkpoint save directory: saved/cifar10/resnet18/explainer/pixels/model_soft_labels/hparam_search/cross_entropy/0627_150603/trials/lr_0.0001-wd_0/models
    epoch          : 1
    val_TP         : [350 150  42  75 257 130 287 112 267 392]
    val_TN         : [7974 8138 8675 8643 8115 8506 7876 8637 8046 7452]
    val_FPs        : [1010  872  323  365  878  492 1143  377  936 1542]
    val_FNs        : [666 840 960 917 750 872 694 874 751 614]
    val_accuracy   : 0.2062
    val_RMSE       : 4.029354787059586
    val_per_class_accuracy: [0.8324 0.8288 0.8717 0.8718 0.8372 0.8636 0.8163 0.8749 0.8313 0.7844]
    val_per_class_accuracy_mean: 0.84124
    val_precision  : [0.25735294 0.14677104 0.115068

100%|███████████████████████████████████████████████████████████████████████████████████| 40/40 [00:00<00:00, 90.72it/s]


Saving validation results to saved/cifar10/resnet18/explainer/pixels/model_soft_labels/hparam_search/cross_entropy/0627_150603/trials/lr_0.0001-wd_0
restored model


100%|███████████████████████████████████████████████████████████████████████████████████| 40/40 [00:00<00:00, 96.58it/s]

dict_keys(['TP', 'TN', 'FPs', 'FNs', 'accuracy', 'RMSE', 'per_class_accuracy', 'per_class_accuracy_mean', 'precision', 'precision_mean', 'recall', 'recall_mean', 'predicted_class_distribution', 'f1', 'f1_mean'])
Best accuracy of 0.392 with lr=0.0001 and wd=0
Trial path: saved/cifar10/resnet18/explainer/pixels/model_soft_labels/hparam_search/cross_entropy/0627_150603/trials/lr_0.0001-wd_0
Saved model and outputs to saved/cifar10/resnet18/explainer/pixels/model_soft_labels/hparam_search/cross_entropy/0627_150603/best
[0627_150853] Trial 2/24: LR = 0.0001 WD = 0.1
OrderedDict([('lr', 0.0001), ('weight_decay', 0.1)])
Created LinearLayers model with 30730 trainable parameters
Training from scratch.
Checkpoint save directory: saved/cifar10/resnet18/explainer/pixels/model_soft_labels/hparam_search/cross_entropy/0627_150603/trials/lr_0.0001-wd_0.1/models





    epoch          : 1
    val_TP         : [350 150  42  75 257 131 288 113 267 392]
    val_TN         : [7973 8137 8675 8643 8115 8507 7876 8639 8048 7452]
    val_FPs        : [1011  873  323  365  878  491 1143  375  934 1542]
    val_FNs        : [666 840 960 917 750 871 693 873 751 614]
    val_accuracy   : 0.2065
    val_RMSE       : 4.030446625375406
    val_per_class_accuracy: [0.8323 0.8287 0.8717 0.8718 0.8372 0.8638 0.8164 0.8752 0.8315 0.7844]
    val_per_class_accuracy_mean: 0.8413
    val_precision  : [0.25716385 0.14662757 0.11506849 0.17045455 0.22643172 0.21061093
 0.20125786 0.23155738 0.22231474 0.20268873]
    val_precision_mean: 0.1984175809663194
    val_recall     : [0.34448819 0.15151515 0.04191617 0.07560484 0.25521351 0.13073852
 0.29357798 0.11460446 0.26227898 0.38966203]
    val_recall_mean: 0.20595998256297582
    val_predicted_class_distribution: [1361 1023  365  440 1135  622 1431  488 1201 1934]
    val_f1         : [0.29448885 0.1490313  0.06144843 0

100%|███████████████████████████████████████████████████████████████████████████████████| 40/40 [00:00<00:00, 93.22it/s]


Saving validation results to saved/cifar10/resnet18/explainer/pixels/model_soft_labels/hparam_search/cross_entropy/0627_150603/trials/lr_0.0001-wd_0.1
restored model


100%|███████████████████████████████████████████████████████████████████████████████████| 40/40 [00:00<00:00, 97.35it/s]

dict_keys(['TP', 'TN', 'FPs', 'FNs', 'accuracy', 'RMSE', 'per_class_accuracy', 'per_class_accuracy_mean', 'precision', 'precision_mean', 'recall', 'recall_mean', 'predicted_class_distribution', 'f1', 'f1_mean'])
Best accuracy of 0.394 with lr=0.0001 and wd=0.1
Trial path: saved/cifar10/resnet18/explainer/pixels/model_soft_labels/hparam_search/cross_entropy/0627_150603/trials/lr_0.0001-wd_0.1
Saved model and outputs to saved/cifar10/resnet18/explainer/pixels/model_soft_labels/hparam_search/cross_entropy/0627_150603/best
[0627_151146] Trial 3/24: LR = 0.0001 WD = 0.01
OrderedDict([('lr', 0.0001), ('weight_decay', 0.01)])
Created LinearLayers model with 30730 trainable parameters
Training from scratch.
Checkpoint save directory: saved/cifar10/resnet18/explainer/pixels/model_soft_labels/hparam_search/cross_entropy/0627_150603/trials/lr_0.0001-wd_0.01/models





    epoch          : 1
    val_TP         : [350 150  42  75 257 130 287 112 267 392]
    val_TN         : [7974 8138 8675 8643 8115 8506 7876 8637 8046 7452]
    val_FPs        : [1010  872  323  365  878  492 1143  377  936 1542]
    val_FNs        : [666 840 960 917 750 872 694 874 751 614]
    val_accuracy   : 0.2062
    val_RMSE       : 4.029354787059586
    val_per_class_accuracy: [0.8324 0.8288 0.8717 0.8718 0.8372 0.8636 0.8163 0.8749 0.8313 0.7844]
    val_per_class_accuracy_mean: 0.84124
    val_precision  : [0.25735294 0.14677104 0.11506849 0.17045455 0.22643172 0.20900322
 0.2006993  0.22903885 0.22194514 0.20268873]
    val_precision_mean: 0.19794539711464076
    val_recall     : [0.34448819 0.15151515 0.04191617 0.07560484 0.25521351 0.12974052
 0.29255861 0.11359026 0.26227898 0.38966203]
    val_recall_mean: 0.20565682548629355
    val_predicted_class_distribution: [1360 1022  365  440 1135  622 1430  489 1203 1934]
    val_f1         : [0.29461279 0.14910537 0.06144843

100%|███████████████████████████████████████████████████████████████████████████████████| 40/40 [00:00<00:00, 90.29it/s]


Saving validation results to saved/cifar10/resnet18/explainer/pixels/model_soft_labels/hparam_search/cross_entropy/0627_150603/trials/lr_0.0001-wd_0.01
restored model


100%|███████████████████████████████████████████████████████████████████████████████████| 40/40 [00:00<00:00, 91.77it/s]

dict_keys(['TP', 'TN', 'FPs', 'FNs', 'accuracy', 'RMSE', 'per_class_accuracy', 'per_class_accuracy_mean', 'precision', 'precision_mean', 'recall', 'recall_mean', 'predicted_class_distribution', 'f1', 'f1_mean'])
[0627_151437] Trial 4/24: LR = 0.0001 WD = 0.001
OrderedDict([('lr', 0.0001), ('weight_decay', 0.001)])
Created LinearLayers model with 30730 trainable parameters
Training from scratch.
Checkpoint save directory: saved/cifar10/resnet18/explainer/pixels/model_soft_labels/hparam_search/cross_entropy/0627_150603/trials/lr_0.0001-wd_0.001/models





    epoch          : 1
    val_TP         : [350 150  42  75 257 130 287 112 267 392]
    val_TN         : [7974 8138 8675 8643 8115 8506 7876 8637 8046 7452]
    val_FPs        : [1010  872  323  365  878  492 1143  377  936 1542]
    val_FNs        : [666 840 960 917 750 872 694 874 751 614]
    val_accuracy   : 0.2062
    val_RMSE       : 4.029354787059586
    val_per_class_accuracy: [0.8324 0.8288 0.8717 0.8718 0.8372 0.8636 0.8163 0.8749 0.8313 0.7844]
    val_per_class_accuracy_mean: 0.84124
    val_precision  : [0.25735294 0.14677104 0.11506849 0.17045455 0.22643172 0.20900322
 0.2006993  0.22903885 0.22194514 0.20268873]
    val_precision_mean: 0.19794539711464076
    val_recall     : [0.34448819 0.15151515 0.04191617 0.07560484 0.25521351 0.12974052
 0.29255861 0.11359026 0.26227898 0.38966203]
    val_recall_mean: 0.20565682548629355
    val_predicted_class_distribution: [1360 1022  365  440 1135  622 1430  489 1203 1934]
    val_f1         : [0.29461279 0.14910537 0.06144843

100%|███████████████████████████████████████████████████████████████████████████████████| 40/40 [00:00<00:00, 92.04it/s]


Saving validation results to saved/cifar10/resnet18/explainer/pixels/model_soft_labels/hparam_search/cross_entropy/0627_150603/trials/lr_0.0001-wd_0.001
restored model


100%|███████████████████████████████████████████████████████████████████████████████████| 40/40 [00:00<00:00, 87.33it/s]

dict_keys(['TP', 'TN', 'FPs', 'FNs', 'accuracy', 'RMSE', 'per_class_accuracy', 'per_class_accuracy_mean', 'precision', 'precision_mean', 'recall', 'recall_mean', 'predicted_class_distribution', 'f1', 'f1_mean'])
[0627_151728] Trial 5/24: LR = 0.001 WD = 0
OrderedDict([('lr', 0.001), ('weight_decay', 0)])
Created LinearLayers model with 30730 trainable parameters
Training from scratch.
Checkpoint save directory: saved/cifar10/resnet18/explainer/pixels/model_soft_labels/hparam_search/cross_entropy/0627_150603/trials/lr_0.001-wd_0/models





    epoch          : 1
    val_TP         : [453 340  85  91 288 270 448 297 511 470]
    val_TN         : [8203 8342 8787 8793 8342 8492 7940 8411 8010 7933]
    val_FPs        : [ 781  668  211  215  651  506 1079  603  972 1061]
    val_FNs        : [563 650 917 901 719 732 533 689 507 536]
    val_accuracy   : 0.3253
    val_RMSE       : 3.762592191561557
    val_per_class_accuracy: [0.8656 0.8682 0.8872 0.8884 0.863  0.8762 0.8388 0.8708 0.8521 0.8403]
    val_per_class_accuracy_mean: 0.86506
    val_precision  : [0.36709887 0.33730159 0.28716216 0.29738562 0.30670927 0.34793814
 0.29338572 0.33       0.34457181 0.3069889 ]
    val_precision_mean: 0.3218542079040716
    val_recall     : [0.44586614 0.34343434 0.08483034 0.09173387 0.28599801 0.26946108
 0.45667686 0.30121704 0.50196464 0.46719682]
    val_recall_mean: 0.3248379141716584
    val_predicted_class_distribution: [1234 1008  296  306  939  776 1527  900 1483 1531]
    val_f1         : [0.40266667 0.34034034 0.13097072 0

100%|███████████████████████████████████████████████████████████████████████████████████| 40/40 [00:00<00:00, 90.19it/s]


Saving validation results to saved/cifar10/resnet18/explainer/pixels/model_soft_labels/hparam_search/cross_entropy/0627_150603/trials/lr_0.001-wd_0
restored model


100%|███████████████████████████████████████████████████████████████████████████████████| 40/40 [00:00<00:00, 92.38it/s]

dict_keys(['TP', 'TN', 'FPs', 'FNs', 'accuracy', 'RMSE', 'per_class_accuracy', 'per_class_accuracy_mean', 'precision', 'precision_mean', 'recall', 'recall_mean', 'predicted_class_distribution', 'f1', 'f1_mean'])
Best accuracy of 0.410 with lr=0.001 and wd=0
Trial path: saved/cifar10/resnet18/explainer/pixels/model_soft_labels/hparam_search/cross_entropy/0627_150603/trials/lr_0.001-wd_0
Saved model and outputs to saved/cifar10/resnet18/explainer/pixels/model_soft_labels/hparam_search/cross_entropy/0627_150603/best
[0627_151909] Trial 6/24: LR = 0.001 WD = 0.1
OrderedDict([('lr', 0.001), ('weight_decay', 0.1)])
Created LinearLayers model with 30730 trainable parameters
Training from scratch.
Checkpoint save directory: saved/cifar10/resnet18/explainer/pixels/model_soft_labels/hparam_search/cross_entropy/0627_150603/trials/lr_0.001-wd_0.1/models





    epoch          : 1
    val_TP         : [454 342  85  89 287 269 449 296 511 472]
    val_TN         : [8204 8343 8789 8795 8345 8493 7932 8416 8009 7928]
    val_FPs        : [ 780  667  209  213  648  505 1087  598  973 1066]
    val_FNs        : [562 648 917 903 720 733 532 690 507 534]
    val_accuracy   : 0.3254
    val_RMSE       : 3.7609174412635014
    val_per_class_accuracy: [0.8658 0.8685 0.8874 0.8884 0.8632 0.8762 0.8381 0.8712 0.852  0.84  ]
    val_per_class_accuracy_mean: 0.8650800000000001
    val_precision  : [0.36790924 0.33894945 0.28911565 0.29470199 0.30695187 0.34754522
 0.29231771 0.3310962  0.34433962 0.30689207]
    val_precision_mean: 0.3219819012928048
    val_recall     : [0.44685039 0.34545455 0.08483034 0.08971774 0.28500497 0.26846307
 0.45769623 0.30020284 0.50196464 0.46918489]
    val_recall_mean: 0.3249369654801092
    val_predicted_class_distribution: [1234 1009  294  302  935  774 1536  894 1484 1538]
    val_f1         : [0.40355556 0.34217109 

100%|███████████████████████████████████████████████████████████████████████████████████| 40/40 [00:00<00:00, 92.65it/s]


Saving validation results to saved/cifar10/resnet18/explainer/pixels/model_soft_labels/hparam_search/cross_entropy/0627_150603/trials/lr_0.001-wd_0.1
restored model


100%|███████████████████████████████████████████████████████████████████████████████████| 40/40 [00:00<00:00, 92.89it/s]

dict_keys(['TP', 'TN', 'FPs', 'FNs', 'accuracy', 'RMSE', 'per_class_accuracy', 'per_class_accuracy_mean', 'precision', 'precision_mean', 'recall', 'recall_mean', 'predicted_class_distribution', 'f1', 'f1_mean'])
Best accuracy of 0.414 with lr=0.001 and wd=0.1
Trial path: saved/cifar10/resnet18/explainer/pixels/model_soft_labels/hparam_search/cross_entropy/0627_150603/trials/lr_0.001-wd_0.1
Saved model and outputs to saved/cifar10/resnet18/explainer/pixels/model_soft_labels/hparam_search/cross_entropy/0627_150603/best
[0627_152121] Trial 7/24: LR = 0.001 WD = 0.01
OrderedDict([('lr', 0.001), ('weight_decay', 0.01)])
Created LinearLayers model with 30730 trainable parameters
Training from scratch.
Checkpoint save directory: saved/cifar10/resnet18/explainer/pixels/model_soft_labels/hparam_search/cross_entropy/0627_150603/trials/lr_0.001-wd_0.01/models





    epoch          : 1
    val_TP         : [453 341  85  91 288 270 448 297 512 470]
    val_TN         : [8204 8342 8787 8793 8342 8493 7940 8411 8010 7933]
    val_FPs        : [ 780  668  211  215  651  505 1079  603  972 1061]
    val_FNs        : [563 649 917 901 719 732 533 689 506 536]
    val_accuracy   : 0.3255
    val_RMSE       : 3.7615289444586226
    val_per_class_accuracy: [0.8657 0.8683 0.8872 0.8884 0.863  0.8763 0.8388 0.8708 0.8522 0.8403]
    val_per_class_accuracy_mean: 0.8651
    val_precision  : [0.36739659 0.33795837 0.28716216 0.29738562 0.30670927 0.3483871
 0.29338572 0.33       0.34501348 0.3069889 ]
    val_precision_mean: 0.3220387210205803
    val_recall     : [0.44586614 0.34444444 0.08483034 0.09173387 0.28599801 0.26946108
 0.45667686 0.30121704 0.50294695 0.46719682]
    val_recall_mean: 0.3250371560997805
    val_predicted_class_distribution: [1233 1009  296  306  939  775 1527  900 1484 1531]
    val_f1         : [0.40284571 0.34117059 0.13097072 0.

100%|███████████████████████████████████████████████████████████████████████████████████| 40/40 [00:00<00:00, 93.84it/s]


Saving validation results to saved/cifar10/resnet18/explainer/pixels/model_soft_labels/hparam_search/cross_entropy/0627_150603/trials/lr_0.001-wd_0.01
restored model


100%|███████████████████████████████████████████████████████████████████████████████████| 40/40 [00:00<00:00, 92.51it/s]

dict_keys(['TP', 'TN', 'FPs', 'FNs', 'accuracy', 'RMSE', 'per_class_accuracy', 'per_class_accuracy_mean', 'precision', 'precision_mean', 'recall', 'recall_mean', 'predicted_class_distribution', 'f1', 'f1_mean'])
[0627_152304] Trial 8/24: LR = 0.001 WD = 0.001
OrderedDict([('lr', 0.001), ('weight_decay', 0.001)])
Created LinearLayers model with 30730 trainable parameters
Training from scratch.
Checkpoint save directory: saved/cifar10/resnet18/explainer/pixels/model_soft_labels/hparam_search/cross_entropy/0627_150603/trials/lr_0.001-wd_0.001/models





    epoch          : 1
    val_TP         : [453 340  85  91 288 270 448 297 511 470]
    val_TN         : [8203 8342 8787 8793 8342 8492 7940 8411 8010 7933]
    val_FPs        : [ 781  668  211  215  651  506 1079  603  972 1061]
    val_FNs        : [563 650 917 901 719 732 533 689 507 536]
    val_accuracy   : 0.3253
    val_RMSE       : 3.762592191561557
    val_per_class_accuracy: [0.8656 0.8682 0.8872 0.8884 0.863  0.8762 0.8388 0.8708 0.8521 0.8403]
    val_per_class_accuracy_mean: 0.86506
    val_precision  : [0.36709887 0.33730159 0.28716216 0.29738562 0.30670927 0.34793814
 0.29338572 0.33       0.34457181 0.3069889 ]
    val_precision_mean: 0.3218542079040716
    val_recall     : [0.44586614 0.34343434 0.08483034 0.09173387 0.28599801 0.26946108
 0.45667686 0.30121704 0.50196464 0.46719682]
    val_recall_mean: 0.3248379141716584
    val_predicted_class_distribution: [1234 1008  296  306  939  776 1527  900 1483 1531]
    val_f1         : [0.40266667 0.34034034 0.13097072 0

100%|███████████████████████████████████████████████████████████████████████████████████| 40/40 [00:00<00:00, 95.50it/s]


Saving validation results to saved/cifar10/resnet18/explainer/pixels/model_soft_labels/hparam_search/cross_entropy/0627_150603/trials/lr_0.001-wd_0.001
restored model


100%|███████████████████████████████████████████████████████████████████████████████████| 40/40 [00:00<00:00, 90.72it/s]

dict_keys(['TP', 'TN', 'FPs', 'FNs', 'accuracy', 'RMSE', 'per_class_accuracy', 'per_class_accuracy_mean', 'precision', 'precision_mean', 'recall', 'recall_mean', 'predicted_class_distribution', 'f1', 'f1_mean'])
[0627_152442] Trial 9/24: LR = 0.05 WD = 0
OrderedDict([('lr', 0.05), ('weight_decay', 0)])
Created LinearLayers model with 30730 trainable parameters
Training from scratch.
Checkpoint save directory: saved/cifar10/resnet18/explainer/pixels/model_soft_labels/hparam_search/cross_entropy/0627_150603/trials/lr_0.05-wd_0/models





    epoch          : 1
    val_TP         : [153 280 170  87 596  74 103 442 400 122]
    val_TN         : [8843 8526 8576 8739 5640 8821 8848 7025 8567 8842]
    val_FPs        : [ 141  484  422  269 3353  177  171 1989  415  152]
    val_FNs        : [863 710 832 905 411 928 878 544 618 884]
    val_accuracy   : 0.2427
    val_RMSE       : 3.4802873444587874
    val_per_class_accuracy: [0.8996 0.8806 0.8746 0.8826 0.6236 0.8895 0.8951 0.7467 0.8967 0.8964]
    val_per_class_accuracy_mean: 0.8485400000000001
    val_precision  : [0.52040816 0.36649215 0.28716216 0.24438202 0.15092428 0.29482072
 0.37591241 0.18181818 0.49079755 0.44525547]
    val_precision_mean: 0.33579731072988617
    val_recall     : [0.15059055 0.28282828 0.16966068 0.08770161 0.591857   0.0738523
 0.1049949  0.44827586 0.39292731 0.12127237]
    val_recall_mean: 0.24239608614396682
    val_predicted_class_distribution: [ 294  764  592  356 3949  251  274 2431  815  274]
    val_f1         : [0.23358779 0.31927024

100%|███████████████████████████████████████████████████████████████████████████████████| 40/40 [00:00<00:00, 91.76it/s]


Saving validation results to saved/cifar10/resnet18/explainer/pixels/model_soft_labels/hparam_search/cross_entropy/0627_150603/trials/lr_0.05-wd_0
restored model


100%|███████████████████████████████████████████████████████████████████████████████████| 40/40 [00:00<00:00, 91.82it/s]

dict_keys(['TP', 'TN', 'FPs', 'FNs', 'accuracy', 'RMSE', 'per_class_accuracy', 'per_class_accuracy_mean', 'precision', 'precision_mean', 'recall', 'recall_mean', 'predicted_class_distribution', 'f1', 'f1_mean'])
[0627_152602] Trial 10/24: LR = 0.05 WD = 0.1
OrderedDict([('lr', 0.05), ('weight_decay', 0.1)])
Created LinearLayers model with 30730 trainable parameters
Training from scratch.
Checkpoint save directory: saved/cifar10/resnet18/explainer/pixels/model_soft_labels/hparam_search/cross_entropy/0627_150603/trials/lr_0.05-wd_0.1/models





    epoch          : 1
    val_TP         : [444 294 197  72 597  92 191 353 136 328]
    val_TN         : [8314 8455 8518 8778 5549 8787 8534 8396 8828 8545]
    val_FPs        : [ 670  555  480  230 3444  211  485  618  154  449]
    val_FNs        : [572 696 805 920 410 910 790 633 882 678]
    val_accuracy   : 0.2704
    val_RMSE       : 3.5266981725120736
    val_per_class_accuracy: [0.8758 0.8749 0.8715 0.885  0.6146 0.8879 0.8725 0.8749 0.8964 0.8873]
    val_per_class_accuracy_mean: 0.8540800000000001
    val_precision  : [0.39856373 0.34628975 0.29098966 0.2384106  0.14773571 0.30363036
 0.28254438 0.36354274 0.46896552 0.42213642]
    val_precision_mean: 0.3262808872772517
    val_recall     : [0.43700787 0.2969697  0.19660679 0.07258065 0.59285005 0.09181637
 0.19469929 0.35801217 0.13359528 0.32604374]
    val_recall_mean: 0.2700181898766436
    val_predicted_class_distribution: [1114  849  677  302 4041  303  676  971  290  777]
    val_f1         : [0.41690141 0.31973899 

100%|███████████████████████████████████████████████████████████████████████████████████| 40/40 [00:00<00:00, 86.97it/s]


Saving validation results to saved/cifar10/resnet18/explainer/pixels/model_soft_labels/hparam_search/cross_entropy/0627_150603/trials/lr_0.05-wd_0.1
restored model


100%|███████████████████████████████████████████████████████████████████████████████████| 40/40 [00:00<00:00, 94.49it/s]

dict_keys(['TP', 'TN', 'FPs', 'FNs', 'accuracy', 'RMSE', 'per_class_accuracy', 'per_class_accuracy_mean', 'precision', 'precision_mean', 'recall', 'recall_mean', 'predicted_class_distribution', 'f1', 'f1_mean'])
[0627_152707] Trial 11/24: LR = 0.05 WD = 0.01
OrderedDict([('lr', 0.05), ('weight_decay', 0.01)])
Created LinearLayers model with 30730 trainable parameters
Training from scratch.
Checkpoint save directory: saved/cifar10/resnet18/explainer/pixels/model_soft_labels/hparam_search/cross_entropy/0627_150603/trials/lr_0.05-wd_0.01/models





    epoch          : 1
    val_TP         : [308  90 129 257 592  61 108 281 148 493]
    val_TN         : [8614 8898 8661 7159 5761 8869 8850 8685 8845 8125]
    val_FPs        : [ 370  112  337 1849 3232  129  169  329  137  869]
    val_FNs        : [708 900 873 735 415 941 873 705 870 513]
    val_accuracy   : 0.2467
    val_RMSE       : 3.3649665674416442
    val_per_class_accuracy: [0.8922 0.8988 0.879  0.7416 0.6353 0.893  0.8958 0.8966 0.8993 0.8618]
    val_per_class_accuracy_mean: 0.8493400000000001
    val_precision  : [0.45427729 0.44554455 0.27682403 0.12203229 0.15481172 0.32105263
 0.3898917  0.46065574 0.51929825 0.36196769]
    val_precision_mean: 0.3506355885321646
    val_recall     : [0.30314961 0.09090909 0.12874251 0.25907258 0.58788481 0.06087824
 0.11009174 0.28498986 0.1453831  0.49005964]
    val_recall_mean: 0.24611611900963007
    val_predicted_class_distribution: [ 678  202  466 2106 3824  190  277  610  285 1362]
    val_f1         : [0.36363636 0.15100671

100%|███████████████████████████████████████████████████████████████████████████████████| 40/40 [00:00<00:00, 91.69it/s]


Saving validation results to saved/cifar10/resnet18/explainer/pixels/model_soft_labels/hparam_search/cross_entropy/0627_150603/trials/lr_0.05-wd_0.01
restored model


100%|███████████████████████████████████████████████████████████████████████████████████| 40/40 [00:00<00:00, 94.83it/s]

dict_keys(['TP', 'TN', 'FPs', 'FNs', 'accuracy', 'RMSE', 'per_class_accuracy', 'per_class_accuracy_mean', 'precision', 'precision_mean', 'recall', 'recall_mean', 'predicted_class_distribution', 'f1', 'f1_mean'])
[0627_152813] Trial 12/24: LR = 0.05 WD = 0.001
OrderedDict([('lr', 0.05), ('weight_decay', 0.001)])
Created LinearLayers model with 30730 trainable parameters
Training from scratch.
Checkpoint save directory: saved/cifar10/resnet18/explainer/pixels/model_soft_labels/hparam_search/cross_entropy/0627_150603/trials/lr_0.05-wd_0.001/models





    epoch          : 1
    val_TP         : [157 221 176 259 545  97 151 305 458 376]
    val_TN         : [8834 8779 8464 7145 6424 8793 8754 8676 8428 8448]
    val_FPs        : [ 150  231  534 1863 2569  205  265  338  554  546]
    val_FNs        : [859 769 826 733 462 905 830 681 560 630]
    val_accuracy   : 0.2745
    val_RMSE       : 3.3007726368230816
    val_per_class_accuracy: [0.8991 0.9    0.864  0.7404 0.6969 0.889  0.8905 0.8981 0.8886 0.8824]
    val_per_class_accuracy_mean: 0.8549000000000001
    val_precision  : [0.51140065 0.48893805 0.24788732 0.12205467 0.17501606 0.32119205
 0.36298077 0.47433904 0.45256917 0.40780911]
    val_precision_mean: 0.35641868890060147
    val_recall     : [0.15452756 0.22323232 0.1756487  0.26108871 0.54121152 0.09680639
 0.15392457 0.30933063 0.44990177 0.37375746]
    val_recall_mean: 0.27394296201627955
    val_predicted_class_distribution: [ 307  452  710 2122 3114  302  416  643 1012  922]
    val_f1         : [0.23733938 0.3065187

100%|███████████████████████████████████████████████████████████████████████████████████| 40/40 [00:00<00:00, 92.78it/s]


Saving validation results to saved/cifar10/resnet18/explainer/pixels/model_soft_labels/hparam_search/cross_entropy/0627_150603/trials/lr_0.05-wd_0.001
restored model


100%|███████████████████████████████████████████████████████████████████████████████████| 40/40 [00:00<00:00, 95.83it/s]

dict_keys(['TP', 'TN', 'FPs', 'FNs', 'accuracy', 'RMSE', 'per_class_accuracy', 'per_class_accuracy_mean', 'precision', 'precision_mean', 'recall', 'recall_mean', 'predicted_class_distribution', 'f1', 'f1_mean'])
[0627_153004] Trial 13/24: LR = 0.01 WD = 0
OrderedDict([('lr', 0.01), ('weight_decay', 0)])
Created LinearLayers model with 30730 trainable parameters
Training from scratch.
Checkpoint save directory: saved/cifar10/resnet18/explainer/pixels/model_soft_labels/hparam_search/cross_entropy/0627_150603/trials/lr_0.01-wd_0/models





    epoch          : 1
    val_TP         : [445 462 199 134 457 216 366 383 563 483]
    val_TN         : [8418 8406 8584 8725 7625 8703 8377 8457 8206 8207]
    val_FPs        : [ 566  604  414  283 1368  295  642  557  776  787]
    val_FNs        : [571 528 803 858 550 786 615 603 455 523]
    val_accuracy   : 0.3708
    val_RMSE       : 3.515622277776724
    val_per_class_accuracy: [0.8863 0.8868 0.8783 0.8859 0.8082 0.8919 0.8743 0.884  0.8769 0.869 ]
    val_per_class_accuracy_mean: 0.87416
    val_precision  : [0.44015826 0.43339587 0.32463295 0.32134293 0.25041096 0.42270059
 0.36309524 0.40744681 0.42046303 0.38031496]
    val_precision_mean: 0.3763961595258346
    val_recall     : [0.43799213 0.46666667 0.19860279 0.13508065 0.45382324 0.21556886
 0.37308869 0.38843813 0.55304519 0.48011928]
    val_recall_mean: 0.37024256216617013
    val_predicted_class_distribution: [1011 1066  613  417 1825  511 1008  940 1339 1270]
    val_f1         : [0.43907252 0.44941634 0.24643963 

100%|███████████████████████████████████████████████████████████████████████████████████| 40/40 [00:00<00:00, 93.09it/s]


Saving validation results to saved/cifar10/resnet18/explainer/pixels/model_soft_labels/hparam_search/cross_entropy/0627_150603/trials/lr_0.01-wd_0
restored model


100%|███████████████████████████████████████████████████████████████████████████████████| 40/40 [00:00<00:00, 92.24it/s]

dict_keys(['TP', 'TN', 'FPs', 'FNs', 'accuracy', 'RMSE', 'per_class_accuracy', 'per_class_accuracy_mean', 'precision', 'precision_mean', 'recall', 'recall_mean', 'predicted_class_distribution', 'f1', 'f1_mean'])
[0627_153122] Trial 14/24: LR = 0.01 WD = 0.1
OrderedDict([('lr', 0.01), ('weight_decay', 0.1)])
Created LinearLayers model with 30730 trainable parameters
Training from scratch.
Checkpoint save directory: saved/cifar10/resnet18/explainer/pixels/model_soft_labels/hparam_search/cross_entropy/0627_150603/trials/lr_0.01-wd_0.1/models





    epoch          : 1
    val_TP         : [449 462 196 133 461 214 361 379 570 485]
    val_TN         : [8417 8412 8586 8746 7578 8720 8385 8463 8194 8209]
    val_FPs        : [ 567  598  412  262 1415  278  634  551  788  785]
    val_FNs        : [567 528 806 859 546 788 620 607 448 521]
    val_accuracy   : 0.371
    val_RMSE       : 3.5153520449593665
    val_per_class_accuracy: [0.8866 0.8874 0.8782 0.8879 0.8039 0.8934 0.8746 0.8842 0.8764 0.8694]
    val_per_class_accuracy_mean: 0.8742000000000001
    val_precision  : [0.44192913 0.43584906 0.32236842 0.33670886 0.24573561 0.43495935
 0.36281407 0.40752688 0.4197349  0.38188976]
    val_precision_mean: 0.37895160496662716
    val_recall     : [0.44192913 0.46666667 0.19560878 0.13407258 0.45779543 0.21357285
 0.36799185 0.38438134 0.55992141 0.48210736]
    val_recall_mean: 0.370404740407439
    val_predicted_class_distribution: [1016 1060  608  395 1876  492  995  930 1358 1270]
    val_f1         : [0.44192913 0.45073171 0

100%|███████████████████████████████████████████████████████████████████████████████████| 40/40 [00:00<00:00, 89.90it/s]


Saving validation results to saved/cifar10/resnet18/explainer/pixels/model_soft_labels/hparam_search/cross_entropy/0627_150603/trials/lr_0.01-wd_0.1
restored model


100%|███████████████████████████████████████████████████████████████████████████████████| 40/40 [00:00<00:00, 95.87it/s]

dict_keys(['TP', 'TN', 'FPs', 'FNs', 'accuracy', 'RMSE', 'per_class_accuracy', 'per_class_accuracy_mean', 'precision', 'precision_mean', 'recall', 'recall_mean', 'predicted_class_distribution', 'f1', 'f1_mean'])
[0627_153243] Trial 15/24: LR = 0.01 WD = 0.01
OrderedDict([('lr', 0.01), ('weight_decay', 0.01)])
Created LinearLayers model with 30730 trainable parameters
Training from scratch.
Checkpoint save directory: saved/cifar10/resnet18/explainer/pixels/model_soft_labels/hparam_search/cross_entropy/0627_150603/trials/lr_0.01-wd_0.01/models





    epoch          : 1
    val_TP         : [445 461 200 134 458 214 365 379 567 483]
    val_TN         : [8423 8405 8583 8724 7616 8706 8378 8459 8207 8205]
    val_FPs        : [ 561  605  415  284 1377  292  641  555  775  789]
    val_FNs        : [571 529 802 858 549 788 616 607 451 523]
    val_accuracy   : 0.3706
    val_RMSE       : 3.516646129481896
    val_per_class_accuracy: [0.8868 0.8866 0.8783 0.8858 0.8074 0.892  0.8743 0.8838 0.8774 0.8688]
    val_per_class_accuracy_mean: 0.8741199999999999
    val_precision  : [0.44234592 0.43245779 0.32520325 0.32057416 0.24959128 0.4229249
 0.36282306 0.40578158 0.42250373 0.37971698]
    val_precision_mean: 0.37639226602484205
    val_recall     : [0.43799213 0.46565657 0.1996008  0.13508065 0.45481629 0.21357285
 0.37206932 0.38438134 0.55697446 0.48011928]
    val_recall_mean: 0.3700263675279757
    val_predicted_class_distribution: [1006 1066  615  418 1835  506 1006  934 1342 1272]
    val_f1         : [0.44015826 0.44844358 0

100%|███████████████████████████████████████████████████████████████████████████████████| 40/40 [00:00<00:00, 92.91it/s]


Saving validation results to saved/cifar10/resnet18/explainer/pixels/model_soft_labels/hparam_search/cross_entropy/0627_150603/trials/lr_0.01-wd_0.01
restored model


100%|███████████████████████████████████████████████████████████████████████████████████| 40/40 [00:00<00:00, 88.25it/s]

dict_keys(['TP', 'TN', 'FPs', 'FNs', 'accuracy', 'RMSE', 'per_class_accuracy', 'per_class_accuracy_mean', 'precision', 'precision_mean', 'recall', 'recall_mean', 'predicted_class_distribution', 'f1', 'f1_mean'])
[0627_153401] Trial 16/24: LR = 0.01 WD = 0.001
OrderedDict([('lr', 0.01), ('weight_decay', 0.001)])
Created LinearLayers model with 30730 trainable parameters
Training from scratch.
Checkpoint save directory: saved/cifar10/resnet18/explainer/pixels/model_soft_labels/hparam_search/cross_entropy/0627_150603/trials/lr_0.01-wd_0.001/models





    epoch          : 1
    val_TP         : [445 462 200 134 458 216 365 383 563 483]
    val_TN         : [8419 8405 8583 8725 7625 8703 8377 8459 8206 8207]
    val_FPs        : [ 565  605  415  283 1368  295  642  555  776  787]
    val_FNs        : [571 528 802 858 549 786 616 603 455 523]
    val_accuracy   : 0.3709
    val_RMSE       : 3.5153235981912108
    val_per_class_accuracy: [0.8864 0.8867 0.8783 0.8859 0.8083 0.8919 0.8742 0.8842 0.8769 0.869 ]
    val_per_class_accuracy_mean: 0.8741800000000002
    val_precision  : [0.44059406 0.43298969 0.32520325 0.32134293 0.25082147 0.42270059
 0.36246276 0.40831557 0.42046303 0.38031496]
    val_precision_mean: 0.3765208301043364
    val_recall     : [0.43799213 0.46666667 0.1996008  0.13508065 0.45481629 0.21556886
 0.37206932 0.38843813 0.55304519 0.48011928]
    val_recall_mean: 0.37033973063212566
    val_predicted_class_distribution: [1010 1067  615  417 1826  511 1007  938 1339 1270]
    val_f1         : [0.43928924 0.44919786

100%|███████████████████████████████████████████████████████████████████████████████████| 40/40 [00:00<00:00, 93.15it/s]


Saving validation results to saved/cifar10/resnet18/explainer/pixels/model_soft_labels/hparam_search/cross_entropy/0627_150603/trials/lr_0.01-wd_0.001
restored model


100%|███████████████████████████████████████████████████████████████████████████████████| 40/40 [00:00<00:00, 87.80it/s]

dict_keys(['TP', 'TN', 'FPs', 'FNs', 'accuracy', 'RMSE', 'per_class_accuracy', 'per_class_accuracy_mean', 'precision', 'precision_mean', 'recall', 'recall_mean', 'predicted_class_distribution', 'f1', 'f1_mean'])
[0627_153519] Trial 17/24: LR = 0.5 WD = 0
OrderedDict([('lr', 0.5), ('weight_decay', 0)])
Created LinearLayers model with 30730 trainable parameters
Training from scratch.
Checkpoint save directory: saved/cifar10/resnet18/explainer/pixels/model_soft_labels/hparam_search/cross_entropy/0627_150603/trials/lr_0.5-wd_0/models





    epoch          : 1
    val_TP         : [ 94 150 161  82 590  72  79 343 650 205]
    val_TN         : [8840 8821 8437 8774 5717 8815 8880 8307 7098 8737]
    val_FPs        : [ 144  189  561  234 3276  183  139  707 1884  257]
    val_FNs        : [922 840 841 910 417 930 902 643 368 801]
    val_accuracy   : 0.2426
    val_RMSE       : 3.6064941425156927
    val_per_class_accuracy: [0.8934 0.8971 0.8598 0.8856 0.6307 0.8887 0.8959 0.865  0.7748 0.8942]
    val_per_class_accuracy_mean: 0.84852
    val_precision  : [0.39495798 0.44247788 0.22299169 0.25949367 0.15261252 0.28235294
 0.36238532 0.32666667 0.25651144 0.44372294]
    val_precision_mean: 0.31441730563598835
    val_recall     : [0.09251969 0.15151515 0.16067864 0.08266129 0.58589871 0.07185629
 0.08053007 0.34787018 0.63850688 0.20377734]
    val_recall_mean: 0.2415814232177099
    val_predicted_class_distribution: [ 238  339  722  316 3866  255  218 1050 2534  462]
    val_f1         : [0.14992026 0.22573363 0.18677494

100%|███████████████████████████████████████████████████████████████████████████████████| 40/40 [00:00<00:00, 87.03it/s]


Saving validation results to saved/cifar10/resnet18/explainer/pixels/model_soft_labels/hparam_search/cross_entropy/0627_150603/trials/lr_0.5-wd_0
restored model


100%|███████████████████████████████████████████████████████████████████████████████████| 40/40 [00:00<00:00, 90.87it/s]

dict_keys(['TP', 'TN', 'FPs', 'FNs', 'accuracy', 'RMSE', 'per_class_accuracy', 'per_class_accuracy_mean', 'precision', 'precision_mean', 'recall', 'recall_mean', 'predicted_class_distribution', 'f1', 'f1_mean'])
[0627_153708] Trial 18/24: LR = 0.5 WD = 0.1
OrderedDict([('lr', 0.5), ('weight_decay', 0.1)])
Created LinearLayers model with 30730 trainable parameters
Training from scratch.
Checkpoint save directory: saved/cifar10/resnet18/explainer/pixels/model_soft_labels/hparam_search/cross_entropy/0627_150603/trials/lr_0.5-wd_0.1/models





    epoch          : 1
    val_TP         : [ 49 248  72  23  46  34 639 214 635  12]
    val_TN         : [8790 8124 8764 8936 8710 8920 5835 8713 6377 8803]
    val_FPs        : [ 194  886  234   72  283   78 3184  301 2605  191]
    val_FNs        : [967 742 930 969 961 968 342 772 383 994]
    val_accuracy   : 0.1972
    val_RMSE       : 4.140398531542586
    val_per_class_accuracy: [0.8839 0.8372 0.8836 0.8959 0.8756 0.8954 0.6474 0.8927 0.7012 0.8815]
    val_per_class_accuracy_mean: 0.8394400000000001
    val_precision  : [0.20164609 0.21869489 0.23529412 0.24210526 0.13981763 0.30357143
 0.16714622 0.41553398 0.19598765 0.0591133 ]
    val_precision_mean: 0.2178910570094248
    val_recall     : [0.04822835 0.25050505 0.07185629 0.02318548 0.04568024 0.03393214
 0.65137615 0.21703854 0.6237721  0.01192843]
    val_recall_mean: 0.19775027602453848
    val_predicted_class_distribution: [ 243 1134  306   95  329  112 3823  515 3240  203]
    val_f1         : [0.07783956 0.23352166 

  return np.nan_to_num(2 * precisions * recalls / (precisions + recalls))


    epoch          : 22
    val_TP         : [136 236  36  77 191 180 508  70 678  40]
    val_TN         : [7900 7531 8901 8546 8535 8636 7517 8894 6803 8889]
    val_FPs        : [1084 1479   97  462  458  362 1502  120 2179  105]
    val_FNs        : [880 754 966 915 816 822 473 916 340 966]
    val_accuracy   : 0.2152
    val_RMSE       : 3.8995512562344916
    val_per_class_accuracy: [0.8036 0.7767 0.8937 0.8623 0.8726 0.8816 0.8025 0.8964 0.7481 0.8929]
    val_per_class_accuracy_mean: 0.8430399999999999
    val_precision  : [0.11147541 0.13760933 0.27067669 0.14285714 0.29429892 0.33210332
 0.25273632 0.36842105 0.23731187 0.27586207]
    val_precision_mean: 0.2423352121917708
    val_recall     : [0.13385827 0.23838384 0.03592814 0.07762097 0.18967229 0.17964072
 0.51783894 0.07099391 0.66601179 0.03976143]
    val_recall_mean: 0.21497103039555365
    val_predicted_class_distribution: [1220 1715  133  539  649  542 2010  190 2857  145]
    val_f1         : [0.1216458  0.1744916

100%|███████████████████████████████████████████████████████████████████████████████████| 40/40 [00:00<00:00, 89.74it/s]


Saving validation results to saved/cifar10/resnet18/explainer/pixels/model_soft_labels/hparam_search/cross_entropy/0627_150603/trials/lr_0.5-wd_0.1
restored model


100%|███████████████████████████████████████████████████████████████████████████████████| 40/40 [00:00<00:00, 95.11it/s]

dict_keys(['TP', 'TN', 'FPs', 'FNs', 'accuracy', 'RMSE', 'per_class_accuracy', 'per_class_accuracy_mean', 'precision', 'precision_mean', 'recall', 'recall_mean', 'predicted_class_distribution', 'f1', 'f1_mean'])
[0627_153857] Trial 19/24: LR = 0.5 WD = 0.01
OrderedDict([('lr', 0.5), ('weight_decay', 0.01)])
Created LinearLayers model with 30730 trainable parameters
Training from scratch.
Checkpoint save directory: saved/cifar10/resnet18/explainer/pixels/model_soft_labels/hparam_search/cross_entropy/0627_150603/trials/lr_0.5-wd_0.01/models





    epoch          : 1
    val_TP         : [668 332 134  60 574  80  94 295  30 162]
    val_TN         : [7237 7981 8649 8851 5701 8835 8838 8699 8909 8729]
    val_FPs        : [1747 1029  349  157 3292  163  181  315   73  265]
    val_FNs        : [348 658 868 932 433 922 887 691 988 844]
    val_accuracy   : 0.2429
    val_RMSE       : 3.87780865953956
    val_per_class_accuracy: [0.7905 0.8313 0.8783 0.8911 0.6275 0.8915 0.8932 0.8994 0.8939 0.8891]
    val_per_class_accuracy_mean: 0.8485800000000001
    val_precision  : [0.27660455 0.24393828 0.27743271 0.2764977  0.14847387 0.32921811
 0.34181818 0.48360656 0.29126214 0.3793911 ]
    val_precision_mean: 0.30482432012312766
    val_recall     : [0.65748031 0.33535354 0.13373253 0.06048387 0.57000993 0.07984032
 0.09582059 0.29918864 0.02946955 0.1610338 ]
    val_recall_mean: 0.24224130836172794
    val_predicted_class_distribution: [2415 1361  483  217 3866  243  275  610  103  427]
    val_f1         : [0.38939085 0.28243301 

100%|███████████████████████████████████████████████████████████████████████████████████| 40/40 [00:00<00:00, 87.30it/s]


Saving validation results to saved/cifar10/resnet18/explainer/pixels/model_soft_labels/hparam_search/cross_entropy/0627_150603/trials/lr_0.5-wd_0.01
restored model


100%|███████████████████████████████████████████████████████████████████████████████████| 40/40 [00:00<00:00, 93.76it/s]

dict_keys(['TP', 'TN', 'FPs', 'FNs', 'accuracy', 'RMSE', 'per_class_accuracy', 'per_class_accuracy_mean', 'precision', 'precision_mean', 'recall', 'recall_mean', 'predicted_class_distribution', 'f1', 'f1_mean'])
[0627_154000] Trial 20/24: LR = 0.5 WD = 0.001
OrderedDict([('lr', 0.5), ('weight_decay', 0.001)])
Created LinearLayers model with 30730 trainable parameters
Training from scratch.
Checkpoint save directory: saved/cifar10/resnet18/explainer/pixels/model_soft_labels/hparam_search/cross_entropy/0627_150603/trials/lr_0.5-wd_0.001/models





    epoch          : 1
    val_TP         : [673 333 117  49 580 100 100 365  96  85]
    val_TN         : [7268 8126 8703 8857 5716 8761 8827 8477 8891 8872]
    val_FPs        : [1716  884  295  151 3277  237  192  537   91  122]
    val_FNs        : [343 657 885 943 427 902 881 621 922 921]
    val_accuracy   : 0.2498
    val_RMSE       : 3.8489089363090936
    val_per_class_accuracy: [0.7941 0.8459 0.882  0.8906 0.6296 0.8861 0.8927 0.8842 0.8987 0.8957]
    val_per_class_accuracy_mean: 0.84996
    val_precision  : [0.28170783 0.27362366 0.28398058 0.245      0.15037594 0.29673591
 0.34246575 0.40465632 0.51336898 0.41062802]
    val_precision_mean: 0.3202542995706709
    val_recall     : [0.66240157 0.33636364 0.11676647 0.04939516 0.57596822 0.0998004
 0.1019368  0.37018256 0.09430255 0.08449304]
    val_recall_mean: 0.249161041190992
    val_predicted_class_distribution: [2389 1217  412  200 3857  337  292  902  187  207]
    val_f1         : [0.39530103 0.3017671  0.16548798 0.

100%|███████████████████████████████████████████████████████████████████████████████████| 40/40 [00:00<00:00, 87.04it/s]


Saving validation results to saved/cifar10/resnet18/explainer/pixels/model_soft_labels/hparam_search/cross_entropy/0627_150603/trials/lr_0.5-wd_0.001
restored model


100%|███████████████████████████████████████████████████████████████████████████████████| 40/40 [00:00<00:00, 88.60it/s]

dict_keys(['TP', 'TN', 'FPs', 'FNs', 'accuracy', 'RMSE', 'per_class_accuracy', 'per_class_accuracy_mean', 'precision', 'precision_mean', 'recall', 'recall_mean', 'predicted_class_distribution', 'f1', 'f1_mean'])
[0627_154108] Trial 21/24: LR = 0.1 WD = 0
OrderedDict([('lr', 0.1), ('weight_decay', 0)])
Created LinearLayers model with 30730 trainable parameters
Training from scratch.
Checkpoint save directory: saved/cifar10/resnet18/explainer/pixels/model_soft_labels/hparam_search/cross_entropy/0627_150603/trials/lr_0.1-wd_0/models





    epoch          : 1
    val_TP         : [128 163 122 280 560  88 110 197  95 640]
    val_TN         : [8849 8816 8682 6789 6349 8815 8861 8825 8849 7548]
    val_FPs        : [ 135  194  316 2219 2644  183  158  189  133 1446]
    val_FNs        : [888 827 880 712 447 914 871 789 923 366]
    val_accuracy   : 0.2383
    val_RMSE       : 3.394716483007086
    val_per_class_accuracy: [0.8977 0.8979 0.8804 0.7069 0.6909 0.8903 0.8971 0.9022 0.8944 0.8188]
    val_per_class_accuracy_mean: 0.8476600000000001
    val_precision  : [0.48669202 0.45658263 0.27853881 0.11204482 0.17478152 0.32472325
 0.41044776 0.51036269 0.41666667 0.30680729]
    val_precision_mean: 0.3477647458137781
    val_recall     : [0.12598425 0.16464646 0.12175649 0.28225806 0.55610725 0.08782435
 0.11213048 0.19979716 0.09332024 0.6361829 ]
    val_recall_mean: 0.2380007646396906
    val_predicted_class_distribution: [ 263  357  438 2499 3204  271  268  386  228 2086]
    val_f1         : [0.20015637 0.2420193  0

100%|███████████████████████████████████████████████████████████████████████████████████| 40/40 [00:00<00:00, 95.70it/s]


Saving validation results to saved/cifar10/resnet18/explainer/pixels/model_soft_labels/hparam_search/cross_entropy/0627_150603/trials/lr_0.1-wd_0
restored model


100%|███████████████████████████████████████████████████████████████████████████████████| 40/40 [00:00<00:00, 90.74it/s]

dict_keys(['TP', 'TN', 'FPs', 'FNs', 'accuracy', 'RMSE', 'per_class_accuracy', 'per_class_accuracy_mean', 'precision', 'precision_mean', 'recall', 'recall_mean', 'predicted_class_distribution', 'f1', 'f1_mean'])
[0627_154204] Trial 22/24: LR = 0.1 WD = 0.1
OrderedDict([('lr', 0.1), ('weight_decay', 0.1)])
Created LinearLayers model with 30730 trainable parameters
Training from scratch.
Checkpoint save directory: saved/cifar10/resnet18/explainer/pixels/model_soft_labels/hparam_search/cross_entropy/0627_150603/trials/lr_0.1-wd_0.1/models





    epoch          : 1
    val_TP         : [636 165  93  47 605  55 218 226 212 270]
    val_TN         : [7594 8717 8755 8855 5504 8853 8047 8797 8735 8670]
    val_FPs        : [1390  293  243  153 3489  145  972  217  247  324]
    val_FNs        : [380 825 909 945 402 947 763 760 806 736]
    val_accuracy   : 0.2527
    val_RMSE       : 3.5525061576301313
    val_per_class_accuracy: [0.823  0.8882 0.8848 0.8902 0.6109 0.8908 0.8265 0.9023 0.8947 0.894 ]
    val_per_class_accuracy_mean: 0.85054
    val_precision  : [0.31391905 0.36026201 0.27678571 0.235      0.14777723 0.275
 0.18319328 0.51015801 0.46187364 0.45454545]
    val_precision_mean: 0.3218514394061821
    val_recall     : [0.62598425 0.16666667 0.09281437 0.04737903 0.60079444 0.05489022
 0.22222222 0.22920892 0.20825147 0.26838966]
    val_recall_mean: 0.25166012633158574
    val_predicted_class_distribution: [2026  458  336  200 4094  200 1190  443  459  594]
    val_f1         : [0.41814596 0.22790055 0.13901345 0.07

100%|███████████████████████████████████████████████████████████████████████████████████| 40/40 [00:00<00:00, 90.23it/s]


Saving validation results to saved/cifar10/resnet18/explainer/pixels/model_soft_labels/hparam_search/cross_entropy/0627_150603/trials/lr_0.1-wd_0.1
restored model


100%|███████████████████████████████████████████████████████████████████████████████████| 40/40 [00:00<00:00, 88.25it/s]

dict_keys(['TP', 'TN', 'FPs', 'FNs', 'accuracy', 'RMSE', 'per_class_accuracy', 'per_class_accuracy_mean', 'precision', 'precision_mean', 'recall', 'recall_mean', 'predicted_class_distribution', 'f1', 'f1_mean'])
[0627_154255] Trial 23/24: LR = 0.1 WD = 0.01
OrderedDict([('lr', 0.1), ('weight_decay', 0.01)])
Created LinearLayers model with 30730 trainable parameters
Training from scratch.
Checkpoint save directory: saved/cifar10/resnet18/explainer/pixels/model_soft_labels/hparam_search/cross_entropy/0627_150603/trials/lr_0.1-wd_0.01/models





    epoch          : 1
    val_TP         : [379  66 107 244 581  75  97 265 224 479]
    val_TN         : [8441 8938 8736 7373 5766 8808 8823 8739 8743 8150]
    val_FPs        : [ 543   72  262 1635 3227  190  196  275  239  844]
    val_FNs        : [637 924 895 748 426 927 884 721 794 527]
    val_accuracy   : 0.2517
    val_RMSE       : 3.4020141093181846
    val_per_class_accuracy: [0.882  0.9004 0.8843 0.7617 0.6347 0.8883 0.892  0.9004 0.8967 0.8629]
    val_per_class_accuracy_mean: 0.8503400000000001
    val_precision  : [0.41106291 0.47826087 0.2899729  0.12985631 0.15257353 0.28301887
 0.33105802 0.49074074 0.4838013  0.36205593]
    val_precision_mean: 0.3412401370500445
    val_recall     : [0.3730315  0.06666667 0.10678643 0.24596774 0.57696127 0.0748503
 0.0988787  0.26876268 0.22003929 0.47614314]
    val_recall_mean: 0.25080877088920167
    val_predicted_class_distribution: [ 922  138  369 1879 3808  265  293  540  463 1323]
    val_f1         : [0.39112487 0.11702128 

100%|███████████████████████████████████████████████████████████████████████████████████| 40/40 [00:00<00:00, 89.71it/s]


Saving validation results to saved/cifar10/resnet18/explainer/pixels/model_soft_labels/hparam_search/cross_entropy/0627_150603/trials/lr_0.1-wd_0.01
restored model


100%|███████████████████████████████████████████████████████████████████████████████████| 40/40 [00:00<00:00, 92.05it/s]

dict_keys(['TP', 'TN', 'FPs', 'FNs', 'accuracy', 'RMSE', 'per_class_accuracy', 'per_class_accuracy_mean', 'precision', 'precision_mean', 'recall', 'recall_mean', 'predicted_class_distribution', 'f1', 'f1_mean'])
[0627_154340] Trial 24/24: LR = 0.1 WD = 0.001
OrderedDict([('lr', 0.1), ('weight_decay', 0.001)])
Created LinearLayers model with 30730 trainable parameters
Training from scratch.
Checkpoint save directory: saved/cifar10/resnet18/explainer/pixels/model_soft_labels/hparam_search/cross_entropy/0627_150603/trials/lr_0.1-wd_0.001/models





    epoch          : 1
    val_TP         : [384 182 104  51 611  80 105 463 273  47]
    val_TN         : [8389 8729 8748 8862 5544 8793 8820 6744 8721 8950]
    val_FPs        : [ 595  281  250  146 3449  205  199 2270  261   44]
    val_FNs        : [632 808 898 941 396 922 876 523 745 959]
    val_accuracy   : 0.23
    val_RMSE       : 3.411290078547997
    val_per_class_accuracy: [0.8773 0.8911 0.8852 0.8913 0.6155 0.8873 0.8925 0.7207 0.8994 0.8997]
    val_per_class_accuracy_mean: 0.8459999999999999
    val_precision  : [0.39223698 0.39308855 0.29378531 0.25888325 0.15049261 0.28070175
 0.34539474 0.1694109  0.51123596 0.51648352]
    val_precision_mean: 0.33117135662617914
    val_recall     : [0.37795276 0.18383838 0.10379242 0.05141129 0.60675273 0.07984032
 0.10703364 0.46957404 0.26817289 0.04671968]
    val_recall_mean: 0.22950881410603804
    val_predicted_class_distribution: [ 979  463  354  197 4060  285  304 2733  534   91]
    val_f1         : [0.38496241 0.25051617 0

100%|███████████████████████████████████████████████████████████████████████████████████| 40/40 [00:00<00:00, 88.05it/s]


Saving validation results to saved/cifar10/resnet18/explainer/pixels/model_soft_labels/hparam_search/cross_entropy/0627_150603/trials/lr_0.1-wd_0.001
restored model


100%|███████████████████████████████████████████████████████████████████████████████████| 40/40 [00:00<00:00, 91.76it/s]

dict_keys(['TP', 'TN', 'FPs', 'FNs', 'accuracy', 'RMSE', 'per_class_accuracy', 'per_class_accuracy_mean', 'precision', 'precision_mean', 'recall', 'recall_mean', 'predicted_class_distribution', 'f1', 'f1_mean'])





### Check best model's results on training data

In [7]:
# Trained on soft labels with 75 clusters
# model_restore_path = 'saved/cifar10/resnet18/explainer/sift_descriptor_histogram/hparam_search/0525_103130/best/model.pth'
# Trained on hard GT labels with 75 clusters
# model_restore_path = 'saved/cifar10/resnet18/explainer/pixels/model_soft_labels/hparam_search/cross_entropy/0627_150603/trials/lr_0.001-wd_0.1/models/model_best.pth'
# Variable
restore_dir = 'saved/cifar10/resnet18/explainer/pixels/model_soft_labels/hparam_search/cross_entropy/0627_150603/trials/lr_0.001-wd_0.1'
model_restore_path = os.path.join(restore_dir, 'models/model_best.pth')
config_json = read_json(os.path.join(restore_dir, 'models/config.json'))

device, device_ids = prepare_device(config_json['n_gpu'])
metric_fns = [getattr(module_metric, met) for met in config_json['metrics']]
loss_fn = getattr(module_loss, config_json['loss'])

model_args = config_json['arch']['args']
model = module_model.LinearLayers(
    checkpoint_path=model_restore_path,
    **model_args)

print(model)

model.eval()
model = model.to(device)

# Dataloaders
dataset_args = config_json['dataset']['args']
train_descriptors_dataset = module_data.KDDataset(split='train', **dataset_args)
test_descriptors_dataset = module_data.KDDataset(split='test', **dataset_args)

dataloader_args = config_json['data_loader']['args']
train_descriptors_dataloader = torch.utils.data.DataLoader(
    train_descriptors_dataset,
    shuffle=True,
    **dataloader_args)
test_descriptors_dataloader = torch.utils.data.DataLoader(
    test_descriptors_dataset,
    shuffle=False,
    **dataloader_args)



# Evaluate on validation set as a sanity check
validation_results = predict(
        data_loader=test_descriptors_dataloader,
        model=model,
        metric_fns=metric_fns,
        device=device,
        loss_fn=loss_fn,
        output_save_path=None,
        log_save_path=None)

print("Validation accuracy: {}".format(validation_results['metrics']['accuracy']))

training_results = predict(
        data_loader=train_descriptors_dataloader,
        model=model,
        metric_fns=metric_fns,
        device=device,
        loss_fn=loss_fn,
        output_save_path=None,
        log_save_path=None)

print("Training accuracy: {}".format(training_results['metrics']['accuracy']))

LinearLayers(
  (layers): Sequential(
    (0): Linear(in_features=3072, out_features=10, bias=True)
  )
)
Trainable parameters: 30730


100%|███████████████████████████████████████████████████████████████████████████████████| 40/40 [00:00<00:00, 92.79it/s]


Validation accuracy: 0.4136


100%|████████████████████████████████████████████████████████████████████████████████| 196/196 [00:00<00:00, 236.18it/s]

Training accuracy: 0.42728





In [8]:
# Save outputs of explainer on train and val
save_path = os.path.join(restore_dir, 'outputs_predictions.pth')

val_outputs = validation_results['logits']
val_probabilities = torch.softmax(val_outputs, dim=1)
val_predictions = torch.argmax(val_outputs, dim=1)
val_outputs_predictions = {
    'outputs': val_outputs.cpu().numpy(),
    'probabilities': val_probabilities.cpu().numpy(),
    'predictions': val_predictions.cpu().numpy()
}

train_outputs = training_results['logits']
train_probabilities = torch.softmax(train_outputs, dim=1)
train_predictions = torch.argmax(train_outputs, dim=1)

train_outputs_predictions = {
    'outputs': train_outputs.cpu().numpy(),
    'probabilities': train_probabilities.cpu().numpy(),
    'predictions': train_predictions.cpu().numpy()
}

outputs_predictions = {
    'train': train_outputs_predictions,
    'test': val_outputs_predictions
}
if not os.path.exists(save_path):
    torch.save(outputs_predictions, save_path)
    print("Saved outputs, probabilities and predictions for train/val to {}".format(save_path))
else:
    print("File exists at {}".format(save_path))

Saved outputs, probabilities and predictions for train/val to saved/cifar10/resnet18/explainer/pixels/model_soft_labels/hparam_search/cross_entropy/0627_150603/trials/lr_0.001-wd_0.1/outputs_predictions.pth
