## Save pre-edit metrics in metadata folder

In [None]:
# General imports
import torch
import numpy as np
import os, sys
import json
from tqdm import tqdm

In [None]:
# Local imports
sys.path.insert(0, 'src')
from utils import read_json, read_lists, list_to_dict, ensure_dir
from utils.model_utils import prepare_device, quick_predict
from parse_config import ConfigParser
from data_loader import data_loaders
from test import predict_with_bump
import model.model as module_arch
import datasets.datasets as module_data
import model.metric as module_metric
import model.loss as module_loss

In [None]:
# Define constants, paths
class_list_path = os.path.join('metadata', 'cinic-10', 'class_names.txt')

config_path = 'configs/copies/cinic10_imagenet_val_pre_edit.json'
run_id = 'pre_edit_validation_set'

In [None]:
# Load config file, models, and dataloader
class_list = read_lists(class_list_path)
class_idx_dict = list_to_dict(class_list)

config_dict = read_json(config_path)
config = ConfigParser(config_dict, run_id=run_id)
print(config.save_dir)
device, device_ids = prepare_device(config_dict['n_gpu'])

# Load datasets
data_loader_args = dict(config.config["data_loader"]["args"])
dataset_args = dict(config["dataset_args"])

val_image_paths = read_lists(config_dict['dataset_paths']['valid_images'])
val_labels = read_lists(config_dict['dataset_paths']['valid_labels'])
val_paths_data_loader = torch.utils.data.DataLoader(
    module_data.CINIC10Dataset(
        data_dir="",
        image_paths=val_image_paths,
        labels=val_labels,
        return_paths=True,
        **dataset_args
    ),
    **data_loader_args
)

# Obtain loss function and metric functions
loss_fn = getattr(module_loss, config['loss'])
metric_fns = [getattr(module_metric, met) for met in config['metrics']]

# Load model
layernum = config.config['layernum']
model = config.init_obj('arch', module_arch, layernum=layernum, device=device)
model.eval()

In [None]:
save_dir = os.path.dirname(config.save_dir)
logit_save_path = os.path.join(save_dir, "pre_edit_logits.pth")
metric_save_path = os.path.join(save_dir, "pre_edit_metrics.pth")
predict_with_bump(
    data_loader=val_paths_data_loader,
    model=model,
    target_class_idx=0,
    bump_amount=0,
    loss_fn=loss_fn,
    metric_fns=metric_fns,
    device=device,
    output_save_path=logit_save_path,
    log_save_path=metric_save_path)