# Kopp et al 2021 Training 

**Authorship:**
Adam Klie (last updated: *06/10/2023*)
***
**Description:**
Notebook to train models on the Kopp et al (2021) dataset. You can also use the `kopp21_training_{FCN|CNN|Hybrid|Kopp21CNN}.py` script as well if you want to run it that way.
***

In [None]:
if 'autoreload' not in get_ipython().extension_manager.loaded:
    %load_ext autoreload 
%autoreload 2

In [1]:
# General imports
import os
import sys
import torch
import numpy as np
import pandas as pd
from copy import deepcopy 
import pytorch_lightning
from itertools import product

# EUGENe imports and settings
import eugene as eu
from eugene import dataload as dl
from eugene import models, train, evaluate
from eugene.dataload._augment import RandomRC
from eugene import settings
settings.dataset_dir = "/cellar/users/aklie/data/eugene/revision/kopp21"
settings.output_dir = "/cellar/users/aklie/projects/ML4GLand/EUGENe_paper/output/kopp21"
settings.logging_dir = "/cellar/users/aklie/projects/ML4GLand/EUGENe_paper/logs/kopp21"
settings.config_dir = "/cellar/users/aklie/projects/ML4GLand/EUGENe_paper/configs/kopp21"

# EUGENe packages
import seqdata as sd

# Print versions
print(f"Python version: {sys.version}")
print(f"NumPy version: {np.__version__}")
print(f"Pandas version: {pd.__version__}")
print(f"Eugene version: {eu.__version__}")
print(f"SeqData version: {sd.__version__}")
print(f"PyTorch version: {torch.__version__}")
print(f"PyTorch Lightning version: {pytorch_lightning.__version__}")

  pkg_resources.require(self.requirement)
  pkg_resources.require(self.requirement)


Python version: 3.9.16 | packaged by conda-forge | (main, Feb  1 2023, 21:39:03) 
[GCC 11.3.0]
NumPy version: 1.23.5
Pandas version: 1.5.2
Eugene version: 0.0.8
SeqData version: 0.0.1
PyTorch version: 2.0.0
PyTorch Lightning version: 2.0.0


# Load in the `SeqData`

In [2]:
sdata = sd.open_zarr(os.path.join(settings.dataset_dir, 'kopp21_train.zarr'))
sdata

Unnamed: 0,Array,Chunk
Bytes,7.24 MiB,463.27 kiB
Shape,"(948771,)","(59299,)"
Dask graph,16 chunks in 2 graph layers,16 chunks in 2 graph layers
Data type,object numpy.ndarray,object numpy.ndarray
"Array Chunk Bytes 7.24 MiB 463.27 kiB Shape (948771,) (59299,) Dask graph 16 chunks in 2 graph layers Data type object numpy.ndarray",948771  1,

Unnamed: 0,Array,Chunk
Bytes,7.24 MiB,463.27 kiB
Shape,"(948771,)","(59299,)"
Dask graph,16 chunks in 2 graph layers,16 chunks in 2 graph layers
Data type,object numpy.ndarray,object numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,7.24 MiB,463.27 kiB
Shape,"(948771,)","(59299,)"
Dask graph,16 chunks in 2 graph layers,16 chunks in 2 graph layers
Data type,int64 numpy.ndarray,int64 numpy.ndarray
"Array Chunk Bytes 7.24 MiB 463.27 kiB Shape (948771,) (59299,) Dask graph 16 chunks in 2 graph layers Data type int64 numpy.ndarray",948771  1,

Unnamed: 0,Array,Chunk
Bytes,7.24 MiB,463.27 kiB
Shape,"(948771,)","(59299,)"
Dask graph,16 chunks in 2 graph layers,16 chunks in 2 graph layers
Data type,int64 numpy.ndarray,int64 numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,7.24 MiB,463.27 kiB
Shape,"(948771,)","(59299,)"
Dask graph,16 chunks in 2 graph layers,16 chunks in 2 graph layers
Data type,int64 numpy.ndarray,int64 numpy.ndarray
"Array Chunk Bytes 7.24 MiB 463.27 kiB Shape (948771,) (59299,) Dask graph 16 chunks in 2 graph layers Data type int64 numpy.ndarray",948771  1,

