In [1]:
# General imports
import os
import torch
import numpy as np
import pandas as pd
from copy import deepcopy 

# EUGENe imports and settings
from eugene import models
from eugene import train
from eugene import settings
settings.dataset_dir = "/cellar/users/aklie/data/eugene/revision/jores21"
settings.output_dir = "/cellar/users/aklie/projects/ML4GLand/EUGENe_paper/output/revision/jores21"
settings.logging_dir = "/cellar/users/aklie/projects/ML4GLand/EUGENe_paper/logs/revision/jores21"
settings.config_dir = "/cellar/users/aklie/projects/ML4GLand/EUGENe_paper/configs/jores21"

# EUGENe packages
import seqdata as sd
import motifdata as md

# Jores21CNN model

# Data stuff (sort this later)

In [3]:
sdata = sd.open_zarr(os.path.join(settings.dataset_dir, "jores21_leaf_train.zarr"))

In [4]:
sdata["ohe_seq"].shape, sdata["train_val"].to_dataframe().value_counts(normalize=True)

((65004, 170, 4),
 train_val
 True         0.899991
 False        0.100009
 dtype: float64)

In [5]:
seq_key = "ohe_seq"
target_keys = "enrichment"
train_key = "train_val"
seq_transforms = {seq_key: lambda x: torch.tensor(x, dtype=torch.float32).permute(0, 2, 1)}
batch_size = 128
num_workers = 4
drop_last = True

In [6]:
if isinstance(target_keys, str):
    target_keys = [target_keys]
if len(target_keys) == 1:
    sdata["target"] = sdata[target_keys[0]]
else:
    sdata["target"] = xr.concat([sdata[target_key] for target_key in target_keys], dim="_targets").transpose("_sequence", "_targets")
targs = sdata["target"].values
if len(targs.shape) == 1:
    nan_mask = np.isnan(targs)
else:
    nan_mask = np.any(np.isnan(targs), axis=1)
print(f"Dropping {nan_mask.sum()} sequences with NaN targets.")
sdata = sdata.isel(_sequence=~nan_mask)

Dropping 0 sequences with NaN targets.


In [7]:
# Load training data into memory
sdata["ohe_seq"].load()
sdata["enrichment"].load()
sdata["train_val"].load()

In [9]:
targs = sdata["enrichment"].values

In [11]:
import xarray as xr

In [12]:
nan_mask = xr.DataArray(np.isnan(targs), dims=["_sequence"])

In [14]:
sdata = sdata.where(~nan_mask, drop=True)

0

In [22]:
print(f"Dropping {int(nan_mask.sum().values)} sequences with NaN targets.")

Dropping 0 sequences with NaN targets.


In [15]:
sdata

Unnamed: 0,Array,Chunk
Bytes,507.84 kiB,507.84 kiB
Shape,"(65004,)","(65004,)"
Dask graph,1 chunks in 5 graph layers,1 chunks in 5 graph layers
Data type,object numpy.ndarray,object numpy.ndarray
"Array Chunk Bytes 507.84 kiB 507.84 kiB Shape (65004,) (65004,) Dask graph 1 chunks in 5 graph layers Data type object numpy.ndarray",65004  1,

Unnamed: 0,Array,Chunk
Bytes,507.84 kiB,507.84 kiB
Shape,"(65004,)","(65004,)"
Dask graph,1 chunks in 5 graph layers,1 chunks in 5 graph layers
Data type,object numpy.ndarray,object numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,507.84 kiB,507.84 kiB
Shape,"(65004,)","(65004,)"
Dask graph,1 chunks in 5 graph layers,1 chunks in 5 graph layers
Data type,object numpy.ndarray,object numpy.ndarray
"Array Chunk Bytes 507.84 kiB 507.84 kiB Shape (65004,) (65004,) Dask graph 1 chunks in 5 graph layers Data type object numpy.ndarray",65004  1,

Unnamed: 0,Array,Chunk
Bytes,507.84 kiB,507.84 kiB
Shape,"(65004,)","(65004,)"
Dask graph,1 chunks in 5 graph layers,1 chunks in 5 graph layers
Data type,object numpy.ndarray,object numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,507.84 kiB,507.84 kiB
Shape,"(65004,)","(65004,)"
Dask graph,1 chunks in 5 graph layers,1 chunks in 5 graph layers
Data type,object numpy.ndarray,object numpy.ndarray
"Array Chunk Bytes 507.84 kiB 507.84 kiB Shape (65004,) (65004,) Dask graph 1 chunks in 5 graph layers Data type object numpy.ndarray",65004  1,

