In [21]:
import torch
import numpy as np
import os, sys
from datetime import datetime

sys.path.insert(0, 'src')
from train import main as train_fn
from predict import predict
from parse_config import ConfigParser
import datasets.datasets as module_data
from utils.utils import read_json, ensure_dir, informal_log, write_lists
from utils.model_utils import prepare_device
from utils.attribute_utils import partition_paths_by_congruency
from model import metric as module_metric
from model import loss as module_loss

### Load hparam search variables

In [2]:
config_path = 'configs/train_ade20k_explainer_KD.json'
debug = False
if debug:
    learning_rates = [1e-6] #, 1e-5, 1e-4, 1e-3, 5e-2, 1e-2, 5e-1, 1e-1]
    weight_decays = [0, 1e-1] #, 1e-2, 1e-3]
else:
    learning_rates = [1e-6, 1e-5, 1e-4, 1e-3, 5e-2, 1e-2, 5e-1, 1e-1]
    weight_decays = [0, 1e-1, 1e-2, 1e-3]

config_json = read_json(config_path)


### Create train and validation datasets outside of loop

In [3]:
dataset_args = config_json['dataset']['args']
train_dataset = module_data.KDDataset(split='train', **dataset_args)
val_dataset = module_data.KDDataset(split='val', **dataset_args)

dataloader_args = config_json['data_loader']['args']
train_dataloader = torch.utils.data.DataLoader(
    train_dataset,
    shuffle=True,
    **dataloader_args)
val_dataloader = torch.utils.data.DataLoader(
    val_dataset,
    shuffle=False,
    **dataloader_args)

In [None]:
best = {
    'lr': -1,
    'wd': -1,
    'val_acc': -1
}
n_trials = len(learning_rates) * len(weight_decays)
trial_idx = 1
timestamp = datetime.now().strftime(r'%m%d_%H%M%S')

# Logging
log_path = os.path.join(config_json['trainer']['save_dir'], timestamp, 'log.txt')
ensure_dir(os.path.dirname(log_path))
informal_log("Hyperparameter search", log_path)
informal_log("Learning rates: {}".format(learning_rates), log_path)
informal_log("Weight decays: {}".format(weight_decays), log_path)

# Debug mode
if debug:
    config_json['trainer']['epochs'] = 1
    
for lr in learning_rates:
    for wd in weight_decays:
        # Update config json
        config_json['optimizer']['args'].update({
            'lr': lr,
            'weight_decay': wd
        })
        
        # Create run ID for trial
        itr_timestamp = datetime.now().strftime(r'%m%d_%H%M%S')
        informal_log("[{}] Trial {}/{}: LR = {} WD = {}".format(
            itr_timestamp, trial_idx, n_trials, lr, wd), log_path)
        run_id = os.path.join(timestamp, 'trials', 'lr_{}-wd_{}'.format(lr, wd))
        config = ConfigParser(config_json, run_id=run_id)
        print(config.config['optimizer']['args'])
        
        # Train model
        model = train_fn(
            config=config, 
            train_data_loader=train_dataloader,
            val_data_loader=val_dataloader)
        
        # Restore model
        model_restore_path = os.path.join(config.save_dir, 'model_best.pth')
        
        model.restore_model(model_restore_path)
        print("restored model")
        # Run on validation set using predict function
        device, device_ids = prepare_device(config_json['n_gpu'])
        metric_fns = [getattr(module_metric, met) for met in config_json['metrics']]
        loss_fn = getattr(module_loss, config_json['loss'])
        trial_path = os.path.dirname(os.path.dirname(model_restore_path))
        output_save_path = os.path.join(trial_path, "val_outputs.pth")
        log_save_path = os.path.join(trial_path, "val_metrics.pth")
        
        validation_data = predict(
            data_loader=val_dataloader,
            model=model,
            metric_fns=metric_fns,
            device=device,
            loss_fn=loss_fn,
            output_save_path=output_save_path,
            log_save_path=log_save_path)
       
        # Obtain accuracy and compare to previous best
        print(validation_data['metrics'].keys())
        val_accuracy = validation_data['metrics']['accuracy']
        if val_accuracy > best['val_acc']:
            best.update({
                'lr': lr,
                'wd': wd,
                'val_acc': val_accuracy
            })
            informal_log("Best accuracy of {:.3f} with lr={} and wd={}".format(val_accuracy, lr, wd), log_path)
            informal_log("Trial path: {}".format(trial_path), log_path)
            # Copy model and outputs to 1 directory for easy access
            best_save_dir = os.path.join(os.path.dirname(os.path.dirname(trial_path)), 'best')
            ensure_dir(best_save_dir)
            best_outputs_save_path = os.path.join(best_save_dir, 'outputs.pth')
            best_model_save_path = os.path.join(best_save_dir, 'model.pth')
            torch.save(validation_data['logits'], best_outputs_save_path)
            model.save_model(best_model_save_path)
            informal_log("Saved model and outputs to {}".format(best_save_dir), log_path)
            
            
        trial_idx += 1

