# Train DeepMEL model using EUGENe on `pbmc-granulocyte-sorted-3k_10x-Multiome`
Adam Klie (last updated: *09/20/2023*)
***
This notebook shows how to train a DeepMEL model using EUGENe on the `pbmc-granulocyte-sorted-3k_10x-Multiome` dataset.

# Set-up

In [None]:
import os
import sys
import torch
import seqdata as sd
import seqpro as sp
from eugene import models
from eugene.models.zoo import DeepMEL
from eugene import train
from eugene.models.base._metrics import calculate_metric

In [None]:
# Set-up the paths to data (TODO: change to your own paths)
dataset_name = "pbmc-granulocyte-sorted-3k_10x-Multiome"
input_dir = '/cellar/users/aklie/projects/ML4GLand/use_cases/scBasset/pbmc-granulocyte-sorted-3k_10x-Multiome/processed'

# Load data

In [None]:
# Load the SeqData
sdata = sd.open_zarr(os.path.join(input_dir, f"{dataset_name}.train.zarr"))

In [None]:
# Rename _topics dim to _targets
sdata = sdata.rename_vars({"ohe_seqs": "ohe_seq", "topics": "target"})
sdata = sdata.rename_dims({"_topic": "_targets"})
n_topics = sdata.dims["_targets"]
sdata[["ohe_seq", "target"]].load()

In [None]:
# # Instantiate the architecture
arch = DeepMEL(
    input_len=500, 
    output_dim=n_topics,
    conv_kwargs={
        "conv_channels": [1024],  
    },
)

In [None]:
# Create the trainable
model = models.SequenceModule(
    arch=arch,
    task="multilabel_classification",
    loss_fxn="bce",
    optimizer="adam",
    metric_kwargs={
        "task": "multilabel",
        "num_labels": n_topics
    }
)

In [None]:
# Initialize the weights
models.init_weights(model)

In [None]:
# Grab only the training data and validation data
train_sdata = sdata.sel(_sequence=(sdata["train_val"] == True).compute())  # noqa
val_sdata = sdata.sel(_sequence=(sdata["train_val"] == False).compute())  # noqa

In [None]:
# Build the dataloaders
train_dataloader = sd.get_torch_dataloader(
    train_sdata,
    sample_dims=["_sequence"],
    variables=["ohe_seq", "target"],
    batch_size=128,
    num_workers=4,
    prefetch_factor=2,
    transforms={
        "ohe_seq": lambda x: torch.tensor(x, dtype=torch.float32).permute(0, 2, 1),
        "target": lambda x: torch.tensor(x, dtype=torch.float32)
    },
    shuffle=True,
    drop_last=True
)
val_dataloader = sd.get_torch_dataloader(
    val_sdata,
    sample_dims=["_sequence"],
    variables=["ohe_seq", "target"],
    batch_size=128,
    num_workers=4,
    prefetch_factor=2,
    transforms={
        "ohe_seq": lambda x: torch.tensor(x, dtype=torch.float32).permute(0, 2, 1),
        "target": lambda x: torch.tensor(x, dtype=torch.long)
    },
    shuffle=False,
    drop_last=False
)

In [None]:
# Grab a batch
batch = next(iter(train_dataloader))

In [None]:
# Quick test
calculate_metric(model.train_metric, "auroc", model.metric_kwargs, model(batch["ohe_seq"]), batch["target"])

In [None]:
# Train the model
train.fit(
    model=model,
    train_dataloader=train_dataloader,
    val_dataloader=val_dataloader,
    gpus=1,
    epochs=25,
    log_dir=os.path.join(input_dir, dataset_name),
    name=f"{dataset_name}.DeepMEL.revision",
    version="v0"
)

# DONE!

---