In [43]:
import os
import mmh3
import numpy as np
import pandas as pd
import torch
import matplotlib.pyplot as plt
import scipy
import xarray as xr
from tqdm import tqdm

import seqdata as sd

from eugene import preprocess as pp

In [44]:
LEFT_ADAPTER = "TGCATTTTTTTCACATC" 
RIGHT_ADAPTER = "GGTTACGGCTGTT"

PLASMID = "aactctcaaggatcttaccgctgttgagatccagttcgatgtaacccactcgtgcacccaactgatcttcagcatcttttactttcaccagcgtttctgggtgagcaaaaacaggaaggcaaaatgccgcaaaaaagggaataagggcgacacggaaatgttgaatactcatactcttcctttttcaatattattgaagcatttatcagggttattgtctcatgagcggatacatatttgaatgtatttagaaaaataaacaaataggggttccgcgcacatttccccgaaaagtgccacctgacgtcatctatattaccctgttatccctagcggatctgccggtagaggtgtggtcaataagagcgacctcatactatacctgagaaagcaacctgacctacaggaaagagttactcaagaataagaattttcgttttaaaacctaagagtcactttaaaatttgtatacacttattttttttataacttatttaataataaaaatcataaatcataagaaattcgcttatttagaagtGGCGCGCCGGTCCGttacttgtacagctcgtccatgccgccggtggagtggcggccctcggcgcgttcgtactgttccacgatggtgtagtcctcgttgtgggaggtgatgtccaacttgatgttgacgttgtaggcgccgggcagctgcacgggcttcttggccttgtaggtggtcttgacctcagcgtcgtagtggccgccgtccttcagcttcagcctctgcttgatctcgcccttcagggcgccgtcctcggggtacatccgctcggaggaggcctcccagcccatggtcttcttctgcattacggggccgtcggaggggaagttggtgccgcgcagcttcaccttgtagatgaactcgccgtcctgcagggaggagtcctgggtcacggtcaccacgccgccgtcctcgaagttcatcacgcgctcccacttgaagccctcggggaaggacagcttcaagtagtcggggatgtcggcggggtgcttcacgtaggccttggagccgtacatgaactgaggggacaggatgtcccaggcgaagggcagggggccacccttggtcaccttcagcttggcggtctgggtgccctcgtaggggcggccctcgccctcgccctcgatctcgaactcgtggccgttcacggagccctccatgtgcaccttgaagcgcatgaactccttgatgatggccatgttatcctcctcgcccttgctcacCATGGTACTAGTGTTTAGTTAATTATAGTTCGTTGACCGTATATTCTAAAAACAAGTACTCCTTAAAAAAAAACCTTGAAGGGAATAAACAAGTAGAATAGATAGAGAGAAAAATAGAAAATGCAAGAGAATTTATATATTAGAAAGAGAGAAAGAAAAATGGAAAAAAAAAAATAGGAAAAGCCAGAAATAGCACTAGAAGGAGCGACACCAGAAAAGAAGGTGATGGAACCAATTTAGCTATATATAGTTAACTACCGGCTCGATCATCTCTGCCTCCAGCATAGTCGAAGAAGAATTTTTTTTTTCTTGAGGCTTCTGTCAGCAACTCGTATTTTTTCTTTCTTTTTTGGTGAGCCTAAAAAGTTCCCACGTTCTCTTGTACGACGCCGTCACAAACAACCTTATGGGTAATTTGTCGCGGTCTGGGTGTATAAATGTGTGGGTGCAACATGAATGTACGGAGGTAGTTTGCTGATTGGCGGTCTATAGATACCTTGGTTATGGCGCCCTCACAGCCGGCAGGGGAAGCGCCTACGCTTGACATCTACTATATGTAAGTATACGGCCCCATATATAggccctttcgtctcgcgcgtttcggtgatgacggtgaaaacctctgacacatgcagctcccggagacggtcacagcttgtctgtaagcggatgccgggagcagacaagcccgtcagggcgcgtcagcgggtgttggcgggtgtcggggctggcttaactatgcggcatcagagcagattgtactgagagtgcaccatatggacatattgtcgttagaacgcggctacaattaatacataaccttatgtatcatacacatacgatttaggtgacactatagaacgcggccgccagctgaagctttaactatgcggcatcagagcagattgtactgagagtgcaccataccaccttttcaattcatcattttttttttattcttttttttgatttcggtttccttgaaatttttttgattcggtaatctccgaacagaaggaagaacgaaggaaggagcacagacttagattggtatatatacgcatatgtagtgttgaagaaacatgaaattgcccagtattcttaacccaactgcacagaacaaaaacctgcaggaaacgaagataaatcatgtcgaaagctacatataaggaacgtgctgctactcatcctagtcctgttgctgccaagctatttaatatcatgcacgaaaagcaaacaaacttgtgtgcttcattggatgttcgtaccaccaaggaattactggagttagttgaagcattaggtcccaaaatttgtttactaaaaacacatgtggatatcttgactgatttttccatggagggcacagttaagccgctaaaggcattatccgccaagtacaattttttactcttcgaagacagaaaatttgctgacattggtaatacagtcaaattgcagtactctgcgggtgtatacagaatagcagaatgggcagacattacgaatgcacacggtgtggtgggcccaggtattgttagcggtttgaagcaggcggcagaagaagtaacaaaggaacctagaggccttttgatgttagcagaattgtcatgcaagggctccctatctactggagaatatactaagggtactgttgacattgcgaagagcgacaaagattttgttatcggctttattgctcaaagagacatgggtggaagagatgaaggttacgattggttgattatgacacccggtgtgggtttagatgacaagggagacgcattgggtcaacagtatagaaccgtggatgatgtggtctctacaggatctgacattattattgttggaagaggactatttgcaaagggaagggatgctaaggtagagggtgaacgttacagaaaagcaggctgggaagcatatttgagaagatgcggccagcaaaactaaaaaactgtattataagtaaatgcatgtatactaaactcacaaattagagcttcaatttaattatatcagttattaccctatgcggtgtgaaataccgcacagatgcgtaaggagaaaataccgcatcaggaaattgtaagcgttaatattttgttaaaattcgcgttaaatttttgttaaatcagctcattttttaaccaataggccgaaatcggcaaaatcccttataaatcaaaagaatagaccgagatagggttgagtgttgttccagtttggaacaagagtccactattaaagaacgtggactccaacgtcaaagggcgaaaaaccgtctatcagggcgatggcccactacgtgaaccatcaccctaatcaagtGCTAGCAGGAATGATGCAAAAGGTTCCCGATTCGAACTGCATTTTTTTCACATCNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNGGTTACGGCTGTTTCTTAATTAAAAAAAGATAGAAAACATTAGGAGTGTAACACAAGACTTTCGGATCCTGAGCAGGCAAGATAAACGAAGGCAAAGatgtctaaaggtgaagaattattcactggtgttgtcccaattttggttgaattagatggtgatgttaatggtcacaaattttctgtctccggtgaaggtgaaggtgatgctacttacggtaaattgaccttaaaattgatttgtactactggtaaattgccagttccatggccaaccttagtcactactttaggttatggtttgcaatgttttgctagatacccagatcatatgaaacaacatgactttttcaagtctgccatgccagaaggttatgttcaagaaagaactatttttttcaaagatgacggtaactacaagaccagagctgaagtcaagtttgaaggtgataccttagttaatagaatcgaattaaaaggtattgattttaaagaagatggtaacattttaggtcacaaattggaatacaactataactctcacaatgtttacatcactgctgacaaacaaaagaatggtatcaaagctaacttcaaaattagacacaacattgaagatggtggtgttcaattagctgaccattatcaacaaaatactccaattggtgatggtccagtcttgttaccagacaaccattacttatcctatcaatctgccttatccaaagatccaaacgaaaagagagaccacatggtcttgttagaatttgttactgctgctggtattacccatggtatggatgaattgtacaaataaggcgcgccacttctaaataagcgaatttcttatgatttatgatttttattattaaataagttataaaaaaaataagtgtatacaaattttaaagtgactcttaggttttaaaacgaaaattcttattcttgagtaactctttcctgtaggtcaggttgctttctcaggtatagtatgaggtcgctcttattgaccacacctctaccggcagatccgctagggataacagggtaatataGATCTGTTTAGCTTGCCTCGTCCCCGCCGGGTCACCCGGCCAGCGACATGGAGGCCCAGAATACCCTCCTTGACAGTCTTGACGTGCGCAGCTCAGGGGCATGATGTGACTGTCGCCCGTACATTTAGCCCATACATCCCCATGTATAATCATTTGCATCCATACATTTTGATGGCCGCACGGCGCGAAGCAAAAATTACGGCTCCTCGCTGCAGACCTGCGAGCAGGGAAACGCTCCCCTCACAGACGCGTTGAATTGTCCCCACGCCGCGCCCCTGTAGAGAAATATAAAAGGTTAGGATTTGCCACTGAGGTTCTTCTTTCATATACTTCCTTTTAAAATCTTGCTAGGATACAGTTCTCACATCACATCCGAACATAAACAACCATGGGTACCACTCTTGACGACACGGCTTACCGGTACCGCACCAGTGTCCCGGGGGACGCCGAGGCCATCGAGGCACTGGATGGGTCCTTCACCACCGACACCGTCTTCCGCGTCACCGCCACCGGGGACGGCTTCACCCTGCGGGAGGTGCCGGTGGACCCGCCCCTGACCAAGGTGTTCCCCGACGACGAATCGGACGACGAATCGGACGACGGGGAGGACGGCGACCCGGACTCCCGGACGTTCGTCGCGTACGGGGACGACGGCGACCTGGCGGGCTTCGTGGTCGTCTCGTACTCCGGCTGGAACCGCCGGCTGACCGTCGAGGACATCGAGGTCGCCCCGGAGCACCGGGGGCACGGGGTCGGGCGCGCGTTGATGGGGCTCGCGACGGAGTTCGCCCGCGAGCGGGGCGCCGGGCACCTCTGGCTGGAGGTCACCAACGTCAACGCACCGGCGATCCACGCGTACCGGCGGATGGGGTTCACCCTCTGCGGCCTGGACACCGCCCTGTACGACGGCACCGCCTCGGACGGCGAGCAGGCGCTCTACATGAGCATGCCCTGCCCCTAATCAGTACTGACAATAAAAAGATTCTTGTTTTCAAGAACTTGTCATTTGTATAGTTTTTTTATATTGTAGTTGTTCTATTTTAATCAAATGTTAGCGTGATTTATATTTTTTTTCGCCTCGACATCATCTGCCCAGATGCGAAGTTAAGTGCGCAGAAAGTAATATCATGCGTCAATCGTATGTGAATGCTGGTCGCTATACTGCTGTCGATTCGATACTAACGCCGCCATCCAGTGTCGAAAACGAGCTCGaattcctgggtccttttcatcacgtgctataaaaataattataatttaaattttttaatataaatatataaattaaaaatagaaagtaaaaaaagaaattaaagaaaaaatagtttttgttttccgaagatgtaaaagactctagggggatcgccaacaaatactaccttttatcttgctcttcctgctctcaggtattaatgccgaattgtttcatcttgtctgtgtagaagaccacacacgaaaatcctgtgattttacattttacttatcgttaatcgaatgtatatctatttaatctgcttttcttgtctaataaatatatatgtaaagtacgctttttgttgaaattttttaaacctttgtttatttttttttcttcattccgtaactcttctaccttctttatttactttctaaaatccaaatacaaaacataaaaataaataaacacagagtaaattcccaaattattccatcattaaaagatacgaggcgcgtgtaagttacaggcaagcgatccgtccGATATCatcagatccactagtggcctatgcggccgcggatctgccggtctccctatagtgagtcgtattaatttcgataagccaggttaacctgcattaatgaatcggccaacgcgcggggagaggcggtttgcgtattgggcgctcttccgcttcctcgctcactgactcgctgcgctcggtcgttcggctgcggcgagcggtatcagctcactcaaaggcggtaatacggttatccacagaatcaggggataacgcaggaaagaacatgtgagcaaaaggccagcaaaaggccaggaaccgtaaaaaggccgcgttgctggcgtttttccataggctccgcccccctgacgagcatcacaaaaatcgacgctcaagtcagaggtggcgaaacccgacaggactataaagataccaggcgtttccccctggaagctccctcgtgcgctctcctgttccgaccctgccgcttaccggatacctgtccgcctttctcccttcgggaagcgtggcgctttctcaTAgctcacgctgtaggtatctcagttcggtgtaggtcgttcgctccaagctgggctgtgtgcacgaaccccccgttcagcccgaccgctgcgccttatccggtaactatcgtcttgagtccaacccggtaagacacgacttatcgccactggcagcagccactggtaacaggattagcagagcgaggtatgtaggcggtgctacagagttcttgaagtggtggcctaactacggctacactagaagAacagtatttggtatctgcgctctgctgaagccagttaccttcggaaaaagagttggtagctcttgatccggcaaacaaaccaccgctggtagcggtggtttttttgtttgcaagcagcagattacgcgcagaaaaaaaggatctcaagaagatcctttgatcttttctacggggtctgacgctcagtggaacgaaaactcacgttaagggattttggtcatgagattatcaaaaaggatcttcacctagatccttttaaattaaaaatgaagttttaaatcaatctaaagtatatatgagtaaacttggtctgacagttaccaatgcttaatcagtgaggcacctatctcagcgatctgtctatttcgttcatccatagttgcctgactccccgtcgtgtagataactacgatacgggagggcttaccatctggccccagtgctgcaatgataccgcgagacccacgTtcaccggctccagatttatcagcaataaaccagccagccggaagggccgagcgcagaagtggtcctgcaactttatccgcctccatccagtctattaattgttgccgggaagctagagtaagtagttcgccagttaatagtttgcgcaacgttgttgccattgctacaggcatcgtggtgtcacgctcgtcgtttggtatggcttcattcagctccggttcccaacgatcaaggcgagttacatgatcccccatgttgtgcaaaaaagcggttagctccttcggtcctccgatcgttgtcagaagtaagttggccgcagtgttatcactcatggttatggcagcactgcataattctcttactgtcatgccatccgtaagatgcttttctgtgactggtgagtactcaaccaagtcattctgagaatagtgtatgcggcgaccgagttgctcttgcccggcgtcaatacgggataataccgcgccacatagcagaactttaaaagtgctcatcattggaaaacgttcttcggggcgaa"
PLASMID = PLASMID.upper()
INSERT_START = PLASMID.find('N'*80)

