# Kopp et al 2021 Evaluation 
**Authorship:**
Adam Klie, *08/12/2022*
***
**Description:**
Notebook to perform a brief evaluation of trained models on the Kopp et al (2021) dataset.
***

In [None]:
# General imports
import os
import sys
import glob
import torch
import numpy as np

# EUGENe imports and settings
import eugene as eu
from eugene import dataload as dl
from eugene import models
from eugene import evaluate
from eugene import settings
settings.dataset_dir = "/cellar/users/aklie/data/eugene/revision/kopp21"
settings.output_dir = "/cellar/users/aklie/projects/ML4GLand/EUGENe_paper/output/revision/kopp21"
settings.logging_dir = "/cellar/users/aklie/projects/ML4GLand/EUGENe_paper/logs/revision/kopp21"
settings.config_dir = "/cellar/users/aklie/projects/ML4GLand/EUGENe_paper/configs/kopp21"

# EUGENe packages
import seqdata as sd

# kopp21 helpers
sys.path.append("/cellar/users/aklie/projects/ML4GLand/EUGENe_paper/scripts/kopp21")
from kopp21_helpers import load_checkpoint_from_arch_config

# For illustrator editing
import matplotlib
matplotlib.rcParams['pdf.fonttype'] = 42
matplotlib.rcParams['ps.fonttype'] = 42

# Print versions
print(f"Python version: {sys.version}")
print(f"NumPy version: {np.__version__}")
print(f"Eugene version: {eu.__version__}")
print(f"SeqData version: {sd.__version__}")
print(f"PyTorch version: {torch.__version__}")

# Load in the test `SeqData`(s)

In [None]:
sdata = sd.open_zarr(os.path.join(settings.dataset_dir, 'kopp21_test.zarr'))
sdata

# Get test set predictions for each model

In [None]:
from itertools import product
configs = ["dsfcn.yaml", "dscnn.yaml", "dshybrid.yaml", "kopp21_cnn.yaml"]
trials = 5

for config, trial in product(configs, range(1, trials+1)):
    model_name = config.split('.')[0]
    print(model_name)

    # Load the model
    model_file = glob.glob(os.path.join(settings.logging_dir, model_name, f"trial_{trial}", "checkpoints", "*"))[0]
    if "ds" in model_name or model_name == "kopp21_dscnn_nn":
        if "hybrid" in model_name:
            arch_name = "dsHybrid"
        else:
            arch_name = model_name[:2] + model_name[2:].upper()
        best_model = load_checkpoint_from_arch_config(
            ckpt_path=model_file,
            config_path=config,
            arch_name=arch_name
        )
    else:
        model = models.load_config(config_path=config)
        best_model = models.SequenceModule.load_from_checkpoint(model_file, arch=model.arch)
    
    # Set-up transforms
    transforms = {
        "target": lambda x: torch.tensor(x, dtype=torch.float32),
        "ohe_seq": lambda x: torch.tensor(x, dtype=torch.float32).swapaxes(1, 2)
    }
    
    # Evaluate
    evaluate.predictions_sequence_module(
        model=best_model,
        sdata=sdata,
        seq_var="ohe_seq",
        target_vars="target",
        gpus=1,
        batch_size=2048,
        num_workers=4,
        prefetch_factor=2,
        in_memory=True,
        transforms=transforms,
        file_label="test",
        name=model_name,
        version=f"trial_{trial}",
        prefix=f"{model_name}_trial_{trial}_"
    )

# Save the predictions
pred_vars = [k for k in sdata.data_vars.keys() if "predictions" in k]
(
    sdata[['chrom', 'chromStart', 'chromEnd', 'target', *pred_vars]]
    .to_dataframe()
    .to_csv(os.path.join(settings.output_dir, f"test_predictions_all.tsv"), sep="\t", index=False)
)
sd.to_zarr(sdata, os.path.join(settings.output_dir, f"test_predictions_all.zarr"), mode="w")

# DONE!

---

# Scratch