Unnamed: 0,Array,Chunk
Bytes,507.84 kiB,507.84 kiB
Shape,"(65004,)","(65004,)"
Dask graph,1 chunks in 5 graph layers,1 chunks in 5 graph layers
Data type,object numpy.ndarray,object numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,507.84 kiB,507.84 kiB
Shape,"(65004,)","(65004,)"
Dask graph,1 chunks in 5 graph layers,1 chunks in 5 graph layers
Data type,object numpy.ndarray,object numpy.ndarray
"Array Chunk Bytes 507.84 kiB 507.84 kiB Shape (65004,) (65004,) Dask graph 1 chunks in 5 graph layers Data type object numpy.ndarray",65004  1,

Unnamed: 0,Array,Chunk
Bytes,507.84 kiB,507.84 kiB
Shape,"(65004,)","(65004,)"
Dask graph,1 chunks in 5 graph layers,1 chunks in 5 graph layers
Data type,object numpy.ndarray,object numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,507.84 kiB,507.84 kiB
Shape,"(65004,)","(65004,)"
Dask graph,1 chunks in 5 graph layers,1 chunks in 5 graph layers
Data type,object numpy.ndarray,object numpy.ndarray
"Array Chunk Bytes 507.84 kiB 507.84 kiB Shape (65004,) (65004,) Dask graph 1 chunks in 5 graph layers Data type object numpy.ndarray",65004  1,

Unnamed: 0,Array,Chunk
Bytes,507.84 kiB,507.84 kiB
Shape,"(65004,)","(65004,)"
Dask graph,1 chunks in 5 graph layers,1 chunks in 5 graph layers
Data type,object numpy.ndarray,object numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,507.84 kiB,78.12 kiB
Shape,"(65004,)","(10000,)"
Dask graph,7 chunks in 5 graph layers,7 chunks in 5 graph layers
Data type,float64 numpy.ndarray,float64 numpy.ndarray
"Array Chunk Bytes 507.84 kiB 78.12 kiB Shape (65004,) (10000,) Dask graph 7 chunks in 5 graph layers Data type float64 numpy.ndarray",65004  1,

Unnamed: 0,Array,Chunk
Bytes,507.84 kiB,78.12 kiB
Shape,"(65004,)","(10000,)"
Dask graph,7 chunks in 5 graph layers,7 chunks in 5 graph layers
Data type,float64 numpy.ndarray,float64 numpy.ndarray


In [8]:
sdata.where(sdata["train_val"], drop=True)

Unnamed: 0,Array,Chunk
Bytes,457.05 kiB,457.05 kiB
Shape,"(58503,)","(58503,)"
Dask graph,1 chunks in 6 graph layers,1 chunks in 6 graph layers
Data type,object numpy.ndarray,object numpy.ndarray
"Array Chunk Bytes 457.05 kiB 457.05 kiB Shape (58503,) (58503,) Dask graph 1 chunks in 6 graph layers Data type object numpy.ndarray",58503  1,

Unnamed: 0,Array,Chunk
Bytes,457.05 kiB,457.05 kiB
Shape,"(58503,)","(58503,)"
Dask graph,1 chunks in 6 graph layers,1 chunks in 6 graph layers
Data type,object numpy.ndarray,object numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,457.05 kiB,457.05 kiB
Shape,"(58503,)","(58503,)"
Dask graph,1 chunks in 6 graph layers,1 chunks in 6 graph layers
Data type,object numpy.ndarray,object numpy.ndarray
"Array Chunk Bytes 457.05 kiB 457.05 kiB Shape (58503,) (58503,) Dask graph 1 chunks in 6 graph layers Data type object numpy.ndarray",58503  1,

Unnamed: 0,Array,Chunk
Bytes,457.05 kiB,457.05 kiB
Shape,"(58503,)","(58503,)"
Dask graph,1 chunks in 6 graph layers,1 chunks in 6 graph layers
Data type,object numpy.ndarray,object numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,457.05 kiB,457.05 kiB
Shape,"(58503,)","(58503,)"
Dask graph,1 chunks in 6 graph layers,1 chunks in 6 graph layers
Data type,object numpy.ndarray,object numpy.ndarray
"Array Chunk Bytes 457.05 kiB 457.05 kiB Shape (58503,) (58503,) Dask graph 1 chunks in 6 graph layers Data type object numpy.ndarray",58503  1,

