In [1]:
# Imports
import os
import torch
import pandas as pd
import numpy as np
import seqdatasets
import seqdata as sd
import xarray as xr
from eugene import preprocess as pp
from eugene.models.zoo import DeepBind, DeepSTARR
from eugene.models import SequenceModule
from eugene.models.base._metrics import calculate_metric
from eugene import plot as pl
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from pytorch_lightning.callbacks.lr_monitor import LearningRateMonitor
from pytorch_lightning.loggers import CSVLogger, TensorBoardLogger, WandbLogger

In [3]:
import os
from os import PathLike
from typing import List, Union
import xarray as xr
import numpy as np
from pytorch_lightning import LightningModule, Trainer, seed_everything
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from pytorch_lightning.callbacks.lr_monitor import LearningRateMonitor
from pytorch_lightning.loggers import CSVLogger, TensorBoardLogger, WandbLogger
from torch.utils.data import DataLoader, Dataset
import seqdata as sd
from eugene import settings
import torch

# Note that CSVLogger is currently hanging training with SequenceModule right now
# Note that if you use wandb logger, it comes with a few extra steps. Show a notebook for this
LOGGER_REGISTRY = {
    #"csv": CSVLogger,
    "tensorboard": TensorBoardLogger,
    #"wandb": WandbLogger,
}

def fit_sequence_module(
    model: LightningModule,
    sdata,
    seq_key: str,
    target_keys: Union[str, List[str]] = None,
    train_key: str = "train_val",
    epochs: int = 10,
    gpus: int = None,
    batch_size: int = None,
    num_workers: int = None,
    logger: str = "tensorboard",
    log_dir: PathLike = None,
    name: str = None,
    version: str = None,
    train_dataloader: DataLoader = None,
    val_dataloader: DataLoader = None,
    seq_transforms = None,
    early_stopping_metric: str = "val_loss_epoch",
    drop_last=True,
    early_stopping_callback: bool = True,
    early_stopping_patience=5,
    early_stopping_verbose=False,
    model_checkpoint_k = 1,
    model_checkpoint_monitor: str ="val_loss_epoch",
    seed: int = None,
    return_trainer: bool = False,
    **kwargs
):
    """
    Train the model using PyTorch Lightning.

    Parameters
    ----------
    model : BaseModel
        The model to train.
    sdata : SeqData
        The SeqData object to train on.
    target_keys : str or list of str
        The target keys in sdata's seqs_annot attribute to train on.
    train_key : str
        The key in sdata's seqs_annot attribute to split into train and validation set
    epochs : int
        The number of epochs to train for.
    gpus : int
        The number of gpus to use. EUGENe will automatically use all available gpus if available.
    batch_size : int
        The batch size to use.
    num_workers : int
        The number of workers to use for the dataloader.
    log_dir : PathLike
        The directory to save the logs to.
    name : str
        The name of the experiment.
    version : str
        The version of the experiment.
    train_dataset :Dataset 
        The training dataset to use. If None, will be created from sdata.
    val_dataset :Dataset 
        The validation dataset to use. If None, will be created from sdata.
    train_dataloader : DataLoader
        The training dataloader to use. If None, will be created from train_dataset.
    val_dataloader : DataLoader
        The validation dataloader to use. If None, will be created from val_dataset.
    seq_transforms : list of str
        The sequence transforms to apply to the data.
    transform_kwargs : dict
        The keyword arguments to pass to the sequence transforms.
    early_stopping_metric : str
        The metric to use for early stopping.
    early_stopping_patience : int
        The number of epochs to wait before stopping.
    early_stopping_verbose : bool
        Whether to print early stopping messages.
    seed : int
        The seed to use for reproducibility.
    verbosity : int
        The verbosity level.
    kwargs : dict
        Additional keyword arguments to pass to the PL Trainer.

    Returns
    -------
    trainer : Trainer
        The PyTorch Lightning Trainer object.
    """
    # Set training parameters
    gpus = gpus if gpus is not None else settings.gpus
    batch_size = batch_size if batch_size is not None else settings.batch_size
    num_workers = num_workers if num_workers is not None else settings.dl_num_workers
    log_dir = log_dir if log_dir is not None else settings.logging_dir
    model_name = model.__class__.__name__
    name = name if name is not None else model_name
    seed_everything(seed, workers=True) if seed is not None else print("No seed set")

    # Set-up dataloaders
    if train_dataloader is not None:
        assert val_dataloader is not None
    elif sdata is not None:
        if target_keys is not None:
            sdata["target"] = xr.concat([sdata[target_key] for target_key in target_keys], dim="_targets").transpose("_sequence", "_targets")
            targs = sdata["target"].values
            if len(targs.shape) == 1:
                nan_mask = np.isnan(targs)
            else:
                nan_mask = np.any(np.isnan(targs), axis=1)
            print(f"Dropping {nan_mask.sum()} sequences with NaN targets.")
            sdata = sdata.isel(_sequence=~nan_mask)
        train_mask = np.where(sdata[train_key])[0]
        train_sdata = sdata.isel(_sequence=train_mask)
        val_sdata = sdata.isel(_sequence=~train_mask)
        train_dataloader = sd.get_torch_dataloader(
            train_sdata,
            sample_dims=["_sequence"],
            variables=[seq_key, "target"],
            transforms=seq_transforms,
            prefetch_factor=None,
            shuffle=True,
            drop_last=drop_last,
            batch_size=batch_size,
            num_workers=num_workers
        )
        val_dataloader = sd.get_torch_dataloader(
            val_sdata,
            sample_dims=["_sequence"],
            variables=[seq_key, "target"],
            transforms=seq_transforms,
            prefetch_factor=None,
            shuffle=False,
            drop_last=drop_last,
            batch_size=batch_size,
            num_workers=num_workers
        )
    else:
        raise ValueError("No data provided to train on.")
    
    # Set-up callbacks
    logger = LOGGER_REGISTRY[logger](save_dir=log_dir, name=name, version=version)
    callbacks = []
    if model_checkpoint_monitor is not None:
        model_checkpoint_callback = ModelCheckpoint(
            dirpath=os.path.join(logger.save_dir, logger.name, logger.version, "checkpoints"), 
            save_top_k=model_checkpoint_k, 
            monitor=model_checkpoint_monitor
        )
        callbacks.append(model_checkpoint_callback)
    if early_stopping_metric is not None:
        early_stopping_callback = EarlyStopping(
            monitor=early_stopping_metric,
            patience=early_stopping_patience,
            mode="min",
            verbose=early_stopping_verbose,
        )
        callbacks.append(early_stopping_callback)
    if model.scheduler is not None:
        callbacks.append(LearningRateMonitor())
    trainer = Trainer(
        max_epochs=epochs, 
        logger=logger, 
        devices=gpus, 
        accelerator="auto",
        callbacks=callbacks, 
        **kwargs
    )
    trainer.fit(
        model, 
        train_dataloaders=train_dataloader, 
        val_dataloaders=val_dataloader
    )
    if return_trainer:
        return trainer

