# Jores et al 2021 Training 
**Authorship:**
Adam Klie, *08/11/2022*
***
**Description:**
Notebook to perform simple training of models on the Jores et al (2021) dataset. You can also use the `jores21_training.py` script as well if you want to run it that way.
***

In [None]:
if 'autoreload' not in get_ipython().extension_manager.loaded:
    %load_ext autoreload
%autoreload 2

import os
import logging
import torch
import numpy as np
import pandas as pd
import eugene as eu

In [None]:
eu.settings.dataset_dir = "/cellar/users/aklie/data/eugene/jores21"
eu.settings.output_dir = "/cellar/users/aklie/projects/EUGENe/EUGENe_paper/output/jores21"
eu.settings.logging_dir = "/cellar/users/aklie/projects/EUGENe/EUGENe_paper/logs/jores21"
eu.settings.config_dir = "/cellar/users/aklie/projects/EUGENe/EUGENe_paper/configs/jores21"
eu.settings.verbosity = logging.ERROR

# Load in the `leaf`, `proto` and `combined` `SeqData`s 

In [None]:
# Load in the preprocessed training data
sdata_leaf = eu.dl.read(os.path.join(eu.settings.dataset_dir, "leaf_processed_train.h5sd"))
sdata_proto = eu.dl.read(os.path.join(eu.settings.dataset_dir, "proto_processed_train.h5sd"))
sdata_combined = eu.dl.concat([sdata_leaf, sdata_proto], keys=["leaf", "proto"])
sdata_leaf, sdata_proto, sdata_combined

In [None]:
# Grab initialization motifs
core_promoter_elements = eu.dl.motif.MinimalMEME(os.path.join(eu.settings.dataset_dir, 'CPEs.meme'))
tf_groups = eu.dl.motif.MinimalMEME(os.path.join(eu.settings.dataset_dir, 'TF-clusters.meme'))
all_motifs = {**core_promoter_elements.motifs, **tf_groups.motifs}
len(all_motifs)

In [None]:
from pytorch_lightning import seed_everything

# Function for instantiating a new randomly initialized model
def prep_new_model(
    seed,
    arch,
    config
):
    # Instantiate the model
    model = eu.models.load_config(
        arch=arch,
        model_config=config
    )
    
    seed_everything(seed)
    
    # Initialize the model prior to conv filter initialization
    eu.models.init_weights(model)

    # Initialize the conv filters
    if arch == "Jores21CNN":
        module_name, module_number, kernel_name, kernel_number = "biconv", None, "kernels", 0, 
    elif arch in ["CNN", "Hybrid"]:
        module_name, module_number, kernel_name, kernel_number = "convnet", 0, None, None
    eu.models.init_from_motifs(
        model, 
        all_motifs, 
        module_name=module_name,
        module_number=module_number,
        kernel_name=kernel_name,
        kernel_number=kernel_number
    )

    # Return the model
    return model 

In [None]:
# Instantiate a test model to make sure this is working properly
test_model = prep_new_model(0, "Jores21CNN", os.path.join(eu.settings.config_dir, "Jores21CNN.yaml"))

## Train a prototype model 

In [None]:
# Train 5 models with 5 different random initializations
model_types = ["CNN", "Hybrid", "Jores21CNN"]
model_names = ["ssCNN", "ssHybrid", "Jores21CNN"]
trials = 1
for model_name, model_type in zip(model_names, model_types):
    for trial in range(1, trials+1):
        print(f"{model_name} trial {trial}")

        # Initialize the model
        leaf_model = prep_new_model(
            arch=model_type, 
            config=os.path.join(eu.settings.config_dir, f"{model_name}.yaml"),
            seed=trial
        )

        # Train the model
        eu.train.fit(
            model=leaf_model, 
            sdata=sdata_leaf, 
            gpus=1, 
            target_keys="enrichment",
            train_key="train_val",
            epochs=1,
            batch_size=128,
            num_workers=0,
            name=model_name,
            seed=trial,
            version=f"leaf_trial_{trial}",
            weights_summary=None,
            verbosity=logging.ERROR
        )

        # Get predictions on the training data
        eu.evaluate.train_val_predictions(
            leaf_model,
            sdata=sdata_leaf, 
            target_keys="enrichment",
            train_key="train_val",
            name=model_name,
            version=f"leaf_trial_{trial}",
            prefix=f"{model_name}_trial_{trial}_"
        )

        # Make room for the next model 
        del leaf_model

# Save train and validation predictions
sdata_leaf.write_h5sd(os.path.join(eu.settings.output_dir, "leaf_train_predictions.h5sd"))

# Train proto models

In [None]:
# Train 5 models with 5 different random initializations
model_types = ["CNN", "Hybrid", "Jores21CNN"]
model_names = ["ssCNN", "ssHybrid", "Jores21CNN"]
trials = 5
for model_name, model_type in zip(model_names, model_types):
    for trial in range(1, trials+1):
        print(f"{model_name} trial {trial}")

        # Initialize the model
        proto_model = prep_new_model(
            arch=model_type, 
            config=os.path.join(eu.settings.config_dir, f"{model_name}.yaml"),
            seed=trial
        )
        
        # Train the model
        eu.train.fit(
            model=proto_model, 
            sdata=sdata_proto, 
            gpus=1, 
            target_keys="enrichment",
            train_key="train_val",
            epochs=25,
            batch_size=128,
            num_workers=4,
            name=model_name,
            seed=trial,
            version=f"proto_trial_{trial}",
            verbosity=logging.ERROR,
            weights_summary=None
        )

        # Get predictions on the training data
        eu.evaluate.train_val_predictions(
            proto_model,
            sdata=sdata_proto, 
            target="enrichment",
            train_key="train_val",
            name=model_name,
            version=f"proto_trial_{trial}",
            prefix=f"{model_name}_trial_{trial}_"
        )

        # Make room for the next model
        del proto_model

# Save train and validation predictions        
sdata_proto.write_h5sd(os.path.join(eu.settings.output_dir, "proto_train_predictions.h5sd"))

# Train combined models

In [None]:
# Train 5 models with 5 different random initializations
model_types = ["CNN", "Hybrid", "Jores21CNN"]
model_names = ["ssCNN", "ssHybrid", "Jores21CNN"]
trials = 5
for model_name, model_type in zip(model_names, model_types):
    for trial in range(1, trials+1):
        print(f"{model_name} trial {trial}")

        # Initialize the model
        combined_model = prep_new_model(
            arch=model_type, 
            config=os.path.join(eu.settings.config_dir, f"{model_name}.yaml"),
            seed=trial
        )

        # Train the model
        eu.train.fit(
            model=combined_model, 
            sdata=sdata_combined, 
            gpus=1, 
            target_keys="enrichment",
            train_key="train_val",
            epochs=25,
            batch_size=128,
            num_workers=4,
            name=model_name,
            seed=trial,
            version=f"combined_trial_{trial}",
            verbosity=logging.ERROR,
            weights_summary=None
        )

        # Get predictions on the training data
        eu.evaluate.train_val_predictions(
            combined_model,
            sdata=sdata_combined, 
            target="enrichment",
            train_key="train_val",
            name=model_name,
            version=f"combined_trial_{trial}",
            prefix=f"{model_name}_trial_{trial}_"
        )

        # Make room for the next model
        del combined_model

# Save train and validation predictions
sdata_combined.write_h5sd(os.path.join(eu.settings.output_dir, "combined_train_predictions.h5sd"))

---