Unnamed: 0,Array,Chunk
Bytes,457.05 kiB,457.05 kiB
Shape,"(58503,)","(58503,)"
Dask graph,1 chunks in 6 graph layers,1 chunks in 6 graph layers
Data type,object numpy.ndarray,object numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,457.05 kiB,457.05 kiB
Shape,"(58503,)","(58503,)"
Dask graph,1 chunks in 6 graph layers,1 chunks in 6 graph layers
Data type,object numpy.ndarray,object numpy.ndarray
"Array Chunk Bytes 457.05 kiB 457.05 kiB Shape (58503,) (58503,) Dask graph 1 chunks in 6 graph layers Data type object numpy.ndarray",58503  1,

Unnamed: 0,Array,Chunk
Bytes,457.05 kiB,457.05 kiB
Shape,"(58503,)","(58503,)"
Dask graph,1 chunks in 6 graph layers,1 chunks in 6 graph layers
Data type,object numpy.ndarray,object numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,457.05 kiB,457.05 kiB
Shape,"(58503,)","(58503,)"
Dask graph,1 chunks in 6 graph layers,1 chunks in 6 graph layers
Data type,object numpy.ndarray,object numpy.ndarray
"Array Chunk Bytes 457.05 kiB 457.05 kiB Shape (58503,) (58503,) Dask graph 1 chunks in 6 graph layers Data type object numpy.ndarray",58503  1,

Unnamed: 0,Array,Chunk
Bytes,457.05 kiB,457.05 kiB
Shape,"(58503,)","(58503,)"
Dask graph,1 chunks in 6 graph layers,1 chunks in 6 graph layers
Data type,object numpy.ndarray,object numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,457.05 kiB,70.54 kiB
Shape,"(58503,)","(9029,)"
Dask graph,7 chunks in 6 graph layers,7 chunks in 6 graph layers
Data type,float64 numpy.ndarray,float64 numpy.ndarray
"Array Chunk Bytes 457.05 kiB 70.54 kiB Shape (58503,) (9029,) Dask graph 7 chunks in 6 graph layers Data type float64 numpy.ndarray",58503  1,

Unnamed: 0,Array,Chunk
Bytes,457.05 kiB,70.54 kiB
Shape,"(58503,)","(9029,)"
Dask graph,7 chunks in 6 graph layers,7 chunks in 6 graph layers
Data type,float64 numpy.ndarray,float64 numpy.ndarray


In [212]:
sdata.where(~sdata["train_val"], drop=True)

Unnamed: 0,Array,Chunk
Bytes,50.79 kiB,50.79 kiB
Shape,"(6501,)","(6501,)"
Dask graph,1 chunks in 6 graph layers,1 chunks in 6 graph layers
Data type,object numpy.ndarray,object numpy.ndarray
"Array Chunk Bytes 50.79 kiB 50.79 kiB Shape (6501,) (6501,) Dask graph 1 chunks in 6 graph layers Data type object numpy.ndarray",6501  1,

Unnamed: 0,Array,Chunk
Bytes,50.79 kiB,50.79 kiB
Shape,"(6501,)","(6501,)"
Dask graph,1 chunks in 6 graph layers,1 chunks in 6 graph layers
Data type,object numpy.ndarray,object numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,50.79 kiB,50.79 kiB
Shape,"(6501,)","(6501,)"
Dask graph,1 chunks in 6 graph layers,1 chunks in 6 graph layers
Data type,object numpy.ndarray,object numpy.ndarray
"Array Chunk Bytes 50.79 kiB 50.79 kiB Shape (6501,) (6501,) Dask graph 1 chunks in 6 graph layers Data type object numpy.ndarray",6501  1,

Unnamed: 0,Array,Chunk
Bytes,50.79 kiB,50.79 kiB
Shape,"(6501,)","(6501,)"
Dask graph,1 chunks in 6 graph layers,1 chunks in 6 graph layers
Data type,object numpy.ndarray,object numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,50.79 kiB,50.79 kiB
Shape,"(6501,)","(6501,)"
Dask graph,1 chunks in 6 graph layers,1 chunks in 6 graph layers
Data type,object numpy.ndarray,object numpy.ndarray
"Array Chunk Bytes 50.79 kiB 50.79 kiB Shape (6501,) (6501,) Dask graph 1 chunks in 6 graph layers Data type object numpy.ndarray",6501  1,

