### Test hyperparam search

In [1]:
import torch

from search_hyperparams import HyperparameterSearch
from data_loader import get_data_loaders, GOAnnotationsDataset
from utils import get_class_weights, get_num_classes

In [2]:
# get the number of files (ending with _filtered.h) in the dataset
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
root_dir = "level8"
num_classes = get_num_classes(root_dir)
initial_dataset = GOAnnotationsDataset(root_dir)
class_weights = get_class_weights(initial_dataset)

In [3]:
search = HyperparameterSearch(initial_dataset, num_classes, class_weights, num_trials=2, max_epochs=5, device='cpu')

In [4]:
best_model, best_train_losses, best_val_losses, best_trial_params, best_model_test_loader = search.run_search()

In [5]:
best_train_losses

[0.1919339955562637,
 0.1107590017574174,
 0.09990979305335454,
 0.09692828428177606,
 0.09482381954079583]

### Find thresholds

In [6]:
from thresholds import ThresholdFinder
threshold_finder = ThresholdFinder(best_model, num_classes, device)

In [7]:
optimal_thresholds = threshold_finder.find_optimal_threshold(best_model_test_loader)

In [8]:
optimal_thresholds

array([0.1 , 0.1 , 0.32, 0.1 , 0.29, 0.14, 0.11, 0.1 , 0.1 , 0.1 , 0.1 ,
       0.13, 0.1 , 0.73, 0.11, 0.1 , 0.1 , 0.1 , 0.15, 0.1 , 0.1 , 0.12,
       0.1 , 0.1 , 0.1 , 0.1 , 0.1 , 0.1 , 0.1 , 0.1 , 0.1 , 0.1 , 0.1 ,
       0.1 , 0.1 , 0.1 , 0.37, 0.1 , 0.12, 0.1 , 0.21, 0.1 , 0.1 , 0.1 ,
       0.21, 0.1 , 0.1 , 0.1 , 0.1 , 0.1 , 0.11, 0.1 , 0.1 , 0.1 , 0.11,
       0.1 , 0.1 , 0.1 , 0.1 , 0.1 , 0.1 , 0.1 , 0.1 , 0.13, 0.17, 0.2 ,
       0.1 , 0.1 , 0.12, 0.41, 0.21, 0.1 , 0.1 , 0.12, 0.1 , 0.1 , 0.1 ,
       0.1 , 0.1 , 0.31, 0.11, 0.1 , 0.1 , 0.2 , 0.1 , 0.1 , 0.11, 0.1 ,
       0.1 , 0.1 , 0.1 , 0.16, 0.1 , 0.1 , 0.13, 0.1 , 0.21, 0.12, 0.13,
       0.1 , 0.1 , 0.1 , 0.1 , 0.1 , 0.11, 0.1 , 0.13, 0.1 , 0.14, 0.1 ,
       0.1 , 0.1 , 0.1 , 0.11, 0.79, 0.1 , 0.1 , 0.39, 0.13, 0.1 , 0.1 ,
       0.14, 0.1 , 0.1 , 0.1 , 0.1 , 0.1 , 0.1 , 0.1 , 0.4 , 0.13, 0.1 ,
       0.1 , 0.34, 0.11, 0.1 , 0.1 , 0.22, 0.21, 0.1 , 0.1 , 0.1 , 0.12,
       0.1 , 0.1 , 0.13, 0.1 , 0.1 , 0.13, 0.1 , 0.

### Eval 

In [9]:
from model_eval import ModelEvaluator

evaluator = ModelEvaluator(best_model, optimal_thresholds, device=device)
eval_metrics = evaluator.evaluate(best_model_test_loader)

In [10]:
eval_metrics

{'accuracy': 0.0012165450121654502,
 'precision': 0.07706173952230734,
 'recall': 0.351129363449692,
 'F1-score': 0.12638580931263857,
 'MCC': 0.15190440612519024}

### Save the model & metrics

In [11]:
from saver import ModelSaver
save_path = "saved_model"
saver = ModelSaver(save_path)

In [12]:
# save(model, thresholds, config, eval_metrics, train_losses, val_losses)
saver.save(best_model, optimal_thresholds, best_trial_params, eval_metrics, best_train_losses, best_val_losses)

### Load the saved model & thresholds 

In [14]:
from loader import ModelLoader
load_path = "saved_model"
loader = ModelLoader(load_path)

In [16]:
info = loader.load()

In [18]:
info.keys()

dict_keys(['model', 'thresholds', 'train_losses', 'val_losses', 'metadata'])

In [27]:
model

MLPClassifier(
  (layers): ModuleList(
    (0): Linear(in_features=512, out_features=2576, bias=True)
    (1): BatchNorm1d(2576, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
    (3): Dropout(p=0.4571852064190821, inplace=False)
    (4): Linear(in_features=2576, out_features=1288, bias=True)
    (5): BatchNorm1d(1288, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (6): ReLU()
    (7): Dropout(p=0.4571852064190821, inplace=False)
    (8): Linear(in_features=1288, out_features=644, bias=True)
    (9): BatchNorm1d(644, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (10): ReLU()
    (11): Dropout(p=0.4571852064190821, inplace=False)
    (12): Linear(in_features=644, out_features=322, bias=True)
    (13): BatchNorm1d(322, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (14): ReLU()
    (15): Dropout(p=0.4571852064190821, inplace=False)
  )
  (out): Linear(in_features=322, out_features=161, bias=T

In [23]:
model = info['model']
thresholds = info['thresholds']
val_losses = info['val_losses']
metadata = info['metadata'] 