def hash_fun(seq, seed):
    return mmh3.hash(seq, seed, signed=False) % 10

def preprocess_data(data, length):
    data = data.copy()
    add_part = PLASMID[INSERT_START-length:INSERT_START]
    data.seq = data.seq.apply(lambda x:  add_part + x[len(LEFT_ADAPTER):])
    data.seq = data.seq.str.slice(-length, None)
    return data

In [45]:
# Change working dir
os.chdir("/cellar/users/aklie/data/datasets/deBoer_random-promoters_GPRA/")

# Load data

In [46]:
# Read tabular
training_df = pd.read_table("dataset_preparation/2023_12_22/dream/train_sequences.txt", header=None, sep="\t")
training_df.columns = ["seq", "bin"]
print(len(training_df))
training_df.head()

6739258


Unnamed: 0,seq,bin
0,TGCATTTTTTTCACATCTCTTTGCCACGGGGTGAAGGATAGGATGG...,11.0
1,TGCATTTTTTTCACATCTATGTTGCGTTAGAACGATATTGGAACAC...,6.0
2,TGCATTTTTTTCACATCTGTGAAGAATATCAGCTTTCAATCGTATT...,8.0
3,TGCATTTTTTTCACATCAATCCGAGATATCTGTTGATAAACTTACC...,9.0
4,TGCATTTTTTTCACATCAAGTTATCTGGTGTACGTTTTCTCGTATA...,12.0


