# Jores et al 2021 Training 
**Authorship:**
Adam Klie (last updated: *06/08/2023*)
***
**Description:**
Notebook to perform simple training of models on the Jores et al (2021) dataset. You can also use the `jores21_training.py` script as well if you want to run it that way.
***

In [1]:
# General imports
import os
import sys
import torch
import numpy as np
import pandas as pd
from copy import deepcopy 
import pytorch_lightning

# EUGENe imports and settings
import eugene as eu
from eugene import dataload as dl
from eugene import models
from eugene import train
from eugene import settings
settings.dataset_dir = "/cellar/users/aklie/data/eugene/revision/jores21"
settings.output_dir = "/cellar/users/aklie/projects/ML4GLand/EUGENe_paper/output/jores21"
settings.logging_dir = "/cellar/users/aklie/projects/ML4GLand/EUGENe_paper/logs/jores21"
settings.config_dir = "/cellar/users/aklie/projects/ML4GLand/EUGENe_paper/configs/jores21"

# EUGENe packages
import seqdata as sd
import motifdata as md

# Print versions
print(f"Python version: {sys.version}")
print(f"NumPy version: {np.__version__}")
print(f"Pandas version: {pd.__version__}")
print(f"Eugene version: {eu.__version__}")
print(f"SeqData version: {sd.__version__}")
print(f"MotifData version: {md.__version__}")
print(f"PyTorch version: {torch.__version__}")
print(f"PyTorch Lightning version: {pytorch_lightning.__version__}")


Python version: 3.9.16 | packaged by conda-forge | (main, Feb  1 2023, 21:39:03) 
[GCC 11.3.0]
NumPy version: 1.23.5
Pandas version: 1.5.2
Eugene version: 0.0.8
SeqData version: 0.0.1
MotifData version: 0.0.1
PyTorch version: 2.0.0
PyTorch Lightning version: 2.0.0


# Load in the `leaf`, `proto` and `combined` `SeqData`s 

In [2]:
# Load in the preprocessed training data
sdata_leaf = sd.open_zarr(os.path.join(settings.dataset_dir, "jores21_leaf_train.zarr"))
sdata_proto = sd.open_zarr(os.path.join(settings.dataset_dir, "jores21_proto_train.zarr"))
sdata_combined = dl.concat_seqdatas([sdata_leaf, sdata_proto], ["leaf", "proto"])

# Load in PFMs to initialize the 1st layer of models

In [3]:
# Grab motifs
core_promoter_elements = md.read_meme(os.path.join(settings.dataset_dir, "CPEs.meme"))
tf_clusters = md.read_meme(os.path.join(settings.dataset_dir, "TF-clusters.meme"))

# Smush them together, make function in the future
all_motifs = deepcopy(core_promoter_elements)
for motif in tf_clusters:
    all_motifs.add_motif(motif)
all_motifs

MotifSet with 78 motifs

# Train models

In [4]:
# Function for instantiating a new randomly initialized model
def prep_new_model(
    config,
    seed
):
    # Instantiate the model
    model = models.load_config(config_path=config, seed=seed)
    
    # Initialize the model prior to conv filter initialization
    models.init_weights(model, initializer="kaiming_normal")

    # Initialize the conv filters
    if model.arch_name == "Jores21CNN":
        layer_name = "arch.biconv.kernels"
        list_index = 0
    elif model.arch_name in ["CNN", "Hybrid", "DeepSTARR"]:
        layer_name = "arch.conv1d_tower.layers.0"
        list_index = None
    models.init_motif_weights(
        model=model,
        layer_name=layer_name,
        list_index=list_index,
        initializer="xavier_uniform",
        motifs=all_motifs,
        convert_to_pwm=False,
        divide_by_bg=True,
        motif_align="left",
        kernel_align="left"
    )

    # Return the model
    return model 

# Test the instantiation of each model to make sure this is working properly
model = prep_new_model("cnn.yaml", seed=0)
model = prep_new_model("hybrid.yaml", seed=0)
model = prep_new_model("jores21_cnn.yaml", seed=0)
model = prep_new_model("deepstarr.yaml", seed=0)