Unnamed: 0,Array,Chunk
Bytes,50.79 kiB,50.79 kiB
Shape,"(6501,)","(6501,)"
Dask graph,1 chunks in 6 graph layers,1 chunks in 6 graph layers
Data type,object numpy.ndarray,object numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,50.79 kiB,50.79 kiB
Shape,"(6501,)","(6501,)"
Dask graph,1 chunks in 6 graph layers,1 chunks in 6 graph layers
Data type,object numpy.ndarray,object numpy.ndarray
"Array Chunk Bytes 50.79 kiB 50.79 kiB Shape (6501,) (6501,) Dask graph 1 chunks in 6 graph layers Data type object numpy.ndarray",6501  1,

Unnamed: 0,Array,Chunk
Bytes,50.79 kiB,50.79 kiB
Shape,"(6501,)","(6501,)"
Dask graph,1 chunks in 6 graph layers,1 chunks in 6 graph layers
Data type,object numpy.ndarray,object numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,50.79 kiB,50.79 kiB
Shape,"(6501,)","(6501,)"
Dask graph,1 chunks in 6 graph layers,1 chunks in 6 graph layers
Data type,object numpy.ndarray,object numpy.ndarray
"Array Chunk Bytes 50.79 kiB 50.79 kiB Shape (6501,) (6501,) Dask graph 1 chunks in 6 graph layers Data type object numpy.ndarray",6501  1,

Unnamed: 0,Array,Chunk
Bytes,50.79 kiB,50.79 kiB
Shape,"(6501,)","(6501,)"
Dask graph,1 chunks in 6 graph layers,1 chunks in 6 graph layers
Data type,object numpy.ndarray,object numpy.ndarray


In [206]:
sdata.where(~sdata.train_val)

Unnamed: 0,Array,Chunk
Bytes,507.84 kiB,507.84 kiB
Shape,"(65004,)","(65004,)"
Dask graph,1 chunks in 5 graph layers,1 chunks in 5 graph layers
Data type,float64 numpy.ndarray,float64 numpy.ndarray
"Array Chunk Bytes 507.84 kiB 507.84 kiB Shape (65004,) (65004,) Dask graph 1 chunks in 5 graph layers Data type float64 numpy.ndarray",65004  1,

Unnamed: 0,Array,Chunk
Bytes,507.84 kiB,507.84 kiB
Shape,"(65004,)","(65004,)"
Dask graph,1 chunks in 5 graph layers,1 chunks in 5 graph layers
Data type,float64 numpy.ndarray,float64 numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,507.84 kiB,507.84 kiB
Shape,"(65004,)","(65004,)"
Dask graph,1 chunks in 7 graph layers,1 chunks in 7 graph layers
Data type,object numpy.ndarray,object numpy.ndarray
"Array Chunk Bytes 507.84 kiB 507.84 kiB Shape (65004,) (65004,) Dask graph 1 chunks in 7 graph layers Data type object numpy.ndarray",65004  1,

Unnamed: 0,Array,Chunk
Bytes,507.84 kiB,507.84 kiB
Shape,"(65004,)","(65004,)"
Dask graph,1 chunks in 7 graph layers,1 chunks in 7 graph layers
Data type,object numpy.ndarray,object numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,507.84 kiB,507.84 kiB
Shape,"(65004,)","(65004,)"
Dask graph,1 chunks in 7 graph layers,1 chunks in 7 graph layers
Data type,object numpy.ndarray,object numpy.ndarray
"Array Chunk Bytes 507.84 kiB 507.84 kiB Shape (65004,) (65004,) Dask graph 1 chunks in 7 graph layers Data type object numpy.ndarray",65004  1,

Unnamed: 0,Array,Chunk
Bytes,507.84 kiB,507.84 kiB
Shape,"(65004,)","(65004,)"
Dask graph,1 chunks in 7 graph layers,1 chunks in 7 graph layers
Data type,object numpy.ndarray,object numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,337.24 MiB,337.24 MiB
Shape,"(65004, 170, 4)","(65004, 170, 4)"
Dask graph,1 chunks in 6 graph layers,1 chunks in 6 graph layers
Data type,float64 numpy.ndarray,float64 numpy.ndarray
"Array Chunk Bytes 337.24 MiB 337.24 MiB Shape (65004, 170, 4) (65004, 170, 4) Dask graph 1 chunks in 6 graph layers Data type float64 numpy.ndarray",4  170  65004,