In [4]:
sdata = seqdatasets.random1000()
pp.ohe_seqs_sdata(sdata)
pp.make_unique_ids_sdata(sdata)
pp.train_test_split_sdata(sdata)
sdata["ohe_seq"] = sdata["ohe_seq"].transpose("_sequence", "_ohe", "length")
sdata

1000it [00:00, 1836.97it/s]


Unnamed: 0,Array,Chunk
Bytes,7.81 kiB,7.81 kiB
Shape,"(1000,)","(1000,)"
Dask graph,1 chunks in 2 graph layers,1 chunks in 2 graph layers
Data type,float64 numpy.ndarray,float64 numpy.ndarray
"Array Chunk Bytes 7.81 kiB 7.81 kiB Shape (1000,) (1000,) Dask graph 1 chunks in 2 graph layers Data type float64 numpy.ndarray",1000  1,

Unnamed: 0,Array,Chunk
Bytes,7.81 kiB,7.81 kiB
Shape,"(1000,)","(1000,)"
Dask graph,1 chunks in 2 graph layers,1 chunks in 2 graph layers
Data type,float64 numpy.ndarray,float64 numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,7.81 kiB,7.81 kiB
Shape,"(1000,)","(1000,)"
Dask graph,1 chunks in 2 graph layers,1 chunks in 2 graph layers
Data type,float64 numpy.ndarray,float64 numpy.ndarray
"Array Chunk Bytes 7.81 kiB 7.81 kiB Shape (1000,) (1000,) Dask graph 1 chunks in 2 graph layers Data type float64 numpy.ndarray",1000  1,

