# Jores et al 2021 Training 
**Authorship:**
Adam Klie, *08/11/2022*
***
**Description:**
Notebook to perform simple training of models on the Jores et al dataset.
***

In [2]:
if 'autoreload' not in get_ipython().extension_manager.loaded:
    %load_ext autoreload
%autoreload 2

import os
import logging
import torch
import numpy as np
import pandas as pd
import eugene as eu

Global seed set to 13
2022-08-12 23:42:47.645958: W tensorflow/stream_executor/platform/default/dso_loader.cc:60] Could not load dynamic library 'libcudart.so.11.0'; dlerror: libcudart.so.11.0: cannot open shared object file: No such file or directory
2022-08-12 23:42:47.646062: I tensorflow/stream_executor/cuda/cudart_stub.cc:29] Ignore above cudart dlerror if you do not have a GPU set up on your machine.


In [3]:
eu.settings.dataset_dir = "../../_data/datasets/jores21"
eu.settings.output_dir = "../../_output/jores21"
eu.settings.logging_dir = "../../_logs/jores21"
eu.settings.config_dir = "../../_configs/jores21"
eu.settings.batch_size = 128
eu.settings.dl_num_workers = 4
#eu.settings.verbosity = logging.ERROR

# Load in the `leaf`, `proto` and `combined` `SeqData`s 

In [4]:
# Load in the preprocessed training data
sdata_leaf = eu.dl.read(os.path.join(eu.settings.dataset_dir, "leaf_processed_train.h5sd"))
sdata_proto = eu.dl.read(os.path.join(eu.settings.dataset_dir, "proto_processed_train.h5sd"))
sdata_combined = eu.dl.concat([sdata_leaf, sdata_proto], keys=["leaf", "proto"])
sdata_leaf, sdata_proto, sdata_combined