Unnamed: 0,Array,Chunk
Bytes,337.24 MiB,337.24 MiB
Shape,"(65004, 170, 4)","(65004, 170, 4)"
Dask graph,1 chunks in 6 graph layers,1 chunks in 6 graph layers
Data type,float64 numpy.ndarray,float64 numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,507.84 kiB,507.84 kiB
Shape,"(65004,)","(65004,)"
Dask graph,1 chunks in 7 graph layers,1 chunks in 7 graph layers
Data type,object numpy.ndarray,object numpy.ndarray
"Array Chunk Bytes 507.84 kiB 507.84 kiB Shape (65004,) (65004,) Dask graph 1 chunks in 7 graph layers Data type object numpy.ndarray",65004  1,

Unnamed: 0,Array,Chunk
Bytes,507.84 kiB,507.84 kiB
Shape,"(65004,)","(65004,)"
Dask graph,1 chunks in 7 graph layers,1 chunks in 7 graph layers
Data type,object numpy.ndarray,object numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,507.84 kiB,507.84 kiB
Shape,"(65004,)","(65004,)"
Dask graph,1 chunks in 7 graph layers,1 chunks in 7 graph layers
Data type,object numpy.ndarray,object numpy.ndarray
"Array Chunk Bytes 507.84 kiB 507.84 kiB Shape (65004,) (65004,) Dask graph 1 chunks in 7 graph layers Data type object numpy.ndarray",65004  1,

Unnamed: 0,Array,Chunk
Bytes,507.84 kiB,507.84 kiB
Shape,"(65004,)","(65004,)"
Dask graph,1 chunks in 7 graph layers,1 chunks in 7 graph layers
Data type,object numpy.ndarray,object numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,507.84 kiB,507.84 kiB
Shape,"(65004,)","(65004,)"
Dask graph,1 chunks in 7 graph layers,1 chunks in 7 graph layers
Data type,object numpy.ndarray,object numpy.ndarray
"Array Chunk Bytes 507.84 kiB 507.84 kiB Shape (65004,) (65004,) Dask graph 1 chunks in 7 graph layers Data type object numpy.ndarray",65004  1,

Unnamed: 0,Array,Chunk
Bytes,507.84 kiB,507.84 kiB
Shape,"(65004,)","(65004,)"
Dask graph,1 chunks in 7 graph layers,1 chunks in 7 graph layers
Data type,object numpy.ndarray,object numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,507.84 kiB,507.84 kiB
Shape,"(65004,)","(65004,)"
Dask graph,1 chunks in 5 graph layers,1 chunks in 5 graph layers
Data type,float64 numpy.ndarray,float64 numpy.ndarray
"Array Chunk Bytes 507.84 kiB 507.84 kiB Shape (65004,) (65004,) Dask graph 1 chunks in 5 graph layers Data type float64 numpy.ndarray",65004  1,

Unnamed: 0,Array,Chunk
Bytes,507.84 kiB,507.84 kiB
Shape,"(65004,)","(65004,)"
Dask graph,1 chunks in 5 graph layers,1 chunks in 5 graph layers
Data type,float64 numpy.ndarray,float64 numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,507.84 kiB,507.84 kiB
Shape,"(65004,)","(65004,)"
Dask graph,1 chunks in 5 graph layers,1 chunks in 5 graph layers
Data type,float64 numpy.ndarray,float64 numpy.ndarray
"Array Chunk Bytes 507.84 kiB 507.84 kiB Shape (65004,) (65004,) Dask graph 1 chunks in 5 graph layers Data type float64 numpy.ndarray",65004  1,

Unnamed: 0,Array,Chunk
Bytes,507.84 kiB,507.84 kiB
Shape,"(65004,)","(65004,)"
Dask graph,1 chunks in 5 graph layers,1 chunks in 5 graph layers
Data type,float64 numpy.ndarray,float64 numpy.ndarray


In [192]:
train_mask = np.where(sdata[train_key])[0]
train_sdata = sdata.isel(_sequence=train_mask)
val_sdata = sdata.isel(_sequence=~train_mask)
train_dataloader = sd.get_torch_dataloader(
    train_sdata,
    sample_dims=["_sequence"],
    variables=[seq_key, "target"],
    transforms=seq_transforms,
    prefetch_factor=2,
    shuffle=True,
    drop_last=drop_last,
    batch_size=batch_size,
    num_workers=num_workers
)
val_dataloader = sd.get_torch_dataloader(
    val_sdata,
    sample_dims=["_sequence"],
    variables=[seq_key, "target"],
    transforms=seq_transforms,
    prefetch_factor=2,
    shuffle=False,
    drop_last=drop_last,
    batch_size=batch_size,
    num_workers=num_workers
)