In [47]:
# Make a fold column
fold = list(map(lambda x: hash_fun(x, 1234), training_df.seq))
training_df['fold'] = fold
training_df = training_df.sort_values('fold')
training_df["fold"].value_counts()

7    674870
3    674634
0    674598
6    674300
8    674223
1    674008
2    673957
5    673238
4    672990
9    672440
Name: fold, dtype: int64

In [48]:
# Create the true sequence
training_df = preprocess_data(training_df, 150)
training_df.head()

Unnamed: 0,seq,bin,fold
4288439,AGTGCTAGCAGGAATGATGCAAAAGGTTCCCGATTCGAACTGCATT...,11.0,0
5319238,AGTGCTAGCAGGAATGATGCAAAAGGTTCCCGATTCGAACTGCATT...,14.0,0
672928,AAGTGCTAGCAGGAATGATGCAAAAGGTTCCCGATTCGAACTGCAT...,7.090078,0
672929,AGTGCTAGCAGGAATGATGCAAAAGGTTCCCGATTCGAACTGCATT...,12.0,0
2485732,AGTGCTAGCAGGAATGATGCAAAAGGTTCCCGATTCGAACTGCATT...,14.0,0


In [49]:
# Add a singleton
training_df["is_singleton"] = np.array([x.is_integer() for x in training_df["bin"].values])