Unnamed: 0,Array,Chunk
Bytes,7.81 kiB,7.81 kiB
Shape,"(1000,)","(1000,)"
Dask graph,1 chunks in 2 graph layers,1 chunks in 2 graph layers
Data type,float64 numpy.ndarray,float64 numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,7.81 kiB,7.81 kiB
Shape,"(1000,)","(1000,)"
Dask graph,1 chunks in 2 graph layers,1 chunks in 2 graph layers
Data type,float64 numpy.ndarray,float64 numpy.ndarray
"Array Chunk Bytes 7.81 kiB 7.81 kiB Shape (1000,) (1000,) Dask graph 1 chunks in 2 graph layers Data type float64 numpy.ndarray",1000  1,

Unnamed: 0,Array,Chunk
Bytes,7.81 kiB,7.81 kiB
Shape,"(1000,)","(1000,)"
Dask graph,1 chunks in 2 graph layers,1 chunks in 2 graph layers
Data type,float64 numpy.ndarray,float64 numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,7.81 kiB,7.81 kiB
Shape,"(1000,)","(1000,)"
Dask graph,1 chunks in 2 graph layers,1 chunks in 2 graph layers
Data type,float64 numpy.ndarray,float64 numpy.ndarray
"Array Chunk Bytes 7.81 kiB 7.81 kiB Shape (1000,) (1000,) Dask graph 1 chunks in 2 graph layers Data type float64 numpy.ndarray",1000  1,

Unnamed: 0,Array,Chunk
Bytes,7.81 kiB,7.81 kiB
Shape,"(1000,)","(1000,)"
Dask graph,1 chunks in 2 graph layers,1 chunks in 2 graph layers
Data type,float64 numpy.ndarray,float64 numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,7.81 kiB,7.81 kiB
Shape,"(1000,)","(1000,)"
Dask graph,1 chunks in 2 graph layers,1 chunks in 2 graph layers
Data type,float64 numpy.ndarray,float64 numpy.ndarray
"Array Chunk Bytes 7.81 kiB 7.81 kiB Shape (1000,) (1000,) Dask graph 1 chunks in 2 graph layers Data type float64 numpy.ndarray",1000  1,

Unnamed: 0,Array,Chunk
Bytes,7.81 kiB,7.81 kiB
Shape,"(1000,)","(1000,)"
Dask graph,1 chunks in 2 graph layers,1 chunks in 2 graph layers
Data type,float64 numpy.ndarray,float64 numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,7.81 kiB,7.81 kiB
Shape,"(1000,)","(1000,)"
Dask graph,1 chunks in 2 graph layers,1 chunks in 2 graph layers
Data type,float64 numpy.ndarray,float64 numpy.ndarray
"Array Chunk Bytes 7.81 kiB 7.81 kiB Shape (1000,) (1000,) Dask graph 1 chunks in 2 graph layers Data type float64 numpy.ndarray",1000  1,

Unnamed: 0,Array,Chunk
Bytes,7.81 kiB,7.81 kiB
Shape,"(1000,)","(1000,)"
Dask graph,1 chunks in 2 graph layers,1 chunks in 2 graph layers
Data type,float64 numpy.ndarray,float64 numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,7.81 kiB,7.81 kiB
Shape,"(1000,)","(1000,)"
Dask graph,1 chunks in 2 graph layers,1 chunks in 2 graph layers
Data type,float64 numpy.ndarray,float64 numpy.ndarray
"Array Chunk Bytes 7.81 kiB 7.81 kiB Shape (1000,) (1000,) Dask graph 1 chunks in 2 graph layers Data type float64 numpy.ndarray",1000  1,

Unnamed: 0,Array,Chunk
Bytes,7.81 kiB,7.81 kiB
Shape,"(1000,)","(1000,)"
Dask graph,1 chunks in 2 graph layers,1 chunks in 2 graph layers
Data type,float64 numpy.ndarray,float64 numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,7.81 kiB,7.81 kiB
Shape,"(1000,)","(1000,)"
Dask graph,1 chunks in 2 graph layers,1 chunks in 2 graph layers
Data type,float64 numpy.ndarray,float64 numpy.ndarray
"Array Chunk Bytes 7.81 kiB 7.81 kiB Shape (1000,) (1000,) Dask graph 1 chunks in 2 graph layers Data type float64 numpy.ndarray",1000  1,

