# Ray et al 2013 Training 
**Authorship:**
Adam Klie (last updated: *06/08/2023*)
***
**Description:**
Notebook to perform simple training of *single task* and *multitask* models on the Ray et al (2013) dataset.
Also take a look at the `ray13_training_ST.py` script for usage. The script was run because all 244 models took several hours to train.
***

In [None]:
# General imports
import os
import sys
import torch
import numpy as np
import pandas as pd
import pytorch_lightning

# EUGENe imports and settings
import eugene as eu
from eugene import models, train, evaluate, settings
from eugene.models import zoo
settings.dataset_dir = "/cellar/users/aklie/data/eugene/revision/ray13"
settings.output_dir = "/cellar/users/aklie/projects/ML4GLand/EUGENe_paper/output/revision/ray13"
settings.logging_dir = "/cellar/users/aklie/projects/ML4GLand/EUGENe_paper/logs/revision/ray13"
settings.config_dir = "/cellar/users/aklie/projects/ML4GLand/EUGENe_paper/configs/ray13"

# EUGENe packages
import seqdata as sd

# Print versions
print(f"Python version: {sys.version}")
print(f"NumPy version: {np.__version__}")
print(f"Pandas version: {pd.__version__}")
print(f"Eugene version: {eu.__version__}")
print(f"SeqData version: {sd.__version__}")
print(f"PyTorch version: {torch.__version__}")
print(f"PyTorch Lightning version: {pytorch_lightning.__version__}")


# Load in the SetA training `SeqData`'s for single task and multi-task models

In [None]:
# Load in the training SetA processed data for single task and multitask models
sdata_training_ST = sd.open_zarr(os.path.join(settings.dataset_dir, "norm_setA_ST.zarr"))
sdata_training_MT = sd.open_zarr(os.path.join(settings.dataset_dir, "norm_setA_MT.zarr"))

In [None]:
# Grab the prediction columns for single task and multitask
ST_vars = pd.Index(sdata_training_ST.data_vars.keys())
target_mask_ST = ST_vars.str.contains("RNCMPT")
target_cols_ST = ST_vars[target_mask_ST]
MT_vars = pd.Index(sdata_training_MT.data_vars.keys())
target_mask_MT = MT_vars.str.contains("RNCMPT")
target_cols_MT = MT_vars[target_mask_MT]

# Train single task models

In [None]:
# Instantiation function
from pytorch_lightning import seed_everything
def prep_new_model(
    seed,
    conv_dropout = 0,
    dense_dropout = 0,
    batchnorm = True
):
    # Set a seed
    seed_everything(seed)

    model = models.zoo.DeepBind(
        input_len=41, # Length of padded sequences
        output_dim=1, # Number of multitask outputs
        conv_kwargs=dict(input_channels=4, conv_channels=[16], conv_kernels=[16], dropout_rates=conv_dropout, batchnorm=batchnorm),
        dense_kwargs=dict(hidden_dims=[32], dropout_rates=dense_dropout, batchnorm=batchnorm),
    )
    
    # Initialize the model prior to conv filter initialization
    models.init_weights(model)

    module = models.SequenceModule(
        arch=model,
        task="regression",
        loss_fxn="mse",
        optimizer="adam",
        optimizer_lr=0.0005,
        scheduler_kwargs=dict(patience=2)
    )

    # Return the model
    return module

In [None]:
# Test out the function to grab a model
model = prep_new_model(seed=13, conv_dropout=0.5, dense_dropout=0.5, batchnorm=True)
model

