# Jores et al 2021 Evaluation 
**Authorship:**
Adam Klie, *08/12/2022*
***
**Description:**
Notebook to perform a brief evaluation of trained models on the Jores et al (2021) dataset.
***

In [None]:
if 'autoreload' not in get_ipython().extension_manager.loaded:
    %load_ext autoreload
%autoreload 2

import os
import glob
import logging
import torch
import numpy as np
import pandas as pd
import eugene as eu
import matplotlib.pyplot as plt
import matplotlib

# For illustrator editing
matplotlib.rcParams['pdf.fonttype'] = 42
matplotlib.rcParams['ps.fonttype'] = 42

In [None]:
eu.settings.dataset_dir = "/cellar/users/aklie/data/eugene/jores21"
eu.settings.output_dir = "/cellar/users/aklie/projects/EUGENe/EUGENe_paper/output/jores21"
eu.settings.logging_dir = "/cellar/users/aklie/projects/EUGENe/EUGENe_paper/logs/jores21"
eu.settings.config_dir = "/cellar/users/aklie/projects/EUGENe/EUGENe_paper/configs/jores21"
eu.settings.figure_dir = "/cellar/users/aklie/projects/EUGENe/EUGENe_paper/figures/jores21"
eu.settings.verbosity = logging.ERROR
eu.settings.batch_size = 128
eu.settings.dl_num_workers = 0

# Load in the `leaf`, `proto` and `combined` test `SeqData`s 

In [None]:
# Load in the preprocessed test set data
sdata_leaf = eu.dl.read(os.path.join(eu.settings.dataset_dir, "leaf_processed_test.h5sd"))
sdata_proto = eu.dl.read(os.path.join(eu.settings.dataset_dir, "proto_processed_test.h5sd"))
sdata_combined = eu.dl.concat([sdata_leaf, sdata_proto], keys=["leaf", "proto"])

# Get test set predictions for each model

## Leaf model

In [None]:
# Evaluate each leaf model on the test set
model_types = ["CNN", "Hybrid", "Jores21CNN"]
model_names = ["ssCNN", "ssHybrid", "Jores21CNN"]
trials = 5
for model_name, model_type in zip(model_names, model_types):
    for trial in range(1, trials+1):
        print(f"{model_name} trial {trial}")
        model_file = glob.glob(os.path.join(eu.settings.logging_dir, model_name, f"leaf_trial_{trial}", "checkpoints", "*"))[0]
        if model_type == "CNN":
            leaf_model = eu.models.CNN.load_from_checkpoint(model_file)
        elif model_type == "Hybrid":
            leaf_model = eu.models.Hybrid.load_from_checkpoint(model_file)
        elif model_type == "Jores21CNN":
            leaf_model = eu.models.Jores21CNN.load_from_checkpoint(model_file)
        eu.evaluate.predictions(
            leaf_model,
            sdata=sdata_leaf,
            target_keys="enrichment",
            name=model_name,
            version=f"leaf_trial_{trial}",
            file_label="test",
            prefix=f"{model_name}_trial_{trial}_"
        )
        del leaf_model
sdata_leaf.write_h5sd(os.path.join(eu.settings.output_dir, "leaf_test_predictions.h5sd"))

## Proto model

In [None]:
# Evaluate each proto model on the test set
model_types = ["CNN", "Hybrid", "Jores21CNN"]
model_names = ["ssCNN", "ssHybrid", "Jores21CNN"]
trials = 5
for model_name, model_type in zip(model_names, model_types):
    for trial in range(1, trials+1):
        print(f"{model_name} trial {trial}")
        model_file = glob.glob(os.path.join(eu.settings.logging_dir, model_name, f"proto_trial_{trial}", "checkpoints", "*"))[0]
        if model_type == "CNN":
            proto_model = eu.models.CNN.load_from_checkpoint(model_file)
        elif model_type == "Hybrid":
            proto_model = eu.models.Hybrid.load_from_checkpoint(model_file)
        elif model_type == "Jores21CNN":
            proto_model = eu.models.Jores21CNN.load_from_checkpoint(model_file)
        eu.evaluate.predictions(
            proto_model,
            sdata=sdata_proto,
            target_keys="enrichment",
            name=model_name,
            version=f"proto_trial_{trial}",
            file_label="test",
            prefix=f"{model_name}_trial_{trial}_"
        )
        del proto_model
sdata_proto.write_h5sd(os.path.join(eu.settings.output_dir, "proto_test_predictions.h5sd"))

## Combined model

In [None]:
# Evaluate each combined model on the test set 
model_types = ["CNN", "Hybrid", "Jores21CNN"]
model_names = ["ssCNN", "ssHybrid", "Jores21CNN"]
trials = 5
for model_name, model_type in zip(model_names, model_types):
    for trial in range(1, trials+1):
        print(f"{model_name} trial {trial}")
        model_file = glob.glob(os.path.join(eu.settings.logging_dir, model_name, f"combined_trial_{trial}", "checkpoints", "*"))[0]
        if model_type == "CNN":
            combined_model = eu.models.CNN.load_from_checkpoint(model_file)
        elif model_type == "Hybrid":
            combined_model = eu.models.Hybrid.load_from_checkpoint(model_file)
        elif model_type == "Jores21CNN":
            combined_model = eu.models.Jores21CNN.load_from_checkpoint(model_file)
        eu.evaluate.predictions(
            combined_model,
            sdata=sdata_combined,
            target_keys="enrichment",
            name=model_name,
            version=f"combined_trial_{trial}",
            file_label="test",
            prefix=f"{model_name}_trial_{trial}_"
        )
        del combined_model
sdata_combined.write_h5sd(os.path.join(eu.settings.output_dir, "combined_test_predictions.h5sd"))

---