In [199]:
train_sdata

Unnamed: 0,Array,Chunk
Bytes,228.53 kiB,228.53 kiB
Shape,"(58503,)","(58503,)"
Dask graph,1 chunks in 3 graph layers,1 chunks in 3 graph layers
Data type,,
"Array Chunk Bytes 228.53 kiB 228.53 kiB Shape (58503,) (58503,) Dask graph 1 chunks in 3 graph layers Data type",58503  1,

Unnamed: 0,Array,Chunk
Bytes,228.53 kiB,228.53 kiB
Shape,"(58503,)","(58503,)"
Dask graph,1 chunks in 3 graph layers,1 chunks in 3 graph layers
Data type,,

Unnamed: 0,Array,Chunk
Bytes,1.79 MiB,1.79 MiB
Shape,"(58503,)","(58503,)"
Dask graph,1 chunks in 3 graph layers,1 chunks in 3 graph layers
Data type,,
"Array Chunk Bytes 1.79 MiB 1.79 MiB Shape (58503,) (58503,) Dask graph 1 chunks in 3 graph layers Data type",58503  1,

Unnamed: 0,Array,Chunk
Bytes,1.79 MiB,1.79 MiB
Shape,"(58503,)","(58503,)"
Dask graph,1 chunks in 3 graph layers,1 chunks in 3 graph layers
Data type,,

Unnamed: 0,Array,Chunk
Bytes,228.53 kiB,228.53 kiB
Shape,"(58503,)","(58503,)"
Dask graph,1 chunks in 3 graph layers,1 chunks in 3 graph layers
Data type,,
"Array Chunk Bytes 228.53 kiB 228.53 kiB Shape (58503,) (58503,) Dask graph 1 chunks in 3 graph layers Data type",58503  1,

Unnamed: 0,Array,Chunk
Bytes,228.53 kiB,228.53 kiB
Shape,"(58503,)","(58503,)"
Dask graph,1 chunks in 3 graph layers,1 chunks in 3 graph layers
Data type,,

Unnamed: 0,Array,Chunk
Bytes,228.53 kiB,228.53 kiB
Shape,"(58503,)","(58503,)"
Dask graph,1 chunks in 3 graph layers,1 chunks in 3 graph layers
Data type,,
"Array Chunk Bytes 228.53 kiB 228.53 kiB Shape (58503,) (58503,) Dask graph 1 chunks in 3 graph layers Data type",58503  1,

Unnamed: 0,Array,Chunk
Bytes,228.53 kiB,228.53 kiB
Shape,"(58503,)","(58503,)"
Dask graph,1 chunks in 3 graph layers,1 chunks in 3 graph layers
Data type,,

Unnamed: 0,Array,Chunk
Bytes,228.53 kiB,228.53 kiB
Shape,"(58503,)","(58503,)"
Dask graph,1 chunks in 3 graph layers,1 chunks in 3 graph layers
Data type,,
"Array Chunk Bytes 228.53 kiB 228.53 kiB Shape (58503,) (58503,) Dask graph 1 chunks in 3 graph layers Data type",58503  1,

Unnamed: 0,Array,Chunk
Bytes,228.53 kiB,228.53 kiB
Shape,"(58503,)","(58503,)"
Dask graph,1 chunks in 3 graph layers,1 chunks in 3 graph layers
Data type,,

Unnamed: 0,Array,Chunk
Bytes,57.13 kiB,57.13 kiB
Shape,"(58503,)","(58503,)"
Dask graph,1 chunks in 3 graph layers,1 chunks in 3 graph layers
Data type,bool numpy.ndarray,bool numpy.ndarray
"Array Chunk Bytes 57.13 kiB 57.13 kiB Shape (58503,) (58503,) Dask graph 1 chunks in 3 graph layers Data type bool numpy.ndarray",58503  1,

Unnamed: 0,Array,Chunk
Bytes,57.13 kiB,57.13 kiB
Shape,"(58503,)","(58503,)"
Dask graph,1 chunks in 3 graph layers,1 chunks in 3 graph layers
Data type,bool numpy.ndarray,bool numpy.ndarray