Unnamed: 0,Array,Chunk
Bytes,7.24 MiB,463.27 kiB
Shape,"(948771,)","(59299,)"
Dask graph,16 chunks in 2 graph layers,16 chunks in 2 graph layers
Data type,int64 numpy.ndarray,int64 numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,904.82 MiB,1.81 MiB
Shape,"(948771, 1, 500)","(29650, 1, 32)"
Dask graph,512 chunks in 2 graph layers,512 chunks in 2 graph layers
Data type,uint16 numpy.ndarray,uint16 numpy.ndarray
"Array Chunk Bytes 904.82 MiB 1.81 MiB Shape (948771, 1, 500) (29650, 1, 32) Dask graph 512 chunks in 2 graph layers Data type uint16 numpy.ndarray",500  1  948771,

Unnamed: 0,Array,Chunk
Bytes,904.82 MiB,1.81 MiB
Shape,"(948771, 1, 500)","(29650, 1, 32)"
Dask graph,512 chunks in 2 graph layers,512 chunks in 2 graph layers
Data type,uint16 numpy.ndarray,uint16 numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,1.77 GiB,3.56 MiB
Shape,"(948771, 500, 4)","(59299, 63, 1)"
Dask graph,512 chunks in 2 graph layers,512 chunks in 2 graph layers
Data type,uint8 numpy.ndarray,uint8 numpy.ndarray
"Array Chunk Bytes 1.77 GiB 3.56 MiB Shape (948771, 500, 4) (59299, 63, 1) Dask graph 512 chunks in 2 graph layers Data type uint8 numpy.ndarray",4  500  948771,

Unnamed: 0,Array,Chunk
Bytes,1.77 GiB,3.56 MiB
Shape,"(948771, 500, 4)","(59299, 63, 1)"
Dask graph,512 chunks in 2 graph layers,512 chunks in 2 graph layers
Data type,uint8 numpy.ndarray,uint8 numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,452.41 MiB,1.81 MiB
Shape,"(948771, 500)","(59299, 32)"
Dask graph,256 chunks in 2 graph layers,256 chunks in 2 graph layers
Data type,|S1 numpy.ndarray,|S1 numpy.ndarray
"Array Chunk Bytes 452.41 MiB 1.81 MiB Shape (948771, 500) (59299, 32) Dask graph 256 chunks in 2 graph layers Data type |S1 numpy.ndarray",500  948771,

Unnamed: 0,Array,Chunk
Bytes,452.41 MiB,1.81 MiB
Shape,"(948771, 500)","(59299, 32)"
Dask graph,256 chunks in 2 graph layers,256 chunks in 2 graph layers
Data type,|S1 numpy.ndarray,|S1 numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,7.24 MiB,463.27 kiB
Shape,"(948771,)","(59299,)"
Dask graph,16 chunks in 2 graph layers,16 chunks in 2 graph layers
Data type,object numpy.ndarray,object numpy.ndarray
"Array Chunk Bytes 7.24 MiB 463.27 kiB Shape (948771,) (59299,) Dask graph 16 chunks in 2 graph layers Data type object numpy.ndarray",948771  1,

Unnamed: 0,Array,Chunk
Bytes,7.24 MiB,463.27 kiB
Shape,"(948771,)","(59299,)"
Dask graph,16 chunks in 2 graph layers,16 chunks in 2 graph layers
Data type,object numpy.ndarray,object numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,0.90 MiB,231.63 kiB
Shape,"(948771,)","(237193,)"
Dask graph,4 chunks in 2 graph layers,4 chunks in 2 graph layers
Data type,uint8 numpy.ndarray,uint8 numpy.ndarray
"Array Chunk Bytes 0.90 MiB 231.63 kiB Shape (948771,) (237193,) Dask graph 4 chunks in 2 graph layers Data type uint8 numpy.ndarray",948771  1,

Unnamed: 0,Array,Chunk
Bytes,0.90 MiB,231.63 kiB
Shape,"(948771,)","(237193,)"
Dask graph,4 chunks in 2 graph layers,4 chunks in 2 graph layers
Data type,uint8 numpy.ndarray,uint8 numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,0.90 MiB,231.63 kiB
Shape,"(948771,)","(237193,)"
Dask graph,4 chunks in 2 graph layers,4 chunks in 2 graph layers
Data type,bool numpy.ndarray,bool numpy.ndarray
"Array Chunk Bytes 0.90 MiB 231.63 kiB Shape (948771,) (237193,) Dask graph 4 chunks in 2 graph layers Data type bool numpy.ndarray",948771  1,