# Make SeqData

In [50]:
# Make an xarray
training_sdata = training_df.to_xarray()
training_sdata = training_sdata.rename_dims({"index": "_sequence"})
training_sdata

In [51]:
# Ohe
pp.ohe_seqs_sdata(sdata=training_sdata)

In [52]:
# Train-val split
training_sdata["train_val"] = xr.DataArray(training_sdata["fold"] != 9)
training_sdata

In [108]:
def bin_transform(x, shift=0.5, scale=0.5):
    x = np.asarray(x)  # Ensure x is a NumPy array for vectorization
    norm = scipy.stats.norm(loc=x + shift, scale=scale)
    
    # Calculate the cumulative distribution function (CDF) for each point
    cumprobs = norm.cdf(POINTS[:, None]) if x.ndim > 0 else norm.cdf(POINTS)
    
    # Calculate the differences to get the probabilities for each bin
    probs = cumprobs[1:] - cumprobs[:-1]
    
    # If the input is multidimensional, the output should be correspondingly structured
    return probs.T if x.ndim > 0 else probs

In [93]:
training_sdata["probs_bin"] = xr.DataArray(bin_transform(training_sdata["bin"].values), dims=("_sequence", "_probs"))

Unnamed: 0,Array,Chunk
Bytes,51.42 MiB,51.42 MiB
Shape,"(6739258,)","(6739258,)"
Dask graph,1 chunks in 1 graph layer,1 chunks in 1 graph layer
Data type,object numpy.ndarray,object numpy.ndarray
"Array Chunk Bytes 51.42 MiB 51.42 MiB Shape (6739258,) (6739258,) Dask graph 1 chunks in 1 graph layer Data type object numpy.ndarray",6739258  1,