Unnamed: 0,Array,Chunk
Bytes,7.81 kiB,7.81 kiB
Shape,"(1000,)","(1000,)"
Dask graph,1 chunks in 2 graph layers,1 chunks in 2 graph layers
Data type,float64 numpy.ndarray,float64 numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,7.81 kiB,7.81 kiB
Shape,"(1000,)","(1000,)"
Dask graph,1 chunks in 2 graph layers,1 chunks in 2 graph layers
Data type,float64 numpy.ndarray,float64 numpy.ndarray
"Array Chunk Bytes 7.81 kiB 7.81 kiB Shape (1000,) (1000,) Dask graph 1 chunks in 2 graph layers Data type float64 numpy.ndarray",1000  1,

Unnamed: 0,Array,Chunk
Bytes,7.81 kiB,7.81 kiB
Shape,"(1000,)","(1000,)"
Dask graph,1 chunks in 2 graph layers,1 chunks in 2 graph layers
Data type,float64 numpy.ndarray,float64 numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,7.81 kiB,7.81 kiB
Shape,"(1000,)","(1000,)"
Dask graph,1 chunks in 2 graph layers,1 chunks in 2 graph layers
Data type,float64 numpy.ndarray,float64 numpy.ndarray
"Array Chunk Bytes 7.81 kiB 7.81 kiB Shape (1000,) (1000,) Dask graph 1 chunks in 2 graph layers Data type float64 numpy.ndarray",1000  1,

Unnamed: 0,Array,Chunk
Bytes,7.81 kiB,7.81 kiB
Shape,"(1000,)","(1000,)"
Dask graph,1 chunks in 2 graph layers,1 chunks in 2 graph layers
Data type,float64 numpy.ndarray,float64 numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,7.81 kiB,7.81 kiB
Shape,"(1000,)","(1000,)"
Dask graph,1 chunks in 2 graph layers,1 chunks in 2 graph layers
Data type,int64 numpy.ndarray,int64 numpy.ndarray
"Array Chunk Bytes 7.81 kiB 7.81 kiB Shape (1000,) (1000,) Dask graph 1 chunks in 2 graph layers Data type int64 numpy.ndarray",1000  1,

Unnamed: 0,Array,Chunk
Bytes,7.81 kiB,7.81 kiB
Shape,"(1000,)","(1000,)"
Dask graph,1 chunks in 2 graph layers,1 chunks in 2 graph layers
Data type,int64 numpy.ndarray,int64 numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,7.81 kiB,7.81 kiB
Shape,"(1000,)","(1000,)"
Dask graph,1 chunks in 2 graph layers,1 chunks in 2 graph layers
Data type,int64 numpy.ndarray,int64 numpy.ndarray
"Array Chunk Bytes 7.81 kiB 7.81 kiB Shape (1000,) (1000,) Dask graph 1 chunks in 2 graph layers Data type int64 numpy.ndarray",1000  1,

Unnamed: 0,Array,Chunk
Bytes,7.81 kiB,7.81 kiB
Shape,"(1000,)","(1000,)"
Dask graph,1 chunks in 2 graph layers,1 chunks in 2 graph layers
Data type,int64 numpy.ndarray,int64 numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,7.81 kiB,7.81 kiB
Shape,"(1000,)","(1000,)"
Dask graph,1 chunks in 2 graph layers,1 chunks in 2 graph layers
Data type,int64 numpy.ndarray,int64 numpy.ndarray
"Array Chunk Bytes 7.81 kiB 7.81 kiB Shape (1000,) (1000,) Dask graph 1 chunks in 2 graph layers Data type int64 numpy.ndarray",1000  1,

Unnamed: 0,Array,Chunk
Bytes,7.81 kiB,7.81 kiB
Shape,"(1000,)","(1000,)"
Dask graph,1 chunks in 2 graph layers,1 chunks in 2 graph layers
Data type,int64 numpy.ndarray,int64 numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,7.81 kiB,7.81 kiB
Shape,"(1000,)","(1000,)"
Dask graph,1 chunks in 2 graph layers,1 chunks in 2 graph layers
Data type,int64 numpy.ndarray,int64 numpy.ndarray
"Array Chunk Bytes 7.81 kiB 7.81 kiB Shape (1000,) (1000,) Dask graph 1 chunks in 2 graph layers Data type int64 numpy.ndarray",1000  1,

Unnamed: 0,Array,Chunk
Bytes,7.81 kiB,7.81 kiB
Shape,"(1000,)","(1000,)"
Dask graph,1 chunks in 2 graph layers,1 chunks in 2 graph layers
Data type,int64 numpy.ndarray,int64 numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,7.81 kiB,7.81 kiB
Shape,"(1000,)","(1000,)"
Dask graph,1 chunks in 2 graph layers,1 chunks in 2 graph layers
Data type,int64 numpy.ndarray,int64 numpy.ndarray
"Array Chunk Bytes 7.81 kiB 7.81 kiB Shape (1000,) (1000,) Dask graph 1 chunks in 2 graph layers Data type int64 numpy.ndarray",1000  1,