In [200]:
val_sdata

Unnamed: 0,Array,Chunk
Bytes,228.53 kiB,228.53 kiB
Shape,"(58503,)","(58503,)"
Dask graph,1 chunks in 3 graph layers,1 chunks in 3 graph layers
Data type,,
"Array Chunk Bytes 228.53 kiB 228.53 kiB Shape (58503,) (58503,) Dask graph 1 chunks in 3 graph layers Data type",58503  1,

Unnamed: 0,Array,Chunk
Bytes,228.53 kiB,228.53 kiB
Shape,"(58503,)","(58503,)"
Dask graph,1 chunks in 3 graph layers,1 chunks in 3 graph layers
Data type,,

Unnamed: 0,Array,Chunk
Bytes,1.79 MiB,1.79 MiB
Shape,"(58503,)","(58503,)"
Dask graph,1 chunks in 3 graph layers,1 chunks in 3 graph layers
Data type,,
"Array Chunk Bytes 1.79 MiB 1.79 MiB Shape (58503,) (58503,) Dask graph 1 chunks in 3 graph layers Data type",58503  1,

Unnamed: 0,Array,Chunk
Bytes,1.79 MiB,1.79 MiB
Shape,"(58503,)","(58503,)"
Dask graph,1 chunks in 3 graph layers,1 chunks in 3 graph layers
Data type,,

Unnamed: 0,Array,Chunk
Bytes,228.53 kiB,228.53 kiB
Shape,"(58503,)","(58503,)"
Dask graph,1 chunks in 3 graph layers,1 chunks in 3 graph layers
Data type,,
"Array Chunk Bytes 228.53 kiB 228.53 kiB Shape (58503,) (58503,) Dask graph 1 chunks in 3 graph layers Data type",58503  1,

Unnamed: 0,Array,Chunk
Bytes,228.53 kiB,228.53 kiB
Shape,"(58503,)","(58503,)"
Dask graph,1 chunks in 3 graph layers,1 chunks in 3 graph layers
Data type,,

Unnamed: 0,Array,Chunk
Bytes,228.53 kiB,228.53 kiB
Shape,"(58503,)","(58503,)"
Dask graph,1 chunks in 3 graph layers,1 chunks in 3 graph layers
Data type,,
"Array Chunk Bytes 228.53 kiB 228.53 kiB Shape (58503,) (58503,) Dask graph 1 chunks in 3 graph layers Data type",58503  1,

Unnamed: 0,Array,Chunk
Bytes,228.53 kiB,228.53 kiB
Shape,"(58503,)","(58503,)"
Dask graph,1 chunks in 3 graph layers,1 chunks in 3 graph layers
Data type,,

Unnamed: 0,Array,Chunk
Bytes,228.53 kiB,228.53 kiB
Shape,"(58503,)","(58503,)"
Dask graph,1 chunks in 3 graph layers,1 chunks in 3 graph layers
Data type,,
"Array Chunk Bytes 228.53 kiB 228.53 kiB Shape (58503,) (58503,) Dask graph 1 chunks in 3 graph layers Data type",58503  1,

Unnamed: 0,Array,Chunk
Bytes,228.53 kiB,228.53 kiB
Shape,"(58503,)","(58503,)"
Dask graph,1 chunks in 3 graph layers,1 chunks in 3 graph layers
Data type,,

Unnamed: 0,Array,Chunk
Bytes,57.13 kiB,57.13 kiB
Shape,"(58503,)","(58503,)"
Dask graph,1 chunks in 3 graph layers,1 chunks in 3 graph layers
Data type,bool numpy.ndarray,bool numpy.ndarray
"Array Chunk Bytes 57.13 kiB 57.13 kiB Shape (58503,) (58503,) Dask graph 1 chunks in 3 graph layers Data type bool numpy.ndarray",58503  1,

Unnamed: 0,Array,Chunk
Bytes,57.13 kiB,57.13 kiB
Shape,"(58503,)","(58503,)"
Dask graph,1 chunks in 3 graph layers,1 chunks in 3 graph layers
Data type,bool numpy.ndarray,bool numpy.ndarray


In [193]:
batch = next(iter(train_dataloader))
batch_ohe_seq = batch[seq_key]
batch_target = batch["target"]
batch_ohe_seq.shape, batch_target.shape

(torch.Size([128, 4, 170]), torch.Size([128]))

