# Ray et al 2013 Training 
**Authorship:**
Adam Klie, *08/31/2022*
***
**Description:**
Notebook to perform simple training of *single task* and *multitask* models on the Ray et al (2013) dataset.
Also take a look at the `ray13_training_ST.py` script for usage. The script was run because all 244 models took several hours to train.
***

In [None]:
if 'autoreload' not in get_ipython().extension_manager.loaded:
    %load_ext autoreload
%autoreload 2

import os
import logging
import torch
import numpy as np
import pandas as pd
import eugene as eu

In [None]:
eu.settings.dataset_dir = "/cellar/users/aklie/data/eugene/ray13"
eu.settings.output_dir = "/cellar/users/aklie/projects/EUGENe/EUGENe_paper/output/ray13"
eu.settings.logging_dir = "/cellar/users/aklie/projects/EUGENe/EUGENe_paper/logs/ray13"
eu.settings.config_dir = "/cellar/users/aklie/projects/EUGENe/EUGENe_paper/configs/ray13"
eu.settings.verbosity = logging.ERROR

# Load in the SetA training `SeqData`'s for single task and multi-task models

In [None]:
# Load in the training SetA processed data for single task and multitask models
sdata_training_ST = eu.dl.read_h5sd(os.path.join(eu.settings.dataset_dir, eu.settings.dataset_dir, "norm_setA_processed_ST.h5sd"))
sdata_training_MT = eu.dl.read_h5sd(os.path.join(eu.settings.dataset_dir, eu.settings.dataset_dir, "norm_setA_processed_MT.h5sd"))

In [None]:
sdata_training_ST, sdata_training_MT

In [None]:
# Grab the prediction columns for single task and multitask
target_mask_ST = sdata_training_ST.seqs_annot.columns.str.contains("RNCMPT")
target_cols_ST = sdata_training_ST.seqs_annot.columns[target_mask_ST]
target_mask_MT = sdata_training_MT.seqs_annot.columns.str.contains("RNCMPT")
target_cols_MT = sdata_training_MT.seqs_annot.columns[target_mask_MT]

In [None]:
target_cols_MT[215]

# Train single task models

In [None]:
# Instantiation function
from pytorch_lightning import seed_everything
def prep_new_model(
    seed,
    conv_dropout = 0,
    fc_dropout = 0,
    batchnorm = True
):
    model = eu.models.DeepBind(
        input_len=41, # Length of padded sequences
        output_dim=1, # Number of multitask outputs
        strand="ss",
        task="regression",
        conv_kwargs=dict(channels=[4, 16], conv_kernels=[16], dropout_rates=conv_dropout, batchnorm=batchnorm),
        mp_kwargs=dict(kernel_size=8),
        fc_kwargs=dict(hidden_dims=[32], dropout_rate=fc_dropout, batchnorm=batchnorm),
        optimizer="sgd",
        lr=0.0005,
        scheduler_patience=3
    )

    # Set a seed
    seed_everything(seed)
    
    # Initialize the model prior to conv filter initialization
    eu.models.init_weights(model)

    # Return the model
    return model 

In [None]:
# Test out a model before training
model = prep_new_model(0)
print(model.summary())
sdataloader = sdata_training_ST[:64].to_dataset().to_dataloader()
test_seqs = next(iter(sdataloader))
print(model(test_seqs[1], test_seqs[2]).size())

In [None]:
# Train a model on each target prediction!
for i, target_col in enumerate(target_cols_ST):
    print(f"Training DeepBind SingleTask model on {target_col}")

    # Initialize the model
    model = prep_new_model(seed=i, conv_dropout=0.5, fc_dropout=0.5, batchnorm=True)

    # Train the model
    eu.train.fit(
        model=model, 
        sdata=sdata_training_ST, 
        gpus=1, 
        target_keys=target_col,
        train_key="train_val",
        epochs=5,
        early_stopping_metric="val_loss",
        early_stopping_patience=3,
        batch_size=64,
        num_workers=0,
        name="DeepBind_ST",
        seed=i,
        version=target_col,
        verbosity=logging.ERROR
    )
    
    # Get predictions on the training data
    eu.evaluate.train_val_predictions(
        model,
        sdata=sdata_training_ST, 
        target_keys=target_col,
        train_key="train_val",
        batch_size=1024,
        num_workers=0,
        name="DeepBind_ST",
        suffix="_ST",
        version=target_col
    )
    del model 
#sdata_training_ST.write_h5sd(os.path.join(eu.settings.output_dir, "DeepBind_ST", "norm_training_predictions_ST.h5sd"))

# Train multi-task model

In [None]:
# Define the version for saving
model_version = 0

In [None]:
# Instantiate the model
conv_dropout = 0.25
fc_dropout = 0.25
batchnorm = True
model = eu.models.DeepBind(
    input_len=41, # Length of padded sequences
    output_dim=len(target_cols_MT), # Number of multitask outputs
    strand="ss", # Strand information to include, only forward strand
    task="regression", # Task type, regression in this case
    optimizer="adam", # Optimizer to use
    optimizer_kwargs={}, # Default optimizer kwargs
    lr=0.0005, # Learning rate to start with
    scheduler_patience=2, # Number of epochs to wait before reducing learning rate
    conv_kwargs=dict(channels=[4, 1024], conv_kernels=[16], dropout_rates=conv_dropout, batchnorm=batchnorm), # Convolutional layer kwargs
    fc_kwargs=dict(hidden_dims=[512], dropout_rate=fc_dropout, batchnorm=batchnorm) # Fully connected layer kwargs
)
model.summary(), model_version

In [None]:
# Train the model
eu.train.fit(
    model=model,
    sdata=sdata_training_MT,
    gpus=1,
    target_keys=target_cols_MT,
    train_key="train_val",
    epochs=100,
    early_stopping_metric="val_loss",
    early_stopping_patience=5,
    batch_size=1024,
    num_workers=0,
    name="DeepBind_MT",
    seed=42,
    version=f"v{model_version}",
    verbosity=logging.ERROR
)

# Get predictions on the training data
eu.evaluate.train_val_predictions(
    model,
    sdata=sdata_training_MT, 
    target_keys=target_cols_MT,
    train_key="train_val",
    batch_size=1024,
    num_workers=0,
    name="DeepBind_MT",
    suffix="_MT",
    version=f"v{model_version}"
)

In [None]:
# Save the predictions!
sdata_training_MT.write_h5sd(os.path.join(eu.settings.output_dir, "DeepBind_MT", f"norm_training_predictions_v{model_version}_MT.h5sd"))

In [None]:
# Double check we predicted on all the columns
np.sum(sdata_training_MT.seqs_annot.columns.str.contains("RNCMPT"))

In [None]:
# Move on to the next model version if training multiple
model_version = model_version + 1

---