# Kopp et al 2021 Training 

**Authorship:**
Adam Klie, *08/07/2022*
***
**Description:**
Notebook to train models on the Kopp et al (2021) dataset. You can also use the `kopp21_training_{FCN|CNN|Hybrid|Kopp21CNN}.py` script as well if you want to run it that way.
***

In [None]:
if 'autoreload' not in get_ipython().extension_manager.loaded:
    %load_ext autoreload 
%autoreload 2

import os
import logging
import torch
import numpy as np
import pandas as pd
import eugene as eu

In [None]:
# Configure EUGENe 
eu.settings.dataset_dir = "/cellar/users/aklie/data/eugene/kopp21"
eu.settings.output_dir = "/cellar/users/aklie/projects/EUGENe/EUGENe_paper/output/kopp21"
eu.settings.logging_dir = "/cellar/users/aklie/projects/EUGENe/EUGENe_paper/logs/kopp21"
eu.settings.config_dir = "/cellar/users/aklie/projects/EUGENe/EUGENe_paper/configs/kopp21"
eu.settings.verbosity = logging.ERROR

# Load in the `SeqData`

In [None]:
sdata = eu.dl.read_h5sd(filename=os.path.join(eu.settings.dataset_dir, "jund_train_processed.h5sd"))
sdata

# Model instantiation and initialization 

In [None]:
from pytorch_lightning import seed_everything
def prep_new_model(
    seed,
    arch,
    config
):
    # Instantiate the model
    model = eu.models.load_config(
        arch=arch,
        model_config=config
    )

    # Set a seed
    seed_everything(seed)
    
    # Initialize the model prior to conv filter initialization
    eu.models.init_weights(model)

    # Return the model
    return model 

In [None]:
# Just make sure the model is taking in the proper data
model_types = ["FCN", "CNN", "RNN", "Hybrid", "Kopp21CNN"]
model_names = ["dsFCN", "dsCNN", "dsRNN", "dsHybrid", "Kopp21CNN"]
for model_name, model_type in zip(model_names, model_types):
    print(model_name, model_type)
    model = prep_new_model(0, model_type, os.path.join(eu.settings.config_dir, f"{model_name}.yaml"))
    if model_type == "RNN":
        sdataloader = sdata.to_dataset(transform_kwargs={"transpose": False}).to_dataloader() 
    else:
        sdataloader = sdata.to_dataset(transform_kwargs={"transpose": True}).to_dataloader()
    test_seqs = next(iter(sdataloader))
    print(model(test_seqs[1], test_seqs[2]).size())
    print()

In [None]:
# Train 5 models with 5 different random initializations
model_types = ["FCN", "CNN", "Hybrid", "Kopp21CNN"]
model_names = ["dsFCN", "dsCNN", "dsHybrid", "Kopp21CNN"]
trials = 5
for model_name, model_type in zip(model_names, model_types):
    for trial in range(1, trials+1):
        print(f"{model_name} trial {trial}")

        # Initialize the model
        model = prep_new_model(
            arch=model_type, 
            config=os.path.join(eu.settings.config_dir, f"{model_name}.yaml"),
            seed=trial
        )

        # Train the model
        eu.train.fit(
            model=model, 
            sdata=sdata, 
            gpus=1, 
            target_keys="target",
            train_key="train_val",
            epochs=30,
            early_stopping_metric="val_loss",
            early_stopping_patience=5,
            transform_kwargs=t_kwargs,
            batch_size=64,
            num_workers=4,
            name=model_name,
            seed=trial,
            version=f"trial_{trial}",
            verbosity=logging.ERROR
        )
        # Get predictions on the training data
        eu.settings.dl_num_workers = 0
        eu.evaluate.train_val_predictions(
            model,
            sdata=sdata, 
            target_keys="target",
            train_key="train_val",
            transform_kwargs=t_kwargs,
            name=model_name,
            version=f"trial_{trial}",
            prefix=f"{model_name}_trial_{trial}_"
        )
        del model 
sdata.write_h5sd(os.path.join(eu.settings.output_dir, "train_predictions.h5sd"))

---

## Scratch