# Kopp et al 2021 GPU Utilization Analysis
**Authorship:**
David Laub (last updated: *07/19/2023*)
***
**Description:**
Notebook to analyze GPU utilization of the Kopp et al (2021) dataset
***

In [None]:
# General imports
import os
import sys
import torch
import wandb
from tqdm import tqdm

# EUGENe imports and settings
import eugene as eu
import eugene.train
import eugene.models
from eugene import settings
settings.dataset_dir = "/cellar/users/aklie/data/eugene/revision/kopp21"
settings.output_dir = "/cellar/users/dlaub/projects/ML4GLand/EUGENe_paper/output/kopp21"
settings.logging_dir = "/cellar/users/dlaub/projects/ML4GLand/EUGENe_paper/logs/kopp21"
settings.config_dir = "/cellar/users/dlaub/projects/ML4GLand/EUGENe_paper/configs/kopp21"

# EUGENe packages
import seqdata as sd

# Print versions
print("Python version: {}".format(sys.version))
print("PyTorch version: {}".format(torch.__version__))
print(f"Eugene version: {eu.__version__}")
print(f"SeqData version: {sd.__version__}")
print(f"WandB version: {wandb.__version__}")

In [None]:
# Login to wandb
wandb.login()

## Load dataset

In [None]:
# Load data
sdata = sd.open_zarr(os.path.join(settings.dataset_dir, 'kopp21_train.zarr'))
sdata

In [None]:
# Load data into memory
sdata[['ohe_seq', 'target']].load();

In [None]:
# Define transforms
transforms = {
    "ohe_seq": lambda x: x.swapaxes(1, 2)
}

In [None]:
# Build dataloader
train_sdata = sdata.sel(_sequence=(sdata.train_val == True).compute())
dloader = sd.get_torch_dataloader(
    train_sdata,
    sample_dims='_sequence',
    variables=['ohe_seq', 'target'],
    transforms=transforms,
    batch_size=2048,
    pin_memory=True,
    drop_last=False,
)

## Build or choose a medium sized model

In [None]:
def prep_new_model(
    config,
    seed,
):
    # Instantiate the model
    model = eu.models.load_config(
        config_path=config,
        seed=seed
    )

    # Initialize the model prior to conv filter initialization
    eu.models.init_weights(model)

    # Return the model
    return model

In [None]:
kopp21 = prep_new_model("kopp21_cnn.yaml", seed=0).to('cuda')

## Use native PyTorch to train the model and record GPU utilization

In [None]:
def train(model, dloader, optim, loss_fn):
    model.train()
    for batch in tqdm(dloader, leave=False):
        pred = model(batch['ohe_seq'].to('cuda'))
        loss = loss_fn(pred.squeeze(), batch['target'].to('cuda'))
        loss.backward()
        optim.step()
        optim.zero_grad()

In [None]:
optim = torch.optim.Adam(kopp21.arch.parameters())

with wandb.init(project='EUGENe GPU Utilization', name='Native PyTorch'):
    epochs = 25
    for _ in tqdm(range(epochs)):
        train(kopp21.arch, dloader, optim, kopp21.loss_fxn)

## Use PyTorch Lightning to train the model and record GPU utilization

In [None]:
import xarray as xr
import pytorch_lightning as pl

In [None]:
batch_size = 2048
num_workers = 0
target_vars = ['target']
seq_var = 'ohe_seq'
in_memory = True
train_var = 'train_val'
prefetch_factor = None
drop_last = False

if target_vars is not None:
    if isinstance(target_vars, str):
        target_vars = [target_vars]
    if len(target_vars) == 1:
        sdata["target"] = sdata[target_vars[0]]
    else:
        sdata["target"] = xr.concat(
            [sdata[target_var] for target_var in target_vars], dim="_targets"
        ).transpose("_sequence", "_targets")
    nan_mask = sdata['target'].isnull()
    if sdata["target"].ndim > 1:
        nan_mask = nan_mask.any('_targets')
    print(f"Dropping {nan_mask.sum().compute().item()} sequences with NaN targets.")
if in_memory:
    print(f"Loading {seq_var} and {target_vars} into memory")
    sdata[seq_var].load()
    sdata["target"].load()

In [None]:
train_sdata = sdata.sel(_sequence=(sdata[train_var] == True).compute())
val_sdata = sdata.sel(_sequence=(sdata[train_var] == False).compute())

In [None]:
train_dataloader = sd.get_torch_dataloader(
    train_sdata,
    sample_dims=["_sequence"],
    variables=['ohe_seq', "target"],
    batch_size=batch_size,
    num_workers=num_workers,
    prefetch_factor=prefetch_factor,
    transforms=transforms,
    shuffle=True,
    drop_last=drop_last,
)

In [None]:
val_dataloader = sd.get_torch_dataloader(
    val_sdata,
    sample_dims=["_sequence"],
    variables=['ohe_seq', "target"],
    batch_size=batch_size,
    num_workers=num_workers,
    prefetch_factor=prefetch_factor,
    transforms=transforms,
    shuffle=False,
    drop_last=drop_last,
)

In [None]:
trainer = pl.Trainer(
    max_epochs=25,
    logger=False,
    devices='auto',
    accelerator="gpu",
)

In [None]:
kopp21.scheduler = None

In [None]:
kopp21 = prep_new_model("kopp21_cnn.yaml", seed=0).to('cuda')
with wandb.init(project='EUGENe GPU Utilization', name='PL no val'):
    trainer.fit(kopp21, train_dataloaders=train_dataloader)

## Use EUGENe to train the model and record GPU utilization

EUGENe requires a validation set to evaluate on at every epoch and conditionally reduce the learning rate.

In [None]:
kopp21 = prep_new_model("kopp21_cnn.yaml", seed=0).to('cuda')
with wandb.init(project='EUGENe GPU Utilization', name='EUGENe'):
    eu.train.fit_sequence_module(
        kopp21,
        sdata,
        gpus=1,
        seq_var="ohe_seq",
        target_vars=["target"],
        in_memory=True,
        train_var="train_val",
        epochs=25,
        batch_size=2048,
        drop_last=False,
        transforms=transforms,
        early_stopping_metric=None,
        model_checkpoint_monitor=None,
    )

# DONE!

---

# Scratch