[rank: 0] Global seed set to 0
[rank: 0] Global seed set to 0
[rank: 0] Global seed set to 0
[rank: 0] Global seed set to 0


In [None]:
# Train 5 models with 5 different random initializations -- NOTE: this is just configured for testing, run jores21_training.py for the full training
training_sets = {"leaf": sdata_leaf, "proto": sdata_proto, "combined": sdata_combined}
configs = ["cnn.yaml", "hybrid.yaml", "jores21_cnn.yaml", "deepstarr.yaml"]
trials = 1
for training_set in training_sets:
    for trial in range(1, trials+1):
        for config in configs:

            # Print the model name
            sdata = training_sets[training_set]
            model_name = config.split(".")[0]
            print(f"{training_set} {model_name} trial {trial}")

            # Initialize the model
            model = prep_new_model(config, seed=trial)

            # Fit the model
            train.fit_sequence_module(
                model,
                sdata,
                seq_key="ohe_seq",
                target_keys=["enrichment"],
                in_memory=True,
                train_key="train_val",
                epochs=5,
                batch_size=128,
                num_workers=4,
                prefetch_factor=2,
                drop_last=False,
                name=model_name,
                version=f"{training_set}_trial_{trial}",
                seq_transforms={"ohe_seq": lambda x: torch.tensor(x, dtype=torch.float32).transpose(1, 2), "target": lambda x: torch.tensor(x, dtype=torch.float32)},
                seed=trial
            )

            # Make room for the next model 
            del model

# DONE!

---

# Scratch

## Check file generation

In [9]:
# Check the logging directory for all the models
!tree -L 3 /cellar/users/aklie/projects/ML4GLand/EUGENe_paper/logs/fix_full/jores21

/cellar/users/aklie/projects/ML4GLand/EUGENe_paper/logs/fix_full/jores21
├── cnn
│   ├── combined_trial_1
│   │   ├── checkpoints
│   │   ├── events.out.tfevents.1685745874.carter-gpu-01.865078.40
│   │   └── hparams.yaml
│   ├── combined_trial_2
│   │   ├── checkpoints
│   │   ├── events.out.tfevents.1685747545.carter-gpu-01.865078.44
│   │   └── hparams.yaml
│   ├── combined_trial_3
│   │   ├── checkpoints
│   │   ├── events.out.tfevents.1685748986.carter-gpu-01.865078.48
│   │   └── hparams.yaml
│   ├── combined_trial_4
│   │   ├── checkpoints
│   │   ├── events.out.tfevents.1685750390.carter-gpu-01.865078.52
│   │   └── hparams.yaml
│   ├── combined_trial_5
│   │   ├── checkpoints
│   │   ├── events.out.tfevents.1685752222.carter-gpu-01.865078.56
│   │   └── hparams.yaml
│   ├── leaf_trial_1
│   │   ├── checkpoints
│   │   ├── events.out.tfevents.1685736231.carter-gpu-01.865078.0
│   │   └── hparams.yaml
│   ├── leaf_trial_2
│   │   ├── checkpoints
│   │   ├── events.out.tfevents.1

In [24]:
# Check that the logging directory has the correct number of models using Python
import glob
import os
import pandas as pd

# Get the list of models
model_dirs = glob.glob("/cellar/users/aklie/projects/ML4GLand/EUGENe_paper/logs/fix_full/jores21/*")
model_dirs = [x for x in model_dirs if os.path.isdir(x)]
for model_dir in model_dirs:
    model_type = model_dir.split("/")[-1]
    print(f"{model_type}: {len(glob.glob(os.path.join(model_dir, '*')))}")

    # Make sure their is a ckpt file in the checkpoint directory within each model directory
    num_ckpt = len(glob.glob(os.path.join(model_dir, "*", "checkpoints", "*")))
    print("  ", f"checkpoints: {num_ckpt}")

hybrid: 15
   checkpoints: 15
jores21_cnn: 15
   checkpoints: 15
deepstarr: 15
   checkpoints: 15
cnn: 15
   checkpoints: 15
