# Jores et al 2021 Training 
**Authorship:**
Adam Klie, *05/18/2023*
***
**Description:**
Notebook to perform simple training of models on the Jores et al (2021) dataset. You can also use the `jores21_training.py` script as well if you want to run it that way.
***

In [None]:
if 'autoreload' not in get_ipython().extension_manager.loaded:
    %load_ext autoreload
%autoreload 2

In [1]:
import os
import logging
import torch
from copy import deepcopy 
import numpy as np
import pandas as pd
import seqdata as sd
import motifdata as md
from eugene import settings

In [2]:
settings.dataset_dir = "/cellar/users/aklie/data/eugene/jores21"
settings.output_dir = "/cellar/users/aklie/projects/ML4GLand/EUGENe_paper/output/revision/jores21"
settings.logging_dir = "/cellar/users/aklie/projects/ML4GLand/EUGENe_paper/logs/revision/jores21"
settings.config_dir = "/cellar/users/aklie/projects/ML4GLand/EUGENe_paper/configs/jores21"
settings.verbosity = logging.ERROR

# Load in the `leaf`, `proto` and `combined` `SeqData`s 

In [3]:
# Load in the preprocessed training data
sdata_leaf = sd.read(os.path.join(settings.dataset_dir, "leaf_processed_train.h5sd"))
sdata_proto = sd.read(os.path.join(settings.dataset_dir, "proto_processed_train.h5sd"))
sdata_combined = sd.concat([sdata_leaf, sdata_proto], keys=["leaf", "proto"])
sdata_leaf, sdata_proto, sdata_combined

(SeqData object with = 65004 seqs
 seqs = (65004,)
 names = (65004,)
 rev_seqs = None
 ohe_seqs = (65004, 4, 170)
 ohe_rev_seqs = None
 seqs_annot: 'GC', 'barcodes', 'batch', 'chromosome', 'end', 'enrichment', 'gene', 'mutations', 'set', 'sp', 'start', 'strand', 'train_val', 'type'
 pos_annot: None
 seqsm: None
 uns: None,
 SeqData object with = 68213 seqs
 seqs = (68213,)
 names = (68213,)
 rev_seqs = None
 ohe_seqs = (68213, 4, 170)
 ohe_rev_seqs = None
 seqs_annot: 'GC', 'barcodes', 'batch', 'chromosome', 'end', 'enrichment', 'gene', 'mutations', 'set', 'sp', 'start', 'strand', 'train_val', 'type'
 pos_annot: None
 seqsm: None
 uns: None,
 SeqData object with = 133217 seqs
 seqs = (133217,)
 names = (133217,)
 rev_seqs = None
 ohe_seqs = (133217, 4, 170)
 ohe_rev_seqs = None
 seqs_annot: 'GC', 'barcodes', 'batch', 'chromosome', 'end', 'enrichment', 'gene', 'mutations', 'set', 'sp', 'start', 'strand', 'train_val', 'type'
 pos_annot: None
 seqsm: None
 uns: None)

In [4]:
# Grab motifs
core_promoter_elements = md.read_meme(os.path.join(settings.dataset_dir, "CPEs.meme"))
tf_clusters = md.read_meme(os.path.join(settings.dataset_dir, "TF-clusters.meme"))

# Smush them together, make function in the future
all_motifs = deepcopy(core_promoter_elements)
for motif in tf_clusters:
    all_motifs.add_motif(motif)
all_motifs

MotifSet with 78 motifs

In [5]:
import yaml
import importlib

In [6]:
from eugene import models
importlib.reload(models)

<module 'eugene.models' from '/cellar/users/aklie/projects/ML4GLand/EUGENe/eugene/models/__init__.py'>