Unnamed: 0,Array,Chunk
Bytes,51.42 MiB,51.42 MiB
Shape,"(6739258,)","(6739258,)"
Dask graph,1 chunks in 1 graph layer,1 chunks in 1 graph layer
Data type,object numpy.ndarray,object numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,51.42 MiB,51.42 MiB
Shape,"(6739258,)","(6739258,)"
Dask graph,1 chunks in 1 graph layer,1 chunks in 1 graph layer
Data type,float64 numpy.ndarray,float64 numpy.ndarray
"Array Chunk Bytes 51.42 MiB 51.42 MiB Shape (6739258,) (6739258,) Dask graph 1 chunks in 1 graph layer Data type float64 numpy.ndarray",6739258  1,

Unnamed: 0,Array,Chunk
Bytes,51.42 MiB,51.42 MiB
Shape,"(6739258,)","(6739258,)"
Dask graph,1 chunks in 1 graph layer,1 chunks in 1 graph layer
Data type,float64 numpy.ndarray,float64 numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,51.42 MiB,51.42 MiB
Shape,"(6739258,)","(6739258,)"
Dask graph,1 chunks in 1 graph layer,1 chunks in 1 graph layer
Data type,int64 numpy.ndarray,int64 numpy.ndarray
"Array Chunk Bytes 51.42 MiB 51.42 MiB Shape (6739258,) (6739258,) Dask graph 1 chunks in 1 graph layer Data type int64 numpy.ndarray",6739258  1,

Unnamed: 0,Array,Chunk
Bytes,51.42 MiB,51.42 MiB
Shape,"(6739258,)","(6739258,)"
Dask graph,1 chunks in 1 graph layer,1 chunks in 1 graph layer
Data type,int64 numpy.ndarray,int64 numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,6.43 MiB,6.43 MiB
Shape,"(6739258,)","(6739258,)"
Dask graph,1 chunks in 1 graph layer,1 chunks in 1 graph layer
Data type,bool numpy.ndarray,bool numpy.ndarray
"Array Chunk Bytes 6.43 MiB 6.43 MiB Shape (6739258,) (6739258,) Dask graph 1 chunks in 1 graph layer Data type bool numpy.ndarray",6739258  1,