(SeqData object with = 65004 seqs
 seqs = (65004,)
 names = (65004,)
 rev_seqs = (65004,)
 ohe_seqs = (65004, 170, 4)
 ohe_rev_seqs = (65004, 170, 4)
 seqs_annot: 'GC', 'barcodes', 'batch', 'chromosome', 'end', 'enrichment', 'gene', 'mutations', 'set', 'sp', 'start', 'strand', 'train_val', 'type'
 pos_annot: None
 seqsm: None
 uns: None,
 SeqData object with = 68213 seqs
 seqs = (68213,)
 names = (68213,)
 rev_seqs = (68213,)
 ohe_seqs = (68213, 170, 4)
 ohe_rev_seqs = (68213, 170, 4)
 seqs_annot: 'GC', 'barcodes', 'batch', 'chromosome', 'end', 'enrichment', 'gene', 'mutations', 'set', 'sp', 'start', 'strand', 'train_val', 'type'
 pos_annot: None
 seqsm: None
 uns: None,
 SeqData object with = 133217 seqs
 seqs = (133217,)
 names = (133217,)
 rev_seqs = (133217,)
 ohe_seqs = (133217, 170, 4)
 ohe_rev_seqs = (133217, 170, 4)
 seqs_annot: 'GC', 'barcodes', 'batch', 'chromosome', 'end', 'enrichment', 'gene', 'mutations', 'set', 'sp', 'start', 'strand', 'train_val', 'type'
 pos_annot: None

In [5]:
# Grab initialization motifs
core_promoter_elements = eu.utils.MinimalMEME(os.path.join(eu.settings.dataset_dir, 'CPEs.meme'))
tf_groups = eu.utils.MinimalMEME(os.path.join(eu.settings.dataset_dir, 'TF-clusters.meme'))
all_motifs = {**core_promoter_elements.motifs, **tf_groups.motifs}
len(all_motifs)

78

In [6]:
def prep_new_model(
    seed,
    arch,
    config
):
    # Instantiate the model
    model = eu.models.load_config(
        arch=arch,
        model_config=config
    )

    # Initialize the model prior to conv filter initialization
    eu.models.base.init_weights(model)

    # Initialize the conv filters
    if arch == "Jores21CNN":
        layer_name, kernel_name, kernel_number, module_number = "biconv", "kernels", 0, None
    elif arch == "CNN":
        layer_name, kernel_name, kernel_number, module_number = "convnet", None, None, 0
    eu.models.init_from_motifs(
        model, 
        all_motifs, 
        layer_name=layer_name,
        kernel_name=kernel_name,
        kernel_number=kernel_number,
        module_number=module_number,
    )

    # Return the model
    return model 

# Train leaf models

In [None]:
# Train 5 models with 5 different random initializations
model_types = ["CNN", "Jores21CNN"]
model_names = ["ssCNN", "Jores21CNN"]
sdata_leaf_sub = sdata_leaf[:100]
trials = 5
for model_name, model_type in zip(model_names, model_types):
    for trial in range(1, trials+1):
        print(f"{model_name} trial {trial}")

        # Initialize the model
        leaf_model = prep_new_model(
            arch=model_type, 
            config=os.path.join(eu.settings.config_dir, f"{model_name}.yaml"),
            seed=13
        )
        # Train the model
        eu.train.fit(
            model=leaf_model, 
            sdata=sdata_leaf_sub, 
            #gpus=1, 
            target="enrichment",
            train_key="train_val",
            #epochs=25,
            epochs=1,
            name=model_name,
            version=f"test_leaf_trial_{trial}",
            seed=trial,
            verbosity=logging.ERROR
        )
        # Get predictions on the training data
        eu.predict.train_val_predictions(
            leaf_model,
            sdata=sdata_leaf_sub, 
            target="enrichment",
            train_key="train_val",
            name=model_name,
            version=f"test_leaf_trial_{trial}",
            prefix=f"{model_name}_test_trial_{trial}_"
        )
        del leaf_model
    sdata_leaf_sub.write_h5sd(os.path.join(eu.settings.output_dir, "leaf_sub_predictions.h5sd"))

# Train proto models

In [None]:
# Train 5 models with 5 different random initializations
model_types = ["CNN", "Jores21CNN"]
model_names = ["ssCNN", "Jores21CNN"]
sdata_proto_sub = sdata_proto[:100]
trials = 5
for model_name, model_type in zip(model_names, model_types):
    for trial in range(1, trials+1):
        print(f"{model_name} trial {trial}")

        # Initialize the model
        proto_model = prep_new_model(
            arch=model_type, 
            config=os.path.join(eu.settings.config_dir, f"{model_name}.yaml"),
            seed=13
        )
        # Train the model
        eu.train.fit(
            model=proto_model, 
            sdata=sdata_proto_sub, 
            #gpus=1, 
            target="enrichment",
            train_key="train_val",
            #epochs=25,
            epochs=1,
            name=model_name,
            version=f"test_proto_trial_{trial}",
            seed=trial,
            verbosity=logging.ERROR
        )
        # Get predictions on the training data
        eu.predict.train_val_predictions(
            proto_model,
            sdata=sdata_proto_sub, 
            target="enrichment",
            train_key="train_val",
            name=model_name,
            version=f"test_proto_trial_{trial}",
            prefix=f"{model_name}_test_trial_{trial}_"
        )
        del proto_model
    sdata_proto_sub.write_h5sd(os.path.join(eu.settings.output_dir, "proto_sub_predictions.h5sd"))

# Train combined models

In [None]:
# Train 5 models with 5 different random initializations
model_types = ["CNN", "Jores21CNN"]
model_names = ["ssCNN", "Jores21CNN"]
sdata_combined_sub = sdata_combined[:100]
trials = 5
for model_name, model_type in zip(model_names, model_types):
    for trial in range(1, trials+1):
        print(f"{model_name} trial {trial}")

        # Initialize the model
        combined_model = prep_new_model(
            arch=model_type, 
            config=os.path.join(eu.settings.config_dir, f"{model_name}.yaml"),
            seed=13
        )
        # Train the model
        eu.train.fit(
            model=combined_model, 
            sdata=sdata_combined_sub, 
            #gpus=1, 
            target="enrichment",
            train_key="train_val",
            #epochs=25,
            epochs=1,
            name=model_name,
            version=f"test_combined_trial_{trial}",
            seed=trial,
            verbosity=logging.ERROR
        )
        # Get predictions on the training data
        eu.predict.train_val_predictions(
            combined_model,
            sdata=sdata_combined_sub, 
            target="enrichment",
            train_key="train_val",
            name=model_name,
            version=f"test_combined_trial_{trial}",
            prefix=f"{model_name}_test_trial_{trial}_"
        )
        del combined_model
sdata_combined_sub.write_h5sd(os.path.join(eu.settings.output_dir, "combined_sub_predictions.h5sd"))

---

# Scratch

In [None]:
# Test conv kernel initialization, this needs a fix!
cnn = prep_new_model(seed=0, arch="CNN", config=os.path.join(eu.settings.config_dir, "ssCNN.yaml"))
jores = prep_new_model(seed=0, arch="Jores21CNN", config=os.path.join(eu.settings.config_dir, "Jores21CNN.yaml"))
torch.all(cnn.convnet.module[0].weight[0] == jores.biconv.kernels[0][0])