In [7]:
def load_config(config_path):
    with open(config_path, "r") as f:
        config = yaml.load(f, Loader=yaml.FullLoader)
    module_name = config.pop("module")
    model_params = config.pop("model")
    arch_name = model_params["arch_name"]
    arch = model_params["arch"]
    model_type = getattr(importlib.import_module("eugene.models"), arch_name)
    model = model_type(**arch)
    module_type = getattr(importlib.import_module("eugene.models"), module_name)
    module = module_type(model, **config)
    return module

In [24]:
model = load_config("/cellar/users/aklie/projects/ML4GLand/EUGENe_paper/configs/jores21/cnn.yaml")

In [25]:
models.init_motif_weights(
    model,
    layer_name="arch.conv1d_tower.layers.0",
    motifs=all_motifs
)

In [26]:
# Double check that layer
from eugene.models.base._utils import get_layer
get_layer(model, 'arch.conv1d_tower.layers.0').weight[0].T

tensor([[0.5100, 1.5060, 0.4780, 1.5060],
        [0.6300, 1.5940, 0.7960, 0.9820],
        [0.9960, 1.2120, 0.7880, 1.0040],
        [0.4940, 2.6200, 0.3020, 0.5820],
        [0.0400, 0.0080, 0.0080, 3.9440],
        [3.8720, 0.0000, 0.0000, 0.1280],
        [0.0080, 0.0560, 0.0240, 3.9120],
        [3.9680, 0.0000, 0.0080, 0.0240]], grad_fn=<PermuteBackward0>)

In [27]:
model.input_len

170

In [28]:
sdata_leaf.ohe_seqs[:128].shape

(128, 4, 170)

In [29]:
model.predict(sdata_leaf.ohe_seqs[:128])

