# Kopp et al 2021 Training 
**Authorship:**
Adam Klie (last updated: *06/10/2023*)
***
**Description:**
Notebook to train models on the Kopp et al (2021) dataset. You can also use the `kopp21_training.py` script as well if you want to run it that way.
***

In [None]:
# General imports
import os
import sys
import torch
import numpy as np
import pandas as pd
from copy import deepcopy 
import pytorch_lightning
from itertools import product

# EUGENe imports and settings
import eugene as eu
from eugene import dataload as dl
from eugene import models, train, evaluate
from eugene.dataload._augment import RandomRC
from eugene import settings
settings.dataset_dir = "/cellar/users/aklie/data/eugene/revision/kopp21"
settings.output_dir = "/cellar/users/aklie/projects/ML4GLand/EUGENe_paper/output/kopp21"
settings.logging_dir = "/cellar/users/aklie/projects/ML4GLand/EUGENe_paper/logs/kopp21"
settings.config_dir = "/cellar/users/aklie/projects/ML4GLand/EUGENe_paper/configs/kopp21"

# EUGENe packages
import seqdata as sd

# Print versions
print(f"Python version: {sys.version}")
print(f"NumPy version: {np.__version__}")
print(f"Pandas version: {pd.__version__}")
print(f"Eugene version: {eu.__version__}")
print(f"SeqData version: {sd.__version__}")
print(f"PyTorch version: {torch.__version__}")
print(f"PyTorch Lightning version: {pytorch_lightning.__version__}")

# Load in the `SeqData`

In [None]:
sdata = sd.open_zarr(os.path.join(settings.dataset_dir, 'kopp21_train.zarr'))
sdata

# Model instantiation and initialization 

In [None]:
def prep_new_model(
    config,
    seed,
):
    # Instantiate the model
    model = eu.models.load_config(
        config_path=config,
        seed=seed
    )
    
    # Initialize the model prior to conv filter initialization
    eu.models.init_weights(model)

    # Return the model
    return model 

In [None]:
# Test the instantiation of each model to make sure this is working properly
model = prep_new_model("dscnn.yaml", seed=0)
model = prep_new_model("dshybrid.yaml", seed=0)
model = prep_new_model("dsfcn.yaml", seed=0)
model = prep_new_model("kopp21_cnn.yaml", seed=0)

In [None]:
configs = ["dsfcn.yaml", "dscnn.yaml", "dshybrid.yaml", "dskopp21_cnn.yaml"]
trials = 1

for config, trial in product(configs, range(1, trials+1)):
    model_name = config.split('.')[0]
    print(model_name)
    
    # Initialize the model
    model = prep_new_model(os.path.join(eu.settings.config_dir, config), seed=trial)
    
    transforms = {
        "target": lambda x: torch.tensor(x, dtype=torch.float32)
    }
    if (model_name != 'kopp21_cnn') and (not model_name.startswith('ds')):
        random_rc = RandomRC()
        def ohe_seq_transform(x):
            x = torch.tensor(x, dtype=torch.float32).swapaxes(1, 2)
            return random_rc(x)
        transforms["ohe_seq"] = ohe_seq_transform
    else:
        transforms["ohe_seq"] = lambda x: torch.tensor(x, dtype=torch.float32).swapaxes(1, 2)
        
    
    # Fit the model
    eu.train.fit_sequence_module(
        model,
        sdata,
        gpus=1,
        seq_var="ohe_seq",
        target_vars=["target"],
        in_memory=True,
        train_var="train_val",
        epochs=25,
        early_stopping_metric='val_loss_epoch',
        early_stopping_patience=5,
        batch_size=64,
        num_workers=4,
        prefetch_factor=2,
        drop_last=False,
        name=model_name,
        version=f"trial_{trial}",
        transforms=transforms,
        seed=trial,
    )
    
    # Evaluate the model on train and validation sets
    evaluate.train_val_predictions_sequence_module(
        model,
        sdata,
        seq_var="ohe_seq",
        target_vars=["target"],
        in_memory=True,
        train_var="train_val",
        batch_size=1024,
        num_workers=4,
        prefetch_factor=2,
        name=model_name,
        version=f"trial_{trial}",
        transforms=transforms,
        prefix=f"{model_name}_trial_{trial}_"
    )
    
    del model

# DONE!

---

# Scratch