# Jores et al 2021 Evaluation 
**Authorship:**
Adam Klie (last updated: *06/08/2023*)
***
**Description:**
Notebook to perform a brief evaluation of trained models on the Jores et al (2021) dataset.
***

In [None]:
# General imports
import os
import sys
import glob
import torch
import numpy as np

# EUGENe imports and settings
import eugene as eu
from eugene import dataload as dl
from eugene import models
from eugene import evaluate
from eugene import settings
settings.dataset_dir = "/cellar/users/aklie/data/eugene/revision/jores21"
settings.output_dir = "/cellar/users/aklie/projects/ML4GLand/EUGENe_paper/output/fix_full/jores21"
settings.logging_dir = "/cellar/users/aklie/projects/ML4GLand/EUGENe_paper/logs/fix_full/jores21"
settings.config_dir = "/cellar/users/aklie/projects/ML4GLand/EUGENe_paper/configs/jores21"

# EUGENe packages
import seqdata as sd

# For illustrator editing
import matplotlib
matplotlib.rcParams['pdf.fonttype'] = 42
matplotlib.rcParams['ps.fonttype'] = 42

# Print versions
print(f"Python version: {sys.version}")
print(f"NumPy version: {np.__version__}")
print(f"Eugene version: {eu.__version__}")
print(f"SeqData version: {sd.__version__}")
print(f"PyTorch version: {torch.__version__}")

# Load in the `leaf`, `proto` and `combined` test `SeqData`s 

In [None]:
# Load in the preprocessed test set data
sdata_leaf = sd.open_zarr(os.path.join(settings.dataset_dir, "jores21_leaf_test.zarr"))
sdata_proto = sd.open_zarr(os.path.join(settings.dataset_dir, "jores21_proto_test.zarr"))
sdata_combined = dl.concat_seqdatas([sdata_leaf, sdata_proto], ["leaf", "proto"])

# Get test set predictions for each model

In [None]:
# Predict with each model that was trained
test_sets = {"leaf": sdata_leaf, "proto": sdata_proto, "combined": sdata_combined}
configs = ["cnn.yaml", "hybrid.yaml", "jores21_cnn.yaml", "deepstarr.yaml"]
trials = 5
for test_set in test_sets:
    
    # Grab the current test set
    sdata = test_sets[test_set]

    # Make an output directory for this dataset if it doesn't exist
    if not os.path.exists(os.path.join(settings.output_dir, test_set)):
        os.makedirs(os.path.join(settings.output_dir, test_set))

    # Iterate over the models
    for config in configs:
        model_name = config.split(".")[0]

        # Iterate over the trials
        for trial in range(1, trials+1):
        
            # Print the model name
            print(f"{test_set} {model_name} trial {trial}")

            # Grab the best model from that training run
            model_file = glob.glob(os.path.join(settings.logging_dir, model_name, f"{test_set}_trial_{trial}", "checkpoints", "*"))[0]
            model = models.load_config(config_path=config)
            best_model = models.SequenceModule.load_from_checkpoint(model_file, arch=model.arch)
            evaluate.predictions_sequence_module(
                model=best_model,
                sdata=sdata,
                seq_var="ohe_seq",
                target_vars="enrichment",
                gpus=1,
                batch_size=2048,
                num_workers=4,
                prefetch_factor=2,
                in_memory=True,
                transforms={"ohe_seq": lambda x: torch.tensor(x, dtype=torch.float32).transpose(1, 2), "target": lambda x: torch.tensor(x, dtype=torch.float32)},
                file_label="test",
                name=model_name,
                version=f"{test_set}_trial_{trial}",
                prefix=f"{model_name}_trial_{trial}_"
            )

    # Save the predictions
    pred_vars = [k for k in sdata.data_vars.keys() if "predictions" in k]
    target_vars = ["enrichment"]
    sdata[["id", *target_vars, *pred_vars]].to_dataframe().to_csv(os.path.join(settings.output_dir, test_set, f"jores21_{test_set}_test_predictions.tsv"), sep="\t", index=False)
    sd.to_zarr(sdata, os.path.join(settings.output_dir, test_set, f"jores21_{test_set}_test_predictions.zarr"), mode="w")

# DONE!

---

# Scratch

In [None]:
for zarr in ["jores21_leaf_test_predictions.zarr", "jores21_proto_test_predictions.zarr", "jores21_combined_test_predictions.zarr"]:
    print(zarr)
    system = zarr.split("_")[1]
    sdata = sd.open_zarr(os.path.join(settings.output_dir, system, zarr))
    print(sdata.dims["_sequence"])
    print(np.unique(sdata["set"].values, return_counts=True))
    if "train_val" in sdata.data_vars:
        print(np.unique(sdata["train_val"].values, return_counts=True))
    else:
        print("No train_val column found")
    print(np.unique(sdata["sp"].values, return_counts=True))
    print(sdata["id"].values[:5])
    print(np.unique(sdata["batch"].values, return_counts=True))
    print(f"Found {len(sdata[pred_vars].data_vars.keys())} predictions columns")