HBox(children=(FloatProgress(value=0.0, description='Predicting on batches', max=1.0, style=ProgressStyle(desc…




tensor([[-0.2599],
        [-0.0426],
        [ 0.0548],
        [-0.2019],
        [-0.1531],
        [-0.1705],
        [-0.1595],
        [-0.0726],
        [-0.0935],
        [-0.2503],
        [-0.0025],
        [-0.1523],
        [-0.1488],
        [-0.1104],
        [-0.2665],
        [-0.1787],
        [-0.0803],
        [-0.0677],
        [-0.1033],
        [-0.0659],
        [ 0.0411],
        [-0.2203],
        [-0.0598],
        [-0.2831],
        [-0.0691],
        [-0.1327],
        [ 0.0020],
        [-0.1014],
        [-0.2648],
        [-0.1480],
        [-0.0881],
        [-0.1245],
        [-0.0588],
        [-0.2496],
        [-0.0493],
        [-0.2572],
        [-0.2992],
        [-0.1821],
        [-0.1141],
        [-0.1116],
        [-0.1799],
        [-0.2547],
        [-0.1192],
        [ 0.1012],
        [-0.0022],
        [-0.0492],
        [-0.0300],
        [-0.2739],
        [ 0.1164],
        [-0.1017],
        [-0.0298],
        [ 0.0466],
        [ 0.

In [30]:
from eugene import dataload as dl

In [31]:
train_idx = np.where(sdata_leaf["train_val"] == True)[0]
val_idx = np.where(sdata_leaf["train_val"] == False)[0]
train_sdataset = dl.SequenceDataset(
    seqs=sdata_leaf.ohe_seqs[train_idx],
    targets=sdata_leaf["enrichment"].values[train_idx]
)
val_sdataset = dl.SequenceDataset(
    seqs=sdata_leaf.ohe_seqs[val_idx],
    targets=sdata_leaf["enrichment"].values[val_idx]
)

In [33]:
train_sdataloader = train_sdataset.to_dataloader(batch_size=128, shuffle=True, num_workers=0, drop_last=True)
val_sdataloader = val_sdataset.to_dataloader(batch_size=128, shuffle=False, num_workers=0, drop_last=True)

In [34]:
from eugene import train

In [35]:
settings.logging_dir

PosixPath('/cellar/users/aklie/projects/ML4GLand/EUGENe_paper/logs/revision/jores21')

In [36]:
model.hparams

"arch":              CNN(
  (conv1d_tower): Conv1DTower(
    (layers): Sequential(
      (0): Conv1d(4, 256, kernel_size=(8,), stride=(1,), padding=valid)
      (1): ELU(alpha=1.0)
      (2): MaxPool1d(kernel_size=2, stride=1, padding=0, dilation=1, ceil_mode=False)
      (3): Dropout(p=0.3, inplace=False)
      (4): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (5): Conv1d(256, 256, kernel_size=(3,), stride=(1,), padding=valid)
      (6): ELU(alpha=1.0)
      (7): MaxPool1d(kernel_size=1, stride=1, padding=0, dilation=1, ceil_mode=False)
      (8): Dropout(p=0.3, inplace=False)
      (9): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (10): Conv1d(256, 256, kernel_size=(3,), stride=(1,), padding=valid)
      (11): ELU(alpha=1.0)
      (12): MaxPool1d(kernel_size=1, stride=1, padding=0, dilation=1, ceil_mode=False)
      (13): Dropout(p=0.3, inplace=False)
      (14): BatchNorm1d(256, eps=1e-05, momentum=0

In [38]:
train.fit(
    model=model, 
    train_dataloader=train_sdataloader,
    val_dataloader=val_sdataloader,
    gpus=1, 
    epochs=50,
    logger="tensorboard",
    log_dir=settings.logging_dir,
    name="grace",
    version="leaf"
)

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


No seed set


Set SLURM handle signals.

  | Name         | Type    | Params
-----------------------------------------
0 | arch         | CNN     | 3.0 M 
1 | train_metric | R2Score | 0     
2 | val_metric   | R2Score | 0     
3 | test_metric  | R2Score | 0     
-----------------------------------------
3.0 M     Trainable params
0         Non-trainable params
3.0 M     Total params
11.971    Total estimated model params size (MB)


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validation sanity check', layout=Layout…



HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Training', layout=Layout(flex='2'), max…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…




In [None]:
from eugene import settings

In [39]:
pl.training_summary(os.path.join(settings.logging_dir, "grace", "leaf"))

In [23]:
model.optimizer.state_dict

<function torch.optim.optimizer.Optimizer.state_dict(self)>

In [None]:
from pytorch_lightning import seed_everything

# Function for instantiating a new randomly initialized model
def prep_new_model(
    seed,
    arch,
    config
):
    # Instantiate the model
    model = eu.models.load_config(
        arch=arch,
        model_config=config
    )
    
    seed_everything(seed)
    
    # Initialize the model prior to conv filter initialization
    eu.models.init_weights(model)

    # Initialize the conv filters
    if arch == "Jores21CNN":
        module_name, module_number, kernel_name, kernel_number = "biconv", None, "kernels", 0, 
    elif arch in ["CNN", "Hybrid"]:
        module_name, module_number, kernel_name, kernel_number = "convnet", 0, None, None
    eu.models.init_from_motifs(
        model, 
        all_motifs, 
        module_name=module_name,
        module_number=module_number,
        kernel_name=kernel_name,
        kernel_number=kernel_number
    )

    # Return the model
    return model 

In [None]:
# Instantiate a test model to make sure this is working properly
test_model = prep_new_model(0, "Jores21CNN", os.path.join(eu.settings.config_dir, "Jores21CNN.yaml"))

## Train a prototype model 

In [None]:
# Train 5 models with 5 different random initializations
model_types = ["CNN", "Hybrid", "Jores21CNN"]
model_names = ["ssCNN", "ssHybrid", "Jores21CNN"]
trials = 1
for model_name, model_type in zip(model_names, model_types):
    for trial in range(1, trials+1):
        print(f"{model_name} trial {trial}")

        # Initialize the model
        leaf_model = prep_new_model(
            arch=model_type, 
            config=os.path.join(eu.settings.config_dir, f"{model_name}.yaml"),
            seed=trial
        )

        # Train the model
        eu.train.fit(
            model=leaf_model, 
            sdata=sdata_leaf, 
            gpus=1, 
            target_keys="enrichment",
            train_key="train_val",
            epochs=1,
            batch_size=128,
            num_workers=0,
            name=model_name,
            seed=trial,
            version=f"leaf_trial_{trial}",
            weights_summary=None,
            verbosity=logging.ERROR
        )

        # Get predictions on the training data
        eu.evaluate.train_val_predictions(
            leaf_model,
            sdata=sdata_leaf, 
            target_keys="enrichment",
            train_key="train_val",
            name=model_name,
            version=f"leaf_trial_{trial}",
            prefix=f"{model_name}_trial_{trial}_"
        )

        # Make room for the next model 
        del leaf_model

# Save train and validation predictions
sdata_leaf.write_h5sd(os.path.join(eu.settings.output_dir, "leaf_train_predictions.h5sd"))

# Train proto models

In [None]:
# Train 5 models with 5 different random initializations
model_types = ["CNN", "Hybrid", "Jores21CNN"]
model_names = ["ssCNN", "ssHybrid", "Jores21CNN"]
trials = 5
for model_name, model_type in zip(model_names, model_types):
    for trial in range(1, trials+1):
        print(f"{model_name} trial {trial}")

        # Initialize the model
        proto_model = prep_new_model(
            arch=model_type, 
            config=os.path.join(eu.settings.config_dir, f"{model_name}.yaml"),
            seed=trial
        )
        
        # Train the model
        eu.train.fit(
            model=proto_model, 
            sdata=sdata_proto, 
            gpus=1, 
            target_keys="enrichment",
            train_key="train_val",
            epochs=25,
            batch_size=128,
            num_workers=4,
            name=model_name,
            seed=trial,
            version=f"proto_trial_{trial}",
            verbosity=logging.ERROR,
            weights_summary=None
        )

        # Get predictions on the training data
        eu.evaluate.train_val_predictions(
            proto_model,
            sdata=sdata_proto, 
            target="enrichment",
            train_key="train_val",
            name=model_name,
            version=f"proto_trial_{trial}",
            prefix=f"{model_name}_trial_{trial}_"
        )

        # Make room for the next model
        del proto_model

# Save train and validation predictions        
sdata_proto.write_h5sd(os.path.join(eu.settings.output_dir, "proto_train_predictions.h5sd"))

# Train combined models

In [None]:
# Train 5 models with 5 different random initializations
model_types = ["CNN", "Hybrid", "Jores21CNN"]
model_names = ["ssCNN", "ssHybrid", "Jores21CNN"]
trials = 5
for model_name, model_type in zip(model_names, model_types):
    for trial in range(1, trials+1):
        print(f"{model_name} trial {trial}")

        # Initialize the model
        combined_model = prep_new_model(
            arch=model_type, 
            config=os.path.join(eu.settings.config_dir, f"{model_name}.yaml"),
            seed=trial
        )

        # Train the model
        eu.train.fit(
            model=combined_model, 
            sdata=sdata_combined, 
            gpus=1, 
            target_keys="enrichment",
            train_key="train_val",
            epochs=25,
            batch_size=128,
            num_workers=4,
            name=model_name,
            seed=trial,
            version=f"combined_trial_{trial}",
            verbosity=logging.ERROR,
            weights_summary=None
        )

        # Get predictions on the training data
        eu.evaluate.train_val_predictions(
            combined_model,
            sdata=sdata_combined, 
            target="enrichment",
            train_key="train_val",
            name=model_name,
            version=f"combined_trial_{trial}",
            prefix=f"{model_name}_trial_{trial}_"
        )

        # Make room for the next model
        del combined_model

# Save train and validation predictions
sdata_combined.write_h5sd(os.path.join(eu.settings.output_dir, "combined_train_predictions.h5sd"))

---