In [None]:
# Train a model on each target prediction! NOTE: this is configured for testing purposes, see the ray13_training_ST.py script for the full training
for i, target_col in enumerate(target_cols_ST[:1]):
    print(f"Training DeepBind SingleTask model on {target_col}")

    # Initialize the model
    model = prep_new_model(seed=i, conv_dropout=0.5, dense_dropout=0.5, batchnorm=True)

    # Fit the model
    train.fit_sequence_module(
        model,
        sdata_training_ST,
        seq_var="ohe_seq",
        target_vars=target_col,
        in_memory=True,
        train_var="train_val",
        epochs=5,
        batch_size=100,
        num_workers=4,
        prefetch_factor=2,
        drop_last=False,
        early_stopping_patience=3,
        name="DeepBind_ST",
        version=target_col,
        transforms={"ohe_seq": lambda x: torch.tensor(x, dtype=torch.float32), "target": lambda x: torch.tensor(x, dtype=torch.float32)},
        seed=i
    )

    evaluate.train_val_predictions_sequence_module(
        model,
        sdata=sdata_training_ST,
        seq_var="ohe_seq",
        target_vars=target_col,
        in_memory=True,
        train_var="train_val",
        batch_size=1024,
        num_workers=4,
        prefetch_factor=2,
        name="DeepBind_ST",
        version=target_col,
        transforms={"ohe_seq": lambda x: torch.tensor(x, dtype=torch.float32), "target": lambda x: torch.tensor(x, dtype=torch.float32)},
        suffix="_ST"
    )

# Save the predictions!
sd.to_zarr(sdata_training_ST, os.path.join(settings.output_dir, f"norm_setA_predictions_ST.zarr"), mode="w")

# Train multi-task model

In [None]:
# Define the version for saving
model_version = 0

In [None]:
# Define the architecture to be trained
arch = models.zoo.DeepBind(
    input_len=41, # Length of padded sequences
    output_dim=len(target_cols_MT), # Number of multitask outputs
    conv_kwargs=dict(input_channels=4, conv_channels=[1024], conv_kernels=[16], dropout_rates=0.25, batchnorm=0.25),
    dense_kwargs=dict(hidden_dims=[512], dropout_rates=0.25, batchnorm=True),
)

# Initialize the model prior to conv filter initialization
models.init_weights(arch)

# Wrap the model in a SequenceModule
model = models.SequenceModule(
    arch=arch,
    task="regression",
    loss_fxn="mse",
    optimizer="adam",
    optimizer_lr=0.0005,
    scheduler_kwargs=dict(patience=2)
)

In [None]:
# Fit the model
train.fit_sequence_module(
    model,
    sdata_training_MT,
    seq_var="ohe_seq",
    target_vars=target_cols_MT,
    in_memory=True,
    train_var="train_val",
    epochs=100,
    batch_size=1024,
    num_workers=4,
    prefetch_factor=2,
    drop_last=False,
    early_stopping_patience=5,
    name="DeepBind_MT",
    version=f"v{model_version}",
    transforms={"ohe_seq": lambda x: torch.tensor(x, dtype=torch.float32), "target": lambda x: torch.tensor(x, dtype=torch.float32)},
    seed=42
)

# Get training predictions
evaluate.train_val_predictions_sequence_module(
    model,
    sdata=sdata_training_MT,
    seq_var="ohe_seq",
    target_vars=target_cols_MT,
    in_memory=True,
    train_var="train_val",
    batch_size=1024,
    num_workers=4,
    prefetch_factor=2,
    name="DeepBind_MT",
    version=f"v{model_version}",
    transforms={"ohe_seq": lambda x: torch.tensor(x, dtype=torch.float32), "target": lambda x: torch.tensor(x, dtype=torch.float32)},
    suffix="_MT"
)

# Save the predictions!
sd.to_zarr(sdata_training_MT, os.path.join(settings.output_dir, f"norm_setA_predictions_v{model_version}_MT.zarr"), mode="w")

# DONE!

---

# Scratch

In [None]:
# Double check we predicted on all the columns
for zarr in [f"norm_setA_predictions_ST.zarr", f"norm_setA_predictions_v{model_version}_MT.zarr"]:
    sdata = sd.open_zarr(os.path.join(settings.output_dir, zarr))
    keys = pd.Index(sdata.data_vars.keys())
    print(zarr, sdata.dims["_sequence"], len(sdata.data_vars))
    print(np.sum(keys.str.contains("RNCMPT")))