In [194]:
from tqdm.auto import tqdm

In [195]:
for i, batch in tqdm(enumerate(train_dataloader), total=len(train_dataloader), desc="Looping over train dataloader"):
    batch_ohe_seq = batch[seq_key]
    batch_target = batch["target"]

Looping over train dataloader:   0%|          | 0/457 [00:00<?, ?it/s]

In [197]:
for i, batch in tqdm(enumerate(val_dataloader), total=len(val_dataloader), desc="Looping over val dataloader"):
    batch_ohe_seq = batch[seq_key]
    batch_target = batch["target"]

Looping over val dataloader:   0%|          | 0/457 [00:00<?, ?it/s]

In [167]:
for i, batch in enumerate(val_dataloader):
    batch_ohe_seq = batch[seq_key]
    batch_target = batch["target"]
    print(batch_ohe_seq.shape, batch_target.shape)
    if i > 10:
        break

torch.Size([32, 4, 170]) torch.Size([32])
torch.Size([32, 4, 170]) torch.Size([32])
torch.Size([32, 4, 170]) torch.Size([32])
torch.Size([32, 4, 170]) torch.Size([32])
torch.Size([32, 4, 170]) torch.Size([32])
torch.Size([32, 4, 170]) torch.Size([32])
torch.Size([32, 4, 170]) torch.Size([32])
torch.Size([32, 4, 170]) torch.Size([32])
torch.Size([32, 4, 170]) torch.Size([32])
torch.Size([32, 4, 170]) torch.Size([32])
torch.Size([32, 4, 170]) torch.Size([32])
torch.Size([32, 4, 170]) torch.Size([32])


In [144]:
to_decode = batch_ohe_seq[0].numpy()

In [146]:
to_decode.shape

(4, 170)

In [147]:
DNA = ["A", "C", "G", "T"]
RNA = ["A", "C", "G", "U"]

def _get_vocab(vocab):
    if vocab == "DNA":
        return DNA
    elif vocab == "RNA":
        return RNA
    else:
        raise ValueError("Invalid vocab, only DNA or RNA are currently supported")

# exact concise
def _get_index_dict(vocab):
    """
    Returns a dictionary mapping each token to its index in the vocabulary.
    """
    return {i: l for i, l in enumerate(vocab)}

# modified dinuc_shuffle
def _one_hot2token(one_hot, neutral_value=-1, consensus=False):
    """
    Converts a one-hot encoding into a vector of integers in the range [0, D]
    where D is the number of classes in the one-hot encoding.

    Parameters
    ----------
    one_hot : np.array
        L x D one-hot encoding
    neutral_value : int, optional
        Value to use for neutral values.
    
    Returns
    -------
    np.array
        L-vector of integers in the range [0, D]
    """
    if consensus:
        return np.argmax(one_hot, axis=0)
    tokens = np.tile(neutral_value, one_hot.shape[1])  # Vector of all D
    seq_inds, dim_inds = np.where(one_hot.transpose()==1)
    tokens[seq_inds] = dim_inds
    return tokens

def _sequencize(tvec, vocab="DNA", neutral_value=-1, neutral_char="N"):
    """
    Converts a token vector into a sequence of symbols of a vocab.
    """
    vocab = _get_vocab(vocab) 
    index_dict = _get_index_dict(vocab)
    index_dict[neutral_value] = neutral_char
    return "".join([index_dict[i] for i in tvec])

def decode_seq(arr, vocab="DNA", neutral_value=-1, neutral_char="N"):
    """Convert a single one-hot encoded array back to string"""
    if isinstance(arr, torch.Tensor):
        arr = arr.numpy()
    return _sequencize(
        tvec=_one_hot2token(arr, neutral_value),
        vocab=vocab,
        neutral_value=neutral_value,
        neutral_char=neutral_char,
    )

In [152]:
len(val_dataloader)

1828

In [151]:
len(train_dataloader)

1828

In [149]:
decode_seq(to_decode)

'CATATCATTTATGTACCAAGGGGTTTAGGGTTAATTGTTGAATATTGTGAGTGAGATGTACCATTTTCTATAAGTGTGTCTACATTTTGTTCTTTATCTAAATATTCTTTTAATAGATGCACGAGGAAGATAGACTAATACAGAGGCATACACTACCACACATGTAGCTA'

In [150]:
batch_target[0]

tensor(1.6883, dtype=torch.float64)