# Kopp et al 2021 Evaluation 
**Authorship:**
Adam Klie, *08/12/2022*
***
**Description:**
Notebook to perform a brief evaluation of trained models on the Kopp et al (2021) dataset.
***

In [None]:
if 'autoreload' not in get_ipython().extension_manager.loaded:
    %load_ext autoreload
%autoreload 2

import os
import glob
import logging
import torch
import numpy as np
import pandas as pd
import eugene as eu
import matplotlib.pyplot as plt
import matplotlib

matplotlib.rcParams['pdf.fonttype'] = 42
matplotlib.rcParams['ps.fonttype'] = 42

In [None]:
eu.settings.dataset_dir = "/cellar/users/aklie/data/eugene/kopp21"
eu.settings.output_dir = "/cellar/users/aklie/projects/EUGENe/EUGENe_paper/output/kopp21"
eu.settings.logging_dir = "/cellar/users/aklie/projects/EUGENe/EUGENe_paper/logs/kopp21"
eu.settings.config_dir = "/cellar/users/aklie/projects/EUGENe/EUGENe_paper/configs/kopp21"
eu.settings.figure_dir = "/cellar/users/aklie/projects/EUGENe/EUGENe_paper/figures/kopp21"
eu.settings.verbosity = logging.ERROR

# Load in the test `SeqData`(s)

In [None]:
# Load in the training data that's been predicted on
sdata_test = eu.dl.read_h5sd(filename=os.path.join(eu.settings.dataset_dir, "jund_test_processed.h5sd"))
sdata_test

# Get test set predictions for each model

In [None]:
# Predict on test set with each model
model_types = ["FCN", "CNN", "Hybrid", "Kopp21CNN"]
model_names = ["dsFCN", "dsCNN", "dsHybrid", "Kopp21CNN"]
trials = 5
for model_name, model_type in zip(model_names, model_types):
    for trial in range(1, trials+1):
        print(f"{model_name} trial {trial}")
        model_file = glob.glob(os.path.join(eu.settings.logging_dir, model_name, f"trial_{trial}", "checkpoints", "*"))[0]
        if model_type == "FCN":
            model = eu.models.FCN.load_from_checkpoint(model_file)
        if model_type == "CNN":
            model = eu.models.CNN.load_from_checkpoint(model_file)
        elif model_type == "Hybrid":
            model = eu.models.Hybrid.load_from_checkpoint(model_file)
        elif model_type == "Kopp21CNN":
            model = eu.models.Kopp21CNN.load_from_checkpoint(model_file)
            
        eu.evaluate.predictions(
            model,
            sdata=sdata_test,
            target_keys="target",
            name=model_name,
            version=f"trial_{trial}",
            file_label="test",
            prefix=f"{model_name}_trial_{trial}_"
        )
        del model
sdata_test.write_h5sd(os.path.join(eu.settings.output_dir, "test_predictions.h5sd"))

---

# Scratch