# EUGENe DeepSTARR model training
Adam Klie (last updated: *09/20/2023*)
***
Notebook for training a DeepSTARR model with EUGENe

# Set-up

In [None]:
import os
import time
import numpy as np
import seqdatasets
from eugene import preprocess as pp
from eugene import dataload as dl
from eugene import train
from eugene import models
import seqdata as sd
import xarray as xr

# Load and preprocess dataset

In [None]:
sdata_train = seqdatasets.deAlmeida22("train")
sdata_val = seqdatasets.deAlmeida22("val")
pp.ohe_seqs_sdata(sdata_train)
pp.ohe_seqs_sdata(sdata_val)

# Instantiate model

In [None]:
from eugene.models.zoo import DeepSTARR

In [None]:
arch = DeepSTARR(
    input_len=249, 
    output_dim=2, 
    optimizer_lr=0.002,
    optimizer_kwargs=dict(weight_decay=1e-6)
)
models.init_weights(arch)

In [None]:
model = models.SequenceModule(
    arch=arch,
    task="regression",
    loss_fxn="mse",
    optimizer="adam",
)

# Build dataloader

In [None]:
train_dl = sd.get_torch_dataloader(
    sdata_train,
    batch_size=100,
    sample_dims="_sequence",
    variables=["ohe_seq", "target"],
    shuffle=True,
    num_workers=4,
    prefetch_factor=2,
)

val_dl = sd.get_torch_dataloader(
    sdata_val,
    batch_size=128,
    sample_dims="_sequence",
    variables=["ohe_seq", "target"],
    shuffle=False,
    num_workers=4,
    prefetch_factor=2,
)

# Train model

In [None]:
train.fit(
    model=model,
    train_dataloader=train_dl,
    val_dataloader=val_dl,
    epochs=100,
    gpus=1,
    log_dir="/cellar/users/aklie/projects/ML4GLand/models/",  # TODO: change to your own path
    name="DeepSTARR",
    version=time.strftime("%Y-%m-%d_%H-%M-%S"),
    early_stopping_patience=10,
    seed=13
)

# DONE!

---