# Jores et al 2021 GPU Utilization Analysis
**Authorship:**
David Laub (last updated: *07/19/2023*)
***
**Description:**
Notebook to analyze GPU utilization of the Jores et al (2021) dataset
***

In [None]:
# General imports
import os
import sys
import torch
import wandb
from tqdm.auto import tqdm
import pytorch_lightning as pl

# EUGENe imports and settings
import eugene as eu
import eugene.train
import eugene.models
from eugene import settings
settings.dataset_dir = "/cellar/users/aklie/data/eugene/revision/jores21"
settings.output_dir = "/cellar/users/dlaub/projects/ML4GLand/EUGENe_paper/output/jores21"
settings.logging_dir = "/cellar/users/dlaub/projects/ML4GLand/EUGENe_paper/logs/jores21"
settings.config_dir = "/cellar/users/dlaub/projects/ML4GLand/EUGENe_paper/configs/jores21"

# EUGENe packages
import seqdata as sd
import motifdata as md

# Other imports
from copy import deepcopy

# Print versions
print(f"Python version: {sys.version}")
print(f"PyTorch version: {torch.__version__}")
print(f"Eugene version: {eu.__version__}")
print(f"SeqData version: {sd.__version__}")
print(f"MotifData version: {md.__version__}")
print(f"WandB version: {wandb.__version__}")

In [None]:
# Login to wandb
wandb.login()

# Load the data

In [None]:
# Load data
sdata = (
    sd.open_zarr('/cellar/users/aklie/data/eugene/revision/jores21/jores21_leaf_train.zarr')
    .rename({"enrichment": 'target'})
)
sdata

In [None]:
# Load data into memory
sdata[['ohe_seq', 'target']].load();

In [None]:
# Build dataloader
train_sdata = sdata.sel(_sequence=(sdata.train_val == True).compute())

transforms = {
    "ohe_seq": lambda x: x.swapaxes(1, 2)
}

dl = sd.get_torch_dataloader(
    train_sdata,
    sample_dims='_sequence',
    variables=['ohe_seq', 'target'],
    transforms=transforms,
    batch_size=1024,
    pin_memory=True,
    drop_last=False,
)

In [None]:
# Skip motifs in first layer for now
# core_promoter_elements = md.read_meme("/cellar/users/aklie/data/eugene/revision/jores21/CPEs.meme")
# tf_clusters = md.read_meme("/cellar/users/aklie/data/eugene/revision/jores21/TF-clusters.meme")
# all_motifs = deepcopy(core_promoter_elements)
# for motif in tf_clusters:
#     all_motifs.add_motif(motif)

## PyTorch run

In [None]:
# Load model
model = eu.models.load_config(
    config_path='/cellar/users/aklie/projects/ML4GLand/EUGENe_paper/configs/jores21/hybrid.yaml',
    seed=0
).cuda()

eu.models.init_weights(model, initializer="kaiming_normal")

# eu.models.init_motif_weights(
#     model=model,
#     layer_name="arch.conv1d_tower.layers.0",
#     list_index=None,
#     initializer="xavier_uniform",
#     motifs=all_motifs,
#     convert_to_pwm=False,
#     divide_by_bg=True,
#     motif_align="left",
#     kernel_align="left"
# )

In [None]:
# Native PyTorch training loop
def train(model, dloader, optim, loss_fn):
    model.train()
    for batch in tqdm(dloader, position=1, leave=False):
        pred = model(batch['ohe_seq'].cuda())
        loss = loss_fn(pred.squeeze(), batch['target'].cuda())
        loss.backward()
        optim.step()
        optim.zero_grad()

In [None]:
# Train model
epochs = 25
optim = torch.optim.Adam(model.parameters())
with wandb.init(project='EUGENe GPU Utilization', name='Native PyTorch', tags=['Jores21']):
    for _ in tqdm(range(epochs), position=0):
        train(model, dl, optim, model.loss_fxn)

## Lightning run

In [None]:
model = eu.models.load_config(
    config_path='/cellar/users/aklie/projects/ML4GLand/EUGENe_paper/configs/jores21/hybrid.yaml',
    seed=0
).cuda()

eu.models.init_weights(model, initializer="kaiming_normal")

# eu.models.init_motif_weights(
#     model=model,
#     layer_name="arch.conv1d_tower.layers.0",
#     list_index=None,
#     initializer="xavier_uniform",
#     motifs=all_motifs,
#     convert_to_pwm=False,
#     divide_by_bg=True,
#     motif_align="left",
#     kernel_align="left"
# )

model.scheduler = None

In [None]:
# PyTorch Lightning training loop
trainer = pl.Trainer(max_epochs=25, logger=False)

In [None]:
# Train model
with wandb.init(project='EUGENe GPU Utilization', name='PL no val', tags=['Jores21']):
    trainer.fit(model, dl)

## EUGENe run

In [None]:
model = eu.models.load_config(
    config_path='/cellar/users/aklie/projects/ML4GLand/EUGENe_paper/configs/jores21/hybrid.yaml',
    seed=0
).cuda()

eu.models.init_weights(model, initializer="kaiming_normal")

# eu.models.init_motif_weights(
#     model=model,
#     layer_name="arch.conv1d_tower.layers.0",
#     list_index=None,
#     initializer="xavier_uniform",
#     motifs=all_motifs,
#     convert_to_pwm=False,
#     divide_by_bg=True,
#     motif_align="left",
#     kernel_align="left"
# )

model.scheduler = None

In [None]:
# Fit the model
with wandb.init(project='EUGENe GPU Utilization', name='EUGENe', tags=['Jores21']):
    eu.train.fit_sequence_module(
        model,
        sdata,
        seq_var="ohe_seq",
        target_vars=['target'],
        in_memory=True,
        train_var="train_val",
        epochs=25,
        batch_size=1024,
        drop_last=False,
        early_stopping_metric=None,
        model_checkpoint_monitor=None,
        transforms=transforms
    )

# DONE!

---

# Scratch