Unnamed: 0,Array,Chunk
Bytes,0.90 MiB,231.63 kiB
Shape,"(948771,)","(237193,)"
Dask graph,4 chunks in 2 graph layers,4 chunks in 2 graph layers
Data type,bool numpy.ndarray,bool numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,0.90 MiB,231.63 kiB
Shape,"(948771,)","(237193,)"
Dask graph,4 chunks in 2 graph layers,4 chunks in 2 graph layers
Data type,bool numpy.ndarray,bool numpy.ndarray
"Array Chunk Bytes 0.90 MiB 231.63 kiB Shape (948771,) (237193,) Dask graph 4 chunks in 2 graph layers Data type bool numpy.ndarray",948771  1,

Unnamed: 0,Array,Chunk
Bytes,0.90 MiB,231.63 kiB
Shape,"(948771,)","(237193,)"
Dask graph,4 chunks in 2 graph layers,4 chunks in 2 graph layers
Data type,bool numpy.ndarray,bool numpy.ndarray


# Model instantiation and initialization 

In [3]:
def prep_new_model(
    config,
    seed,
):
    # Instantiate the model
    model = eu.models.load_config(
        config_path=config,
        seed=seed
    )
    
    # Initialize the model prior to conv filter initialization
    eu.models.init_weights(model)

    # Return the model
    return model 

In [12]:
# Test the instantiation of each model to make sure this is working properly
model = prep_new_model("dscnn.yaml", seed=0)
#model = prep_new_model("dshybrid.yaml", seed=0)
#model = prep_new_model("dsfcn.yaml", seed=0)
#model = prep_new_model("kopp21_cnn.yaml", seed=0)

[rank: 0] Global seed set to 0


In [13]:
model

SequenceModule(
  (arch): dsCNN(
    (revcomp): RevComp()
    (conv1d_tower): Conv1DTower(
      (layers): Sequential(
        (0): Conv1d(4, 10, kernel_size=(11,), stride=(1,), padding=valid)
        (1): ReLU()
        (2): MaxPool1d(kernel_size=30, stride=1, padding=0, dilation=1, ceil_mode=False)
        (3): Dropout(p=0.2, inplace=False)
        (4): BatchNorm1d(10, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (5): Conv1d(10, 8, kernel_size=(3,), stride=(1,), padding=valid)
        (6): ReLU()
        (7): Dropout(p=0.2, inplace=False)
        (8): BatchNorm1d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (dense_block): DenseBlock(
      (layers): Sequential(
        (0): Linear(in_features=7344, out_features=64, bias=True)
        (1): ReLU()
        (2): Dropout(p=0.2, inplace=False)
        (3): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (4): Linear(in_features=64, out

In [None]:
configs = ["fcn.yaml", "cnn.yaml", "hybrid.yaml", "kopp21_cnn.yaml"]
trials = 1

for config, trial in product(configs, range(1, trials+1)):
    model_name = config.split('.')[0]
    print(model_name)
    
    # Initialize the model
    model = prep_new_model(os.path.join(eu.settings.config_dir, config), seed=trial)
    
    transforms = {
        "target": lambda x: torch.tensor(x, dtype=torch.float32)
    }
    if model_name != 'kopp21_cnn':
        random_rc = RandomRC()
        def ohe_seq_transform(x):
            x = torch.tensor(x, dtype=torch.float32).swapaxes(1, 2)
            return random_rc(x)
        transforms["ohe_seq"] = ohe_seq_transform
    else:
        transforms["ohe_seq"] = lambda x: torch.tensor(x, dtype=torch.float32).swapaxes(1, 2)
        
    
    # Fit the model
    eu.train.fit_sequence_module(
        model,
        sdata,
        gpus=1,
        seq_key="ohe_seq",
        target_keys=["target"],
        in_memory=True,
        train_key="train_val",
        epochs=25,
        early_stopping_metric='val_loss_epoch',
        early_stopping_patience=5,
        batch_size=64,
        num_workers=4,
        prefetch_factor=2,
        drop_last=False,
        name=model_name,
        version=f"trial_{trial}",
        transforms=transforms,
        seed=trial,
    )
    
    # Evaluate the model on train and validation sets
    evaluate.train_val_predictions_sequence_module(
        model,
        sdata,
        seq_key="ohe_seq",
        target_keys=["target"],
        in_memory=True,
        train_key="train_val",
        batch_size=1024,
        num_workers=4,
        prefetch_factor=2,
        name=model_name,
        version=f"trial_{trial}",
        transforms=transforms,
        prefix=f"{model_name}_trial_{trial}_"
    )
    
    del model

# DONE!

---

# Scratch