Unnamed: 0,Array,Chunk
Bytes,6.43 MiB,6.43 MiB
Shape,"(6739258,)","(6739258,)"
Dask graph,1 chunks in 1 graph layer,1 chunks in 1 graph layer
Data type,bool numpy.ndarray,bool numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,3.77 GiB,3.77 GiB
Shape,"(6739258, 150, 4)","(6739258, 150, 4)"
Dask graph,1 chunks in 1 graph layer,1 chunks in 1 graph layer
Data type,uint8 numpy.ndarray,uint8 numpy.ndarray
"Array Chunk Bytes 3.77 GiB 3.77 GiB Shape (6739258, 150, 4) (6739258, 150, 4) Dask graph 1 chunks in 1 graph layer Data type uint8 numpy.ndarray",4  150  6739258,

Unnamed: 0,Array,Chunk
Bytes,3.77 GiB,3.77 GiB
Shape,"(6739258, 150, 4)","(6739258, 150, 4)"
Dask graph,1 chunks in 1 graph layer,1 chunks in 1 graph layer
Data type,uint8 numpy.ndarray,uint8 numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,6.43 MiB,6.43 MiB
Shape,"(6739258,)","(6739258,)"
Dask graph,1 chunks in 1 graph layer,1 chunks in 1 graph layer
Data type,bool numpy.ndarray,bool numpy.ndarray
"Array Chunk Bytes 6.43 MiB 6.43 MiB Shape (6739258,) (6739258,) Dask graph 1 chunks in 1 graph layer Data type bool numpy.ndarray",6739258  1,

Unnamed: 0,Array,Chunk
Bytes,6.43 MiB,6.43 MiB
Shape,"(6739258,)","(6739258,)"
Dask graph,1 chunks in 1 graph layer,1 chunks in 1 graph layer
Data type,bool numpy.ndarray,bool numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,0.90 GiB,0.90 GiB
Shape,"(6739258, 18)","(6739258, 18)"
Dask graph,1 chunks in 1 graph layer,1 chunks in 1 graph layer
Data type,float64 numpy.ndarray,float64 numpy.ndarray
"Array Chunk Bytes 0.90 GiB 0.90 GiB Shape (6739258, 18) (6739258, 18) Dask graph 1 chunks in 1 graph layer Data type float64 numpy.ndarray",18  6739258,

Unnamed: 0,Array,Chunk
Bytes,0.90 GiB,0.90 GiB
Shape,"(6739258, 18)","(6739258, 18)"
Dask graph,1 chunks in 1 graph layer,1 chunks in 1 graph layer
Data type,float64 numpy.ndarray,float64 numpy.ndarray


In [155]:
training_sdata["seq"] = training_sdata["seq"].astype(str)

In [None]:
training_sdata = training_sdata.chunk({'_sequence': 1000})

In [159]:
sd.to_zarr(training_sdata, "training/2023_12_24/dream/training.zarr", mode="w")

In [141]:
# Define training transformations
POINTS = np.array([-np.inf, *range(1, 18, 1), np.inf])
from eugene.dataload._augment import RandomRC


def seq_transform(x):
    return x.swapaxes(1, 2)

def to_tensor(x):
    return tuple(torch.tensor(arr, dtype=torch.float32) for arr in x)

def random_rc(x):
    return RandomRC(rc_prob=0.5)(*x)

def bin_transform(x, shift=0.5, scale=0.5):
    norm = scipy.stats.norm(loc=x+shift, scale=scale)
    cumprobs = norm.cdf(POINTS)
    probs = cumprobs[1:] - cumprobs[:-1]
    return probs

In [142]:
train_dl = sd.get_torch_dataloader(
    training_sdata,
    sample_dims=['_sequence'],
    variables=['ohe_seq', 'bin', 'probs_bin'],
    prefetch_factor=None,
    batch_size=32,
    transforms={
        'ohe_seq': seq_transform,
#        ('ohe_seq', 'bin'): to_tensor,
#        'ohe_seq': random_rc,
    },
    return_tuples=True,
    shuffle=True,
)

In [143]:
batch = next(iter(train_dl))
batch[0].shape, batch[1].shape, batch[2].shape

(torch.Size([32, 4, 150]), torch.Size([32]), torch.Size([32, 18]))

In [138]:
torch.arange(start=0, end=18, step=1, requires_grad=False)

tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17])

In [140]:
F.softmax(batch[2][0])

NameError: name 'F' is not defined

In [None]:
x = F.softmax(x, dim=1)
score = (x * self.bins).sum(dim=1)