In [40]:
import torch
import os, sys
from datetime import datetime

sys.path.insert(0, 'src')
from train import main as train_fn
from parse_config import ConfigParser
import datasets.datasets as module_data
from utils.utils import read_json, ensure_dir, informal_log

### Load hparam search variables

In [33]:
config_path = 'configs/train_ade20k_explainer_KD.json'
learning_rates = [1e-6, 1e-5, 1e-4, 1e-3, 5e-2, 1e-2, 5e-1, 1e-1]
weight_decays = [0, 1e-1, 1e-2, 1e-3]

config_json = read_json(config_path)

### Create train and validation datasets outside of loop

In [35]:
dataset_args = config_json['dataset']['args']
train_dataset = module_data.KDDataset(split='train', **dataset_args)
val_dataset = module_data.KDDataset(split='val', **dataset_args)

dataloader_args = config_json['data_loader']['args']
train_dataloader = torch.utils.data.DataLoader(
    train_dataset,
    shuffle=True,
    **dataloader_args)
val_dataloader = torch.utils.data.DataLoader(
    val_dataset,
    shuffle=False,
    **dataloader_args)

In [38]:
log_path = os.path.join(config_json['trainer']['save_dir'], 'log.txt')
ensure_dir(os.path.dirname(log_path))
informal_log("Hyperparameter search", log_path)
informal_log("Learning rates: {}".format(learning_rates), log_path)
informal_log("Weight decays: {}".format(weight_decays), log_path)

Hyperparameter search
Learning rates: [1e-06, 1e-05, 0.0001, 0.001, 0.05, 0.01, 0.5, 0.1]
Weight decays: [0, 0.1, 0.01, 0.001]


In [17]:
ds = KDDataset(
    input_features_path='data/ade20k/frequency_filtered_one_hot_attributes.pth',
    labels_path='saved/PlacesCategoryClassification/0510_102912/ADE20K_predictions/saga/outputs_predictions.pth',
    split='train',
    out_type='probabilities')

In [39]:
best = {
    'lr': -1,
    'wd': -1,
    'val_acc': -1
}
n_trials = len(learning_rates) * len(weight_decays)
trial_idx = 1
for lr in learning_rates:
    for wd in weight_decays:
        # Update config json
        config_json['optimizer']['args'].update({
            'lr': lr,
            'weight_decay': wd
        })
        timestamp = datetime.now().strftime(r'%m%d_%H%M%S')
        informal_log("[{}] Trial {}/{}: LR = {} WD = {}".format(
            timestamp, trial_idx, n_trials, lr, wd), log_path)
        # Create directory for trial
        # Train
        # Restore model
        # Run on validation set using predict function
        # obtain accuracy
        trial_idx += 1

OrderedDict([('lr', 1e-06), ('weight_decay', 0), ('amsgrad', False)])
OrderedDict([('lr', 1e-06), ('weight_decay', 0.1), ('amsgrad', False)])
OrderedDict([('lr', 1e-06), ('weight_decay', 0.01), ('amsgrad', False)])
OrderedDict([('lr', 1e-06), ('weight_decay', 0.001), ('amsgrad', False)])
OrderedDict([('lr', 1e-05), ('weight_decay', 0), ('amsgrad', False)])
OrderedDict([('lr', 1e-05), ('weight_decay', 0.1), ('amsgrad', False)])
OrderedDict([('lr', 1e-05), ('weight_decay', 0.01), ('amsgrad', False)])
OrderedDict([('lr', 1e-05), ('weight_decay', 0.001), ('amsgrad', False)])
OrderedDict([('lr', 0.0001), ('weight_decay', 0), ('amsgrad', False)])
OrderedDict([('lr', 0.0001), ('weight_decay', 0.1), ('amsgrad', False)])
OrderedDict([('lr', 0.0001), ('weight_decay', 0.01), ('amsgrad', False)])
OrderedDict([('lr', 0.0001), ('weight_decay', 0.001), ('amsgrad', False)])
OrderedDict([('lr', 0.001), ('weight_decay', 0), ('amsgrad', False)])
OrderedDict([('lr', 0.001), ('weight_decay', 0.1), ('amsgra