## Post Processing of results before survey processing

### From the best outputs, also obtain probabilities and predictions

In [20]:
best_output_path = 'saved/PlacesCategoryClassification/0510_102912/ADE20K_predictions/saga/KD_baseline_explainer/hparam_search/0523_164052/best/outputs.pth'
best_output_dir = os.path.dirname(best_output_path)

outputs = torch.load(best_output_path)
softmax = torch.softmax(outputs, dim=1)
outputs = outputs.cpu().numpy()
softmax = softmax.cpu().numpy()
predictions = np.argmax(softmax, axis=1)

data = {
    'outputs': outputs,
    'probabilities': softmax,
    'predictions': predictions
}
data_save_path = os.path.join(best_output_dir, 'val_outputs_predictions.pth')
torch.save(data, data_save_path)
print("Saved outputs, probabilities, and predictions to {}".format(data_save_path))

Saved outputs, probabilities, and predictions to saved/PlacesCategoryClassification/0510_102912/ADE20K_predictions/saga/KD_baseline_explainer/hparam_search/0523_164052/best/val_outputs_predictions.pth


### Obtain congruent and incongruent paths

In [22]:
# Load predictions and paths to images
explainer_predictions = predictions
model_predictions_path = 'saved/PlacesCategoryClassification/0510_102912/ADE20K_predictions/saga/val_outputs_predictions.pth'
model_predictions = torch.load(model_predictions_path)['predictions']

paths_path = 'data/ade20k/full_ade20k_imagelabels.pth'
val_paths = torch.load(paths_path)['val']

# Sanity checks
assert explainer_predictions.shape == model_predictions.shape
assert len(explainer_predictions) == len(val_paths)

# Obtain congruent and incongruent paths
congruency = partition_paths_by_congruency(
    explainer_predictions=explainer_predictions,
    model_predictions=model_predictions,
    paths=val_paths)
congruent_paths = congruency['congruent']
incongruent_paths = congruency['incongruent']
print("{} congruent paths and {} incongruent paths".format(
    len(congruent_paths), len(incongruent_paths)))

# Save to .txt files
congruent_paths_save_path = os.path.join(best_output_dir, 'congruent_paths.txt')
incongruent_paths_save_path = os.path.join(best_output_dir, 'incongruent_paths.txt')
write_lists(congruent_paths, congruent_paths_save_path)
write_lists(incongruent_paths, incongruent_paths_save_path)
print("Wrote congruent paths to {} and incongruent paths to {}".format(
    congruent_paths_save_path, incongruent_paths_save_path))


4442it [00:00, 2446310.18it/s]

2764 congruent paths and 1678 incongruent paths
Wrote congruent paths to saved/PlacesCategoryClassification/0510_102912/ADE20K_predictions/saga/KD_baseline_explainer/hparam_search/0523_164052/best/congruent_paths.txt and incongruent paths to saved/PlacesCategoryClassification/0510_102912/ADE20K_predictions/saga/KD_baseline_explainer/hparam_search/0523_164052/best/incongruent_paths.txt.



