# Training scBasset with EUGENe
Adam Klie (last updated: *09/21/2023*)
***
This notebook shows how to take preprocessed SeqDatas and train scBasset models with EUGENe

# Set-up

In [None]:
# General imports
import os
import sys
import torch
import numpy as np
import pandas as pd

# EUGENe imports
import eugene as eu

# EUGENe packages
import seqdatasets
import seqdata as sd
from eugene import dataload as dl
from eugene import models
from eugene.models.zoo import scBasset
from eugene import train

# Print versions
print(f"Python version: {sys.version}")
print(f"NumPy version: {np.__version__}")
print(f"Pandas version: {pd.__version__}")
print(f"PyTorch version: {torch.__version__}")
print(f"Eugene version: {eu.__version__}")

In [None]:
# Set-up the paths to data (TODO: change to your own paths)
input_dir = '/cellar/users/aklie/projects/ML4GLand/use_cases/scBasset/Buenrostro_2018/processed'
train_seqdata = os.path.join(input_dir, 'train_seqdata.zarr')
val_seqdata = os.path.join(input_dir, 'val_seqdata.zarr')
output_dir = '/cellar/users/aklie/projects/ML4GLand/use_cases/scBasset/Buenrostro_2018/models'

# Load data

In [None]:
sdata_train = sd.open_zarr(train_seqdata)
sdata_val = sd.open_zarr(val_seqdata)

# Smush the seqdatas together with "train_val" column
sdata_train_val = dl.concat([sdata_train, sdata_val], keys=['train', 'val'], axis=0)

In [None]:
train_dl = sd.get_torch_dataloader(
    sdata_train,
    sample_dims=['_sequence'],
    variables=['ohe_seq', 'bin_counts'],
    prefetch_factor=None,
    batch_size=128,
    transforms={
        lambda x: torch.tensor(x, dtype=torch.float32).transpose(1, 2)
    },
    shuffle=True,
)

In [None]:
val_dl = sd.get_torch_dataloader(
    sdata_val,
    sample_dims=['_sequence'],
    variables=['ohe_seq', 'bin_counts'],
    prefetch_factor=None,
    batch_size=128,
    transforms={
        lambda x: torch.tensor(x, dtype=torch.float32).transpose(1, 2)
    },
    shuffle=False,
)

In [None]:
# Test a batch
batch = next(iter(train_dl))
[x.shape for x in batch]

# Load model

In [1]:
arch = scBasset(num_cells=2711, l1=0.01, l2=0.01)

In [None]:
# Test a batch
batch = next(iter(train_dl))

In [None]:
# Forward pass
outs = arch(batch[0])
outs[0].shape

In [None]:
# Create a SequenceModule for training
model = models.SequenceModule(
    arch=arch,
    task='multilabel_classification',
    loss_fn='bce',
    metrics=['auroc'],
)

# Train the model

In [None]:
# Fit the model
train.fit(
    model,
    train_dataloader=train_dl,
    val_dataloader=val_dl,
    seq_var="ohe_seq",
    target_vars=["bin_counts"],
    in_memory=True,
    train_var="batch",
    epochs=5,
    batch_size=128,
    num_workers=4,
    prefetch_factor=2,
    drop_last=False,
    model_checkpoint_monitor="val_auroc_epoch",
    log_dir=output_dir,
    name="eugene",
    version="20Sep23"
)

# DONE!

---