In [1]:
import torch
import os, sys
from datetime import datetime

sys.path.insert(0, 'src')
from train import main as train_fn
from predict import predict
from parse_config import ConfigParser
import datasets.datasets as module_data
from utils.utils import read_json, ensure_dir, informal_log
from utils.model_utils import prepare_device
from model import metric as module_metrics
from model import loss as module_loss

### Load hparam search variables

In [2]:
config_path = 'configs/train_ade20k_explainer_KD.json'
debug = False
if debug:
    learning_rates = [1e-6] #, 1e-5, 1e-4, 1e-3, 5e-2, 1e-2, 5e-1, 1e-1]
    weight_decays = [0, 1e-1] #, 1e-2, 1e-3]
else:
    learning_rates = [1e-6, 1e-5, 1e-4, 1e-3, 5e-2, 1e-2, 5e-1, 1e-1]
    weight_decays = [0, 1e-1, 1e-2, 1e-3]

config_json = read_json(config_path)


### Create train and validation datasets outside of loop

In [3]:
dataset_args = config_json['dataset']['args']
train_dataset = module_data.KDDataset(split='train', **dataset_args)
val_dataset = module_data.KDDataset(split='val', **dataset_args)

dataloader_args = config_json['data_loader']['args']
train_dataloader = torch.utils.data.DataLoader(
    train_dataset,
    shuffle=True,
    **dataloader_args)
val_dataloader = torch.utils.data.DataLoader(
    val_dataset,
    shuffle=False,
    **dataloader_args)

In [None]:
best = {
    'lr': -1,
    'wd': -1,
    'val_acc': -1
}
n_trials = len(learning_rates) * len(weight_decays)
trial_idx = 1
timestamp = datetime.now().strftime(r'%m%d_%H%M%S')

# Logging
log_path = os.path.join(config_json['trainer']['save_dir'], timestamp, 'log.txt')
ensure_dir(os.path.dirname(log_path))
informal_log("Hyperparameter search", log_path)
informal_log("Learning rates: {}".format(learning_rates), log_path)
informal_log("Weight decays: {}".format(weight_decays), log_path)

# Debug mode
if debug:
    config_json['trainer']['epochs'] = 1
    
for lr in learning_rates:
    for wd in weight_decays:
        # Update config json
        config_json['optimizer']['args'].update({
            'lr': lr,
            'weight_decay': wd
        })
        
        # Create run ID for trial
        itr_timestamp = datetime.now().strftime(r'%m%d_%H%M%S')
        informal_log("[{}] Trial {}/{}: LR = {} WD = {}".format(
            itr_timestamp, trial_idx, n_trials, lr, wd), log_path)
        run_id = os.path.join(timestamp, 'trials', 'lr_{}-wd_{}'.format(lr, wd))
        config = ConfigParser(config_json, run_id=run_id)
        print(config.config['optimizer']['args'])
        
        # Train model
        model = train_fn(
            config=config, 
            train_data_loader=train_dataloader,
            val_data_loader=val_dataloader)
        
        # Restore model
        model_restore_path = os.path.join(config.save_dir, 'model_best.pth')
        
        model.restore_model(model_restore_path)
        print("restored model")
        # Run on validation set using predict function
        device, device_ids = prepare_device(config_json['n_gpu'])
        metric_fns = [getattr(module_metrics, met) for met in config_json['metrics']]
        loss_fn = getattr(module_loss, config_json['loss'])
        trial_path = os.path.dirname(os.path.dirname(model_restore_path))
        output_save_path = os.path.join(trial_path, "val_outputs.pth")
        log_save_path = os.path.join(trial_path, "val_metrics.pth")
        
        validation_data = predict(
            data_loader=val_dataloader,
            model=model,
            metric_fns=metric_fns,
            device=device,
            loss_fn=loss_fn,
            output_save_path=output_save_path,
            log_save_path=log_save_path)
       
        # Obtain accuracy and compare to previous best
        print(validation_data['metrics'].keys())
        val_accuracy = validation_data['metrics']['accuracy']
        if val_accuracy > best['val_acc']:
            best.update({
                'lr': lr,
                'wd': wd,
                'val_acc': val_accuracy
            })
            informal_log("Best accuracy of {:.3f} with lr={} and wd={}".format(val_accuracy, lr, wd), log_path)
            informal_log("Trial path: {}".format(trial_path), log_path)
            # Copy model and outputs to 1 directory for easy access
            best_save_dir = os.path.join(os.path.dirname(os.path.dirname(trial_path)), 'best')
            ensure_dir(best_save_dir)
            best_outputs_save_path = os.path.join(best_save_dir, 'outputs.pth')
            best_model_save_path = os.path.join(best_save_dir, 'model.pth')
            torch.save(validation_data['logits'], best_outputs_save_path)
            model.save_model(best_model_save_path)
            informal_log("Saved model and outputs to {}".format(best_save_dir), log_path)
            
            
        trial_idx += 1

Hyperparameter search
Learning rates: [1e-06, 1e-05, 0.0001, 0.001, 0.05, 0.01, 0.5, 0.1]
Weight decays: [0, 0.1, 0.01, 0.001]
[0523_164052] Trial 1/32: LR = 1e-06 WD = 0
OrderedDict([('lr', 1e-06), ('weight_decay', 0), ('amsgrad', False)])
Created LinearLayers model with 19216 trainable parameters
Training from scratch.
Checkpoint save directory: saved/PlacesCategoryClassification/0510_102912/ADE20K_predictions/saga/KD_baseline_explainer/hparam_search/0523_164052/trials/lr_1e-06-wd_0/models
    epoch          : 1
    val_TP         : [  0   8 227   0  74   3  53   3   0   1  20   6  12   0  40   1]
    val_TN         : [3989 4049 3098 4301 3256 4089 3524 4196 4120 4331 3973 3578 4166 4178
 3926 3862]
    val_FPs        : [ 155   80  168   13 1066  127  614  168   69   65  200  545  170  234
  291   29]
    val_FNs        : [298 305 949 128  46 223 251  75 253  45 249 313  94  30 185 550]
    val_accuracy   : 0.10085547050877983
    val_per_class_accuracy: [0.89801891 0.91332733 0.7485

  return np.nan_to_num(2 * precisions * recalls / (precisions + recalls))


    epoch          : 2
    val_TP         : [  0   8 228   0  74   3  53   3   0   1  22   6  12   0  42   1]
    val_TN         : [3985 4048 3095 4301 3263 4090 3525 4197 4120 4333 3972 3577 4167 4181
 3926 3861]
    val_FPs        : [ 159   81  171   13 1059  126  613  167   69   63  201  546  169  231
  291   30]
    val_FNs        : [298 305 948 128  46 223 251  75 253  45 247 313  94  30 183 550]
    val_accuracy   : 0.10198108959927961
    val_per_class_accuracy: [0.89711842 0.91310221 0.74808645 0.96825754 0.75123818 0.92143179
 0.80549302 0.94552004 0.92751013 0.97568663 0.89914453 0.80661864
 0.94079244 0.94124268 0.89329131 0.86942819]
    val_per_class_accuracy_mean: 0.8877476361999099
    val_precision  : [0.         0.08988764 0.57142857 0.         0.06531333 0.02325581
 0.07957958 0.01764706 0.         0.015625   0.09865471 0.01086957
 0.06629834 0.         0.12612613 0.03225806]
    val_precision_mean: 0.07480898741281994
    val_recall     : [0.         0.02555911 0.193

  return np.nan_to_num(2 * precisions * recalls / (precisions + recalls))


    epoch          : 3
    val_TP         : [  0   9 236   0  73   3  53   3   0   1  23   6  12   0  43   1]
    val_TN         : [3984 4047 3089 4303 3269 4094 3527 4199 4120 4334 3971 3576 4169 4183
 3925 3861]
    val_FPs        : [ 160   82  177   11 1053  122  611  165   69   62  202  547  167  229
  292   30]
    val_FNs        : [298 304 940 128  47 223 251  75 253  45 246 313  94  30 182 550]
    val_accuracy   : 0.10423232778027916
    val_per_class_accuracy: [0.89689329 0.91310221 0.7485367  0.96870779 0.7523638  0.92233228
 0.80594327 0.94597028 0.92751013 0.97591175 0.89914453 0.80639352
 0.94124268 0.94169293 0.89329131 0.86942819]
    val_per_class_accuracy_mean: 0.8880290409725349
    val_precision  : [0.         0.0989011  0.57142857 0.         0.06483126 0.024
 0.07981928 0.01785714 0.         0.01587302 0.10222222 0.01084991
 0.06703911 0.         0.12835821 0.03225806]
    val_precision_mean: 0.07583986741827622
    val_recall     : [0.         0.02875399 0.20068027

  return np.nan_to_num(2 * precisions * recalls / (precisions + recalls))


    epoch          : 4
    val_TP         : [  0   9 248   0  73   3  53   3   0   1  23   6  12   0  45   1]
    val_TN         : [3985 4042 3081 4303 3280 4095 3529 4202 4121 4336 3972 3579 4170 4186
 3922 3862]
    val_FPs        : [ 159   87  185   11 1042  121  609  162   68   60  201  544  166  226
  295   29]
    val_FNs        : [298 304 928 128  47 223 251  75 253  45 246 313  94  30 180 550]
    val_accuracy   : 0.10738406123367852
    val_per_class_accuracy: [0.89711842 0.91197659 0.74943719 0.96870779 0.75484016 0.92255741
 0.80639352 0.94664566 0.92773525 0.976362   0.89936965 0.80706889
 0.94146781 0.9423683  0.89306619 0.86965331]
    val_per_class_accuracy_mean: 0.8884230076542098
    val_precision  : [0.         0.09375    0.57274827 0.         0.06547085 0.02419355
 0.08006042 0.01818182 0.         0.01639344 0.10267857 0.01090909
 0.06741573 0.         0.13235294 0.03333333]
    val_precision_mean: 0.07609300120334103
    val_recall     : [0.         0.02875399 0.210

  return np.nan_to_num(2 * precisions * recalls / (precisions + recalls))


    epoch          : 5
    val_TP         : [  0  11 256   0  73   3  53   3   0   1  23   6  12   0  45   1]
    val_TN         : [3986 4040 3078 4303 3290 4093 3527 4202 4123 4338 3973 3578 4172 4189
 3919 3864]
    val_FPs        : [ 158   89  188   11 1032  123  611  162   66   58  200  545  164  223
  298   27]
    val_FNs        : [298 302 920 128  47 223 251  75 253  45 246 313  94  30 180 550]
    val_accuracy   : 0.10963529941467808
    val_per_class_accuracy: [0.89734354 0.91197659 0.75056281 0.96870779 0.7570914  0.92210716
 0.80594327 0.94664566 0.9281855  0.97681225 0.89959478 0.80684376
 0.94191805 0.94304367 0.89239081 0.87010356]
    val_per_class_accuracy_mean: 0.8887044124268348
    val_precision  : [0.         0.11       0.57657658 0.         0.06606335 0.02380952
 0.07981928 0.01818182 0.         0.01694915 0.10313901 0.01088929
 0.06818182 0.         0.13119534 0.03571429]
    val_precision_mean: 0.07753246509106303
    val_recall     : [0.         0.03514377 0.217

  return np.nan_to_num(2 * precisions * recalls / (precisions + recalls))


    epoch          : 6
    val_TP         : [  0  12 264   0  73   3  53   3   0   1  24   8  12   0  47   1]
    val_TN         : [3987 4035 3077 4304 3300 4095 3528 4204 4124 4341 3970 3577 4175 4192
 3916 3864]
    val_FPs        : [ 157   94  189   10 1022  121  610  160   65   55  203  546  161  220
  301   27]
    val_FNs        : [298 301 912 128  47 223 251  75 253  45 245 311  94  30 178 550]
    val_accuracy   : 0.11278703286807744
    val_per_class_accuracy: [0.89756866 0.91107609 0.75213868 0.96893291 0.75934264 0.92255741
 0.80616839 0.9470959  0.92841063 0.97748762 0.89914453 0.80706889
 0.94259343 0.94371905 0.89216569 0.87010356]
    val_per_class_accuracy_mean: 0.8890983791085096
    val_precision  : [0.         0.11320755 0.58278146 0.         0.06666667 0.02419355
 0.07993967 0.01840491 0.         0.01785714 0.10572687 0.01444043
 0.06936416 0.         0.13505747 0.03571429]
    val_precision_mean: 0.07895963515455245
    val_recall     : [0.         0.03833866 0.224

  return np.nan_to_num(2 * precisions * recalls / (precisions + recalls))


    epoch          : 7
    val_TP         : [  0  12 272   0  73   3  53   3   0   1  25   8  12   0  47   1]
    val_TN         : [3986 4033 3073 4305 3310 4095 3529 4206 4124 4343 3966 3577 4177 4195
 3915 3864]
    val_FPs        : [ 158   96  193    9 1012  121  609  158   65   53  207  546  159  217
  302   27]
    val_FNs        : [298 301 904 128  47 223 251  75 253  45 244 311  94  30 178 550]
    val_accuracy   : 0.11481314723097703
    val_per_class_accuracy: [0.89734354 0.91062584 0.75303917 0.96915804 0.76159388 0.92255741
 0.80639352 0.94754615 0.92841063 0.97793787 0.89846916 0.80706889
 0.94304367 0.94439442 0.89194057 0.87010356]
    val_per_class_accuracy_mean: 0.8893516434038722
    val_precision  : [0.         0.11111111 0.58494624 0.         0.06728111 0.02419355
 0.08006042 0.01863354 0.         0.01851852 0.10775862 0.01444043
 0.07017544 0.         0.13467049 0.03571429]
    val_precision_mean: 0.0792189843262182
    val_recall     : [0.         0.03833866 0.2312

  return np.nan_to_num(2 * precisions * recalls / (precisions + recalls))


    epoch          : 8
    val_TP         : [  0  12 285   0  73   3  53   3   0   1  26   8  12   0  48   1]
    val_TN         : [3987 4032 3070 4305 3320 4095 3530 4209 4124 4344 3962 3580 4177 4200
 3913 3865]
    val_FPs        : [ 157   97  196    9 1002  121  608  155   65   52  211  543  159  212
  304   26]
    val_FNs        : [298 301 891 128  47 223 251  75 253  45 243 311  94  30 177 550]
    val_accuracy   : 0.11819000450247637
    val_per_class_accuracy: [0.89756866 0.91040072 0.75529041 0.96915804 0.76384511 0.92255741
 0.80661864 0.94822152 0.92841063 0.97816299 0.89779379 0.80774426
 0.94304367 0.94552004 0.89171544 0.87032868]
    val_per_class_accuracy_mean: 0.8897737505628095
    val_precision  : [0.         0.11009174 0.59251559 0.         0.06790698 0.02419355
 0.08018154 0.01898734 0.         0.01886792 0.10970464 0.01451906
 0.07017544 0.         0.13636364 0.03703704]
    val_precision_mean: 0.08003402998698775
    val_recall     : [0.         0.03833866 0.242

  return np.nan_to_num(2 * precisions * recalls / (precisions + recalls))


    epoch          : 9
    val_TP         : [  0  12 291   0  72   3  53   3   0   1  26   8  12   0  47   1]
    val_TN         : [3987 4031 3066 4305 3324 4098 3534 4215 4126 4344 3960 3574 4179 4202
 3907 3865]
    val_FPs        : [157  98 200   9 998 118 604 149  63  52 213 549 157 210 310  26]
    val_FNs        : [298 301 885 128  48 223 251  75 253  45 243 311  94  30 178 550]
    val_accuracy   : 0.11909049977487618
    val_per_class_accuracy: [0.89756866 0.9101756  0.75574066 0.96915804 0.76452049 0.92323278
 0.80751914 0.94957226 0.92886087 0.97816299 0.89734354 0.80639352
 0.94349392 0.94597028 0.89013958 0.87032868]
    val_per_class_accuracy_mean: 0.8898863124718595
    val_precision  : [0.         0.10909091 0.59266802 0.         0.06728972 0.02479339
 0.08066971 0.01973684 0.         0.01886792 0.10878661 0.01436266
 0.07100592 0.         0.13165266 0.03703704]
    val_precision_mean: 0.0797475876411537
    val_recall     : [0.         0.03833866 0.24744898 0.         0

  return np.nan_to_num(2 * precisions * recalls / (precisions + recalls))


    epoch          : 10
    val_TP         : [  0  12 300   0  72   3  54   3   0   1  26   9  12   0  47   1]
    val_TN         : [3985 4030 3062 4305 3331 4101 3538 4216 4126 4345 3958 3574 4180 4207
 3905 3865]
    val_FPs        : [159  99 204   9 991 115 600 148  63  51 215 549 156 205 312  26]
    val_FNs        : [298 301 876 128  48 223 250  75 253  45 243 310  94  30 178 550]
    val_accuracy   : 0.12156686177397569
    val_per_class_accuracy: [0.89711842 0.90995047 0.75686628 0.96915804 0.76609635 0.92390815
 0.80864475 0.94979739 0.92886087 0.97838811 0.89689329 0.80661864
 0.94371905 0.9470959  0.88968933 0.87032868]
    val_per_class_accuracy_mean: 0.8901958577217469
    val_precision  : [0.         0.10810811 0.5952381  0.         0.06773283 0.02542373
 0.08256881 0.01986755 0.         0.01923077 0.10788382 0.01612903
 0.07142857 0.         0.13091922 0.03703704]
    val_precision_mean: 0.08009797301339247
    val_recall     : [0.         0.03833866 0.25510204 0.        

  return np.nan_to_num(2 * precisions * recalls / (precisions + recalls))


    epoch          : 11
    val_TP         : [  0  12 314   0  72   3  54   3   0   1  26   9  12   0  47   1]
    val_TN         : [3987 4027 3057 4305 3342 4101 3541 4219 4126 4346 3954 3575 4184 4211
 3903 3864]
    val_FPs        : [157 102 209   9 980 115 597 145  63  50 219 548 152 201 314  27]
    val_FNs        : [298 301 862 128  48 223 250  75 253  45 243 310  94  30 178 550]
    val_accuracy   : 0.12471859522737505
    val_per_class_accuracy: [0.89756866 0.9092751  0.75889239 0.96915804 0.76857271 0.92390815
 0.80932013 0.95047276 0.92886087 0.97861324 0.8959928  0.80684376
 0.94461954 0.9479964  0.88923908 0.87010356]
    val_per_class_accuracy_mean: 0.8905898244034218
    val_precision  : [0.         0.10526316 0.60038241 0.         0.06844106 0.02542373
 0.08294931 0.02027027 0.         0.01960784 0.10612245 0.01615799
 0.07317073 0.         0.13019391 0.03571429]
    val_precision_mean: 0.08023107150841009
    val_recall     : [0.         0.03833866 0.2670068  0.        

  return np.nan_to_num(2 * precisions * recalls / (precisions + recalls))


    epoch          : 12
    val_TP         : [  0  13 321   0  72   3  54   3   0   0  26   9  12   0  48   2]
    val_TN         : [3985 4024 3054 4305 3350 4101 3545 4220 4127 4347 3952 3574 4186 4215
 3903 3863]
    val_FPs        : [159 105 212   9 972 115 593 144  62  49 221 549 150 197 314  28]
    val_FNs        : [298 300 855 128  48 223 250  75 253  46 243 310  94  30 177 549]
    val_accuracy   : 0.12674470959027465
    val_per_class_accuracy: [0.89711842 0.90882485 0.75979289 0.96915804 0.77037371 0.92390815
 0.81022062 0.95069788 0.929086   0.97861324 0.89554255 0.80661864
 0.94506979 0.94889689 0.88946421 0.87010356]
    val_per_class_accuracy_mean: 0.8908430886987844
    val_precision  : [0.         0.11016949 0.60225141 0.         0.06896552 0.02542373
 0.08346213 0.02040816 0.         0.         0.10526316 0.01612903
 0.07407407 0.         0.13259669 0.06666667]
    val_precision_mean: 0.08158812855454463
    val_recall     : [0.         0.04153355 0.27295918 0.        

  return np.nan_to_num(2 * precisions * recalls / (precisions + recalls))