Unnamed: 0,Array,Chunk
Bytes,7.81 kiB,7.81 kiB
Shape,"(1000,)","(1000,)"
Dask graph,1 chunks in 2 graph layers,1 chunks in 2 graph layers
Data type,int64 numpy.ndarray,int64 numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,7.81 kiB,7.81 kiB
Shape,"(1000,)","(1000,)"
Dask graph,1 chunks in 2 graph layers,1 chunks in 2 graph layers
Data type,int64 numpy.ndarray,int64 numpy.ndarray
"Array Chunk Bytes 7.81 kiB 7.81 kiB Shape (1000,) (1000,) Dask graph 1 chunks in 2 graph layers Data type int64 numpy.ndarray",1000  1,

Unnamed: 0,Array,Chunk
Bytes,7.81 kiB,7.81 kiB
Shape,"(1000,)","(1000,)"
Dask graph,1 chunks in 2 graph layers,1 chunks in 2 graph layers
Data type,int64 numpy.ndarray,int64 numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,7.81 kiB,7.81 kiB
Shape,"(1000,)","(1000,)"
Dask graph,1 chunks in 2 graph layers,1 chunks in 2 graph layers
Data type,int64 numpy.ndarray,int64 numpy.ndarray
"Array Chunk Bytes 7.81 kiB 7.81 kiB Shape (1000,) (1000,) Dask graph 1 chunks in 2 graph layers Data type int64 numpy.ndarray",1000  1,

Unnamed: 0,Array,Chunk
Bytes,7.81 kiB,7.81 kiB
Shape,"(1000,)","(1000,)"
Dask graph,1 chunks in 2 graph layers,1 chunks in 2 graph layers
Data type,int64 numpy.ndarray,int64 numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,7.81 kiB,7.81 kiB
Shape,"(1000,)","(1000,)"
Dask graph,1 chunks in 2 graph layers,1 chunks in 2 graph layers
Data type,int64 numpy.ndarray,int64 numpy.ndarray
"Array Chunk Bytes 7.81 kiB 7.81 kiB Shape (1000,) (1000,) Dask graph 1 chunks in 2 graph layers Data type int64 numpy.ndarray",1000  1,

Unnamed: 0,Array,Chunk
Bytes,7.81 kiB,7.81 kiB
Shape,"(1000,)","(1000,)"
Dask graph,1 chunks in 2 graph layers,1 chunks in 2 graph layers
Data type,int64 numpy.ndarray,int64 numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,7.81 kiB,7.81 kiB
Shape,"(1000,)","(1000,)"
Dask graph,1 chunks in 2 graph layers,1 chunks in 2 graph layers
Data type,int64 numpy.ndarray,int64 numpy.ndarray
"Array Chunk Bytes 7.81 kiB 7.81 kiB Shape (1000,) (1000,) Dask graph 1 chunks in 2 graph layers Data type int64 numpy.ndarray",1000  1,

Unnamed: 0,Array,Chunk
Bytes,7.81 kiB,7.81 kiB
Shape,"(1000,)","(1000,)"
Dask graph,1 chunks in 2 graph layers,1 chunks in 2 graph layers
Data type,int64 numpy.ndarray,int64 numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,7.81 kiB,7.81 kiB
Shape,"(1000,)","(1000,)"
Dask graph,1 chunks in 2 graph layers,1 chunks in 2 graph layers
Data type,int64 numpy.ndarray,int64 numpy.ndarray
"Array Chunk Bytes 7.81 kiB 7.81 kiB Shape (1000,) (1000,) Dask graph 1 chunks in 2 graph layers Data type int64 numpy.ndarray",1000  1,

Unnamed: 0,Array,Chunk
Bytes,7.81 kiB,7.81 kiB
Shape,"(1000,)","(1000,)"
Dask graph,1 chunks in 2 graph layers,1 chunks in 2 graph layers
Data type,int64 numpy.ndarray,int64 numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,7.81 kiB,7.81 kiB
Shape,"(1000,)","(1000,)"
Dask graph,1 chunks in 2 graph layers,1 chunks in 2 graph layers
Data type,object numpy.ndarray,object numpy.ndarray
"Array Chunk Bytes 7.81 kiB 7.81 kiB Shape (1000,) (1000,) Dask graph 1 chunks in 2 graph layers Data type object numpy.ndarray",1000  1,

Unnamed: 0,Array,Chunk
Bytes,7.81 kiB,7.81 kiB
Shape,"(1000,)","(1000,)"
Dask graph,1 chunks in 2 graph layers,1 chunks in 2 graph layers
Data type,object numpy.ndarray,object numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,7.81 kiB,7.81 kiB
Shape,"(1000,)","(1000,)"
Dask graph,1 chunks in 2 graph layers,1 chunks in 2 graph layers
Data type,object numpy.ndarray,object numpy.ndarray
"Array Chunk Bytes 7.81 kiB 7.81 kiB Shape (1000,) (1000,) Dask graph 1 chunks in 2 graph layers Data type object numpy.ndarray",1000  1,

Unnamed: 0,Array,Chunk
Bytes,7.81 kiB,7.81 kiB
Shape,"(1000,)","(1000,)"
Dask graph,1 chunks in 2 graph layers,1 chunks in 2 graph layers
Data type,object numpy.ndarray,object numpy.ndarray


In [5]:
target_keys = ["activity_1", "activity_2", "activity_3"]

In [6]:
arch = DeepBind(
    input_len=100,
    output_dim=3
)
arch

DeepBind(
  (conv1d_tower): Conv1DTower(
    (layers): Sequential(
      (0): Conv1d(4, 16, kernel_size=(16,), stride=(1,), padding=valid)
      (1): ReLU()
      (2): Dropout(p=0.25, inplace=False)
    )
  )
  (max_pool): MaxPool1d(kernel_size=85, stride=85, padding=0, dilation=1, ceil_mode=False)
  (avg_pool): AvgPool1d(kernel_size=(85,), stride=(85,), padding=(0,))
  (dense_block): DenseBlock(
    (layers): Sequential(
      (0): Linear(in_features=32, out_features=32, bias=True)
      (1): ReLU()
      (2): Dropout(p=0.25, inplace=False)
      (3): Linear(in_features=32, out_features=3, bias=True)
    )
  )
)

In [7]:
model = SequenceModule(
    arch=arch,
    task="regression",
    arch_name="DeepBind",
    model_name="random1000_DeepBind_multitask_regression",
    scheduler="reduce_lr_on_plateau",
    scheduler_kwargs={"patience": 2},
    metric="r2score",
    metric_kwargs={"num_classes": 3}
)
model.summary()

Model: DeepBind
Sequence length: 100
Output dimension: 3
Task: regression
Loss function: mse_loss
Optimizer: Adam
	Optimizer parameters: {}
	Optimizer starting learning rate: 0.001
Scheduler: ReduceLROnPlateau
	Scheduler parameters: {'patience': 2}
Metric: r2score
	Metric parameters: {'num_classes': 3}
Seed: None
Parameters summary:


  | Name         | Type     | Params
------------------------------------------
0 | arch         | DeepBind | 2.2 K 
1 | train_metric | R2Score  | 0     
2 | val_metric   | R2Score  | 0     
3 | test_metric  | R2Score  | 0     
------------------------------------------
2.2 K     Trainable params
0         Non-trainable params
2.2 K     Total params
0.009     Total estimated model params size (MB)

In [10]:
fit_sequence_module(
    model=model,
    sdata=sdata,
    seq_key="ohe_seq",
    target_keys=target_keys,
    seq_transforms={"ohe_seq": lambda x: torch.tensor(x, dtype=torch.float32)},
    epochs=10,
    batch_size=128,
    num_workers=0,
    log_dir="/cellar/users/aklie/projects/ML4GLand/EUGENe/notebooks/tests",
    name="random1000_DeepBind_multitask_regression",
    version="0.0.1"
)

No seed set
Dropping 0 sequences with NaN targets.


  rank_zero_warn(
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name         | Type     | Params
------------------------------------------
0 | arch         | DeepBind | 2.2 K 
1 | train_metric | R2Score  | 0     
2 | val_metric   | R2Score  | 0     
3 | test_metric  | R2Score  | 0     
------------------------------------------
2.2 K     Trainable params
0         Non-trainable params
2.2 K     Total params
0.009     Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]

  rank_zero_warn(
  rank_zero_warn(
  rank_zero_warn(


Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

`Trainer.fit` stopped: `max_epochs=10` reached.
