# Testing Sequence Dataloading

**Authorship:**
Adam Klie, *03/02/2022*
***
**Description:**
Notebook for testing out PyTorch DataLoading for SeqDataSet and SeqDataModule classes

<div class="alert alert-block alert-warning">
<b>TODOs</b>:
<ul>
    <b><li>Add test cases for each step</li></b>
    <b><li>Feel like initial loading will break given certain data</li></b>
    </ul>
</div>

In [1]:
import numpy as np
import pandas as pd

# Autoreload extension
if 'autoreload' not in get_ipython().extension_manager.loaded:
    %load_ext autoreload
%autoreload 2

In [2]:
# Define directories
OLS_TSV = "/cellar/users/aklie/projects/EUGENE/data/2021_OLS_Library/2021_OLS_Library.tsv"
NUMPY_OHE = "/cellar/users/aklie/projects/EUGENE/data/2021_OLS_Library/ohe_seq/0.09-0.4_X-train-0.9_ohe-seq.npy"
FASTA_SEQS = "/cellar/users/aklie/projects/EUGENE/data/2021_OLS_Library/fasta/0.09-0.4_X-test-0.1_fasta.fa"
BINARY_TARGET = "/cellar/users/aklie/projects/EUGENE/data/2021_OLS_Library/binary/0.09-0.4_y-train-0.9_binary.txt"

# SeqDataset Class
PyTorch Dataset class for loading sequence data. Here tare the steps for loading:
 1. Load the dataset from files of different supported types using functions from `load_data.py`
 2. Generate an SeqDataset object from sequences and targets
     - Pass in the seqs and targets
     - Compose torchvision transforms
 3. Pass the dataset to DataLoader

In [3]:
from eugene.dataloading.SeqDataset import SeqDataset
SeqDataset?

[0;31mInit signature:[0m [0mSeqDataset[0m[0;34m([0m[0;34m*[0m[0margs[0m[0;34m,[0m [0;34m**[0m[0mkwds[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0;31mDocstring:[0m      Sequence based PyTorch dataset definition
[0;31mInit docstring:[0m
Args:
    seqs (iterable): list of sequences to serve as input into models
    names (iterable, optional): list of identifiers for sequences
    targets (iterable): aligned list of targets for each sequence
    rev_seqs (iterable, optional): Optional reverse complements of seqs
    transform (callable, optional): Optional transform to be applied
        on a sample.
[0;31mFile:[0m           /mnt/beegfs/users/aklie/projects/EUGENE/eugene/dataloading/SeqDataset.py
[0;31mType:[0m           type
[0;31mSubclasses:[0m     


## Load from TSV

In [61]:
from eugene.dataloading.load_data import load_csv

In [134]:
# Load the sequences as numpy arrays
names, seqs, rev_seqs, targets = load_csv("test_100seqs_66/test_seqs.tsv", seq_col="SEQ", name_col="NAME", target_col="LABEL", rev_comp=True)
names, len(seqs), seqs[0][:5], len(rev_seqs), rev_seqs[0][-5:], targets[0]

(array(['seq001', 'seq002', 'seq003', 'seq004', 'seq005', 'seq006',
        'seq007', 'seq008', 'seq009', 'seq010', 'seq011', 'seq012',
        'seq013', 'seq014', 'seq015', 'seq016', 'seq017', 'seq018',
        'seq019', 'seq020', 'seq021', 'seq022', 'seq023', 'seq024',
        'seq025', 'seq026', 'seq027', 'seq028', 'seq029', 'seq030',
        'seq031', 'seq032', 'seq033', 'seq034', 'seq035', 'seq036',
        'seq037', 'seq038', 'seq039', 'seq040', 'seq041', 'seq042',
        'seq043', 'seq044', 'seq045', 'seq046', 'seq047', 'seq048',
        'seq049', 'seq050', 'seq051', 'seq052', 'seq053', 'seq054',
        'seq055', 'seq056', 'seq057', 'seq058', 'seq059', 'seq060',
        'seq061', 'seq062', 'seq063', 'seq064', 'seq065', 'seq066',
        'seq067', 'seq068', 'seq069', 'seq070', 'seq071', 'seq072',
        'seq073', 'seq074', 'seq075', 'seq076', 'seq077', 'seq078',
        'seq079', 'seq080', 'seq081', 'seq082', 'seq083', 'seq084',
        'seq085', 'seq086', 'seq087', 'seq088', 

In [135]:
targets.shape

(100,)

In [136]:
# Compose different data transforms for this particular load
data_transform = transforms.Compose([
    Augment(randomize_linker_p=0.1, enhancer="WT-otx-a"), ReverseComplement(ohe_encoded=False), OneHotEncode(), ToTensor(transpose=False)
])

In [141]:
np.any(test_dataset.name_lengths != test_dataset.name_lengths[0])

False

In [140]:
test_dataset.name_lengths != test_dataset.name_lengths[0]

array([False, False, False, False, False, False, False, False, False,
       False, False, False, False, False, False, False, False, False,
       False, False, False, False, False, False, False, False, False,
       False, False, False, False, False, False, False, False, False,
       False, False, False, False, False, False, False, False, False,
       False, False, False, False, False, False, False, False, False,
       False, False, False, False, False, False, False, False, False,
       False, False, False, False, False, False, False, False, False,
       False, False, False, False, False, False, False, False, False,
       False, False, False, False, False, False, False, False, False,
       False, False, False, False, False, False, False, False, False,
       False])

In [139]:
# Instantiate a Dataset
test_dataset = SeqDataset(seqs, names=names, targets=targets, transform=data_transform)
test_dataset[0]

(tensor([115., 101., 113.,  48.,  48.,  49.]),
 tensor([[0., 0., 1., 0.],
         [0., 0., 0., 1.],
         [1., 0., 0., 0.],
         [0., 0., 1., 0.],
         [0., 0., 1., 0.],
         [0., 0., 0., 1.],
         [1., 0., 0., 0.],
         [1., 0., 0., 0.],
         [0., 0., 1., 0.],
         [0., 1., 0., 0.],
         [0., 0., 1., 0.],
         [0., 0., 1., 0.],
         [0., 0., 1., 0.],
         [0., 0., 1., 0.],
         [0., 0., 0., 1.],
         [1., 0., 0., 0.],
         [0., 0., 0., 1.],
         [0., 0., 0., 1.],
         [0., 0., 0., 1.],
         [0., 0., 1., 0.],
         [0., 1., 0., 0.],
         [1., 0., 0., 0.],
         [0., 1., 0., 0.],
         [0., 0., 0., 1.],
         [0., 0., 0., 1.],
         [0., 1., 0., 0.],
         [0., 1., 0., 0.],
         [0., 1., 0., 0.],
         [0., 0., 0., 1.],
         [0., 0., 0., 1.],
         [1., 0., 0., 0.],
         [1., 0., 0., 0.],
         [0., 0., 0., 1.],
         [0., 1., 0., 0.],
         [0., 1., 0., 0.],
        

In [32]:
# Check the Dataset Class
for i in range(len(test_dataset)):
    sample = test_dataset[i]
    print(i, sample[1].size(), sample[2].size(), sample[3])
    if i == 3:
        break

0 torch.Size([66, 4]) torch.Size([66, 4]) tensor(0.9078)
1 torch.Size([66, 4]) torch.Size([66, 4]) tensor(0.5071)
2 torch.Size([66, 4]) torch.Size([66, 4]) tensor(0.6962)
3 torch.Size([66, 4]) torch.Size([66, 4]) tensor(0.8218)


## Load from Numpy arrays

In [34]:
from eugene.dataloading.load_data import load_numpy

In [62]:
# Load the sequences as numpy arrays
names, seqs, rev_seqs, targets = load_numpy("../data/2021_OLS_Library/ohe_seq/0.09-0.4_X-all_ohe-seq.npy",
                                            target_file="../data/2021_OLS_Library/binary/0.09-0.4_y-all_binary.txt",
                                            is_target_text=True)
names, seqs[0][:5], rev_seqs, targets[0]

(None,
 array([[0., 1., 0., 0.],
        [1., 0., 0., 0.],
        [0., 0., 0., 1.],
        [0., 1., 0., 0.],
        [0., 0., 0., 1.]]),
 None,
 1.0)

In [58]:
targets.shape

(460800,)

In [60]:
# Load the sequences as numpy arrays
names, seqs, rev_seqs, targets = load_numpy("test_100seqs_1000/test_ohe_seqs.npy", 
                                            target_file="test_100seqs_66/test_labels.npy",
                                            names_file="test_100seqs_66/test_ids.npy", 
                                            rev_seq_file="test_100seqs_66/test_rev_ohe_seqs.npy")
len(names), names[0], len(seqs), seqs[0][:5], len(rev_seqs), rev_seqs[0][-5:], targets[0]

(100,
 'seq001',
 100,
 array([[0, 0, 1, 0],
        [0, 0, 0, 1],
        [1, 0, 0, 0],
        [0, 0, 1, 0],
        [0, 0, 1, 0]], dtype=int8),
 100,
 array([[0, 1, 0, 0],
        [0, 1, 0, 0],
        [0, 0, 0, 1],
        [1, 0, 0, 0],
        [0, 1, 0, 0]], dtype=int8),
 0)

In [61]:
targets.shape

(100,)

In [279]:
# Compose different data transforms for this particular load
data_transform = transforms.Compose([ToTensor(transpose=False)])

In [280]:
# Instantiate a Dataset
test_dataset = MPRADataset(seqs, names=names, rev_seqs=rev_seqs, transform=data_transform)

In [282]:
# Check the Dataset Class
for i in range(len(test_dataset)):
    sample = test_dataset[i]
    print(i, sample[0].size(), sample[1].size(), sample[2].size(), sample[3])
    if i == 3:
        break

0 torch.Size([6]) torch.Size([1000, 4]) torch.Size([1000, 4]) tensor([-1.])
1 torch.Size([6]) torch.Size([1000, 4]) torch.Size([1000, 4]) tensor([-1.])
2 torch.Size([6]) torch.Size([1000, 4]) torch.Size([1000, 4]) tensor([-1.])
3 torch.Size([6]) torch.Size([1000, 4]) torch.Size([1000, 4]) tensor([-1.])


## Load from Fasta

In [283]:
from load_data import load_fasta

In [284]:
# Load the sequences as numpy arrays
names, seqs, rev_seqs, targets = load_fasta("test_seqs.fa", "test_labels.npy", rev_comp=True)

In [285]:
# Compose different data transforms for this particular load
data_transform = transforms.Compose([
    Augment(randomize_linker_p=0.1, enhancer="WT-otx-a"), OneHotEncode(), ToTensor(transpose=False)
])

In [289]:
# Instantiate a Dataset
test_dataset = MPRADataset(seqs, names=names, targets=targets, rev_seqs=rev_seqs, transform=data_transform)

In [290]:
# Check the Dataset Class
for i in range(len(test_dataset)):
    sample = test_dataset[i]
    print(i, sample[0].size(), sample[1].size(), sample[2].size(), sample[3])
    if i == 3:
        break

0 torch.Size([6]) torch.Size([1000, 4]) torch.Size([1000, 4]) tensor(0.)
1 torch.Size([6]) torch.Size([1000, 4]) torch.Size([1000, 4]) tensor(1.)
2 torch.Size([6]) torch.Size([1000, 4]) torch.Size([1000, 4]) tensor(1.)
3 torch.Size([6]) torch.Size([1000, 4]) torch.Size([1000, 4]) tensor(1.)


## Build DataLoader

In [26]:
from load_data import load

In [27]:
# Load the sequences as numpy arrays
names, seqs, rev_seqs, targets = load("test_ohe_seqs.npy", names_file="test_ids.npy", target_file="test_labels.npy", rev_seq_file="test_rev_ohe_seqs.npy")
len(names), names[0], len(seqs), seqs[0][:5], len(rev_seqs), rev_seqs[0][-5:], len(targets), targets[0]

(100,
 'seq001',
 100,
 array([[0, 0, 1, 0],
        [0, 0, 0, 1],
        [1, 0, 0, 0],
        [0, 0, 1, 0],
        [0, 0, 1, 0]], dtype=int8),
 100,
 array([[0, 1, 0, 0],
        [0, 1, 0, 0],
        [0, 0, 0, 1],
        [1, 0, 0, 0],
        [0, 1, 0, 0]], dtype=int8),
 100,
 0)

In [28]:
data_transform = transforms.Compose([
    ToTensor(transpose=False)
])

In [29]:
test_dataset = MPRADataset(seqs, names=names, targets=targets, rev_seqs=rev_seqs, transform=data_transform)

In [30]:
# Instantiate a DataLoader
test_dataloader = DataLoader(test_dataset, batch_size=32, shuffle=False, num_workers=0)

In [33]:
# Check the Dataset Class
for i in range(len(test_dataset)):
    sample = test_dataset[i]
    print(i, sample[0].size(), sample[1].size(), sample[2].size(), sample[3])
    #if i == 3:
    #    break

0 torch.Size([6]) torch.Size([1000, 4]) torch.Size([1000, 4]) tensor(0.)
1 torch.Size([6]) torch.Size([1000, 4]) torch.Size([1000, 4]) tensor(1.)
2 torch.Size([6]) torch.Size([1000, 4]) torch.Size([1000, 4]) tensor(1.)
3 torch.Size([6]) torch.Size([1000, 4]) torch.Size([1000, 4]) tensor(1.)
4 torch.Size([6]) torch.Size([1000, 4]) torch.Size([1000, 4]) tensor(1.)
5 torch.Size([6]) torch.Size([1000, 4]) torch.Size([1000, 4]) tensor(1.)
6 torch.Size([6]) torch.Size([1000, 4]) torch.Size([1000, 4]) tensor(0.)
7 torch.Size([6]) torch.Size([1000, 4]) torch.Size([1000, 4]) tensor(0.)
8 torch.Size([6]) torch.Size([1000, 4]) torch.Size([1000, 4]) tensor(0.)
9 torch.Size([6]) torch.Size([1000, 4]) torch.Size([1000, 4]) tensor(1.)
10 torch.Size([6]) torch.Size([1000, 4]) torch.Size([1000, 4]) tensor(0.)
11 torch.Size([6]) torch.Size([1000, 4]) torch.Size([1000, 4]) tensor(1.)
12 torch.Size([6]) torch.Size([1000, 4]) torch.Size([1000, 4]) tensor(1.)
13 torch.Size([6]) torch.Size([1000, 4]) torch.S

In [22]:
# Check the DataLoader
for i_batch, sample_batched in enumerate(test_dataloader):
    print(i_batch, sample_batched[1].size(), sample_batched[2].size(), sample_batched[3].size())
    # observe 4th batch and stop.
    if i_batch == 3:
        break

NameError: name 'test_dataloader' is not defined

## Variable length

In [126]:
seq_file="/cellar/users/aklie/projects/EUGENE/data/All_Genomic_Sequences/All_Genomic_Sequences.tsv"

In [127]:
import pandas as pd
pd.read_csv(seq_file, sep="\t").head()

Unnamed: 0,NAME,SEQ,FXN_LABEL,TILE,SEQ_LEN,DATASET,FXN_DESCRIPTION
0,scaffold_1:462149:462232,AAAGTAGGCTATATGCTACAGCCCAGAGCTCATGGATTTTAATGGG...,1,full,195,2010_Khoueiry_CellPress,
1,scaffold_102:102675:102754,AAAGTAGGCTATCGCGACCACTGATCGTCGCGTAATTATTTTGAGG...,1,full,183,2010_Khoueiry_CellPress,
2,scaffold_3:475040:475120,AAAGTAGGCTTTGTATAATTCCCAATTAGAACACACAGGTAGGGTA...,1,full,180,2010_Khoueiry_CellPress,
3,scaffold_3:781073:781145,AAAGTAGGCTTTAACATGGAATCTATTCCCCGTGACGATGAGATAA...,1,full,154,2010_Khoueiry_CellPress,
4,scaffold_357:46416:46496,AAAGTAGGCTTCTTGCTTTGTTGTTTTTGTTTTGAAACTGGTGAGC...,0,full,186,2010_Khoueiry_CellPress,


In [145]:
from eugene.dataloading.load_data import load

In [146]:
# Load the sequences as numpy arrays
names, seqs, rev_seqs, targets = load(seq_file, seq_col="SEQ", name_col="NAME", target_col="FXN_LABEL", rev_comp=True)
len(names), names[0], len(seqs), seqs[0][:5], len(rev_seqs), rev_seqs[0][-5:], len(targets), targets[0]

(42, 'scaffold_1:462149:462232', 42, 'AAAGT', 42, 'ACTTT', 42, 1.0)

In [147]:
test_dataset = SeqDataset(seqs, names=names, targets=targets, rev_seqs=rev_seqs, transform=None)
test_dataset[0]

array([array([115.,  99.,  97., 102., 102., 111., 108., 100.,  95.,  49.,  58.,
               52.,  54.,  50.,  49.,  52.,  57.,  58.,  52.,  54.,  50.,  50.,
               51.,  50.,  36.,  36.])                                         ,
       'AAAGTAGGCTATATGCTACAGCCCAGAGCTCATGGATTTTAATGGGATCGGCTATCTAGGCCGACCCTCGCTCTCCCAAGGAAATGTCCACCTTCCAGCCGGGAAAAGATAACCGCTCGCCAGAGCGACGCTTTCCGGCTGACAAATTGTGTCGGACCTTGATAGCATTCCTGTTCCCTATCGGACCCAACTTT ',
       ' AAAGTTGGGTCCGATAGGGAACAGGAATGCTATCAAGGTCCGACACAATTTGTCAGCCGGAAAGCGTCGCTCTGGCGAGCGGTTATCTTTTCCCGGCTGGAAGGTGGACATTTCCTTGGGAGAGCGAGGGTCGGCCTAGATAGCCGATCCCATTAAAATCCATGAGCTCTGGGCTGTAGCATATAGCCTACTTT',
       1.0], dtype=object)

In [148]:
test_dataset[0][0]

array([115.,  99.,  97., 102., 102., 111., 108., 100.,  95.,  49.,  58.,
        52.,  54.,  50.,  49.,  52.,  57.,  58.,  52.,  54.,  50.,  50.,
        51.,  50.,  36.,  36.])

In [149]:
test_str = "".join([chr(int(letter)) for letter in test_dataset[0][0]]).replace("$", "")

In [153]:
test_strs = [ascii_decode(thing) for thing in test_dataset.ascii_names]

array(['scaffold_1:462149:462232', 'scaffold_102:102675:102754',
       'scaffold_3:475040:475120', 'scaffold_3:781073:781145',
       'scaffold_357:46416:46496', 'scaffold_577:39771:39829',
       'scaffold_6:512297:512370', 'scaffold_98:245430:245507',
       'scaffold_248:118347:118482', 'Scaffold_31:173924:174038',
       'scaffold_6:562995:563124', 'scaffold_11:490974:491037',
       'scaffold_366:40184:40258', 'scaffold_37:273449:273512',
       'scaffold_371:68262:68340', 'scaffold_40:155281:155360',
       'scaffold_48:226447:226527', 'scaffold_490:18446:18523',
       'scaffold_50:51150:51235', 'scaffold_522:36260:36327', 'Otxa-WTg1',
       'Otxa-WTg2', 'Otxa-WTg3', 'Otxa-WTg5', 'Negative_control', 'nc2',
       'nc3', 'nc4', 'nc5', 'EM.Ci1', 'EM.Dr1', 'EM.Ggd1', 'EM.Mm6',
       'EM.Hs9', 'EM.Mm2', 'EM.Hs1', 'EM.Hs2', 'EM.Hs3', 'EM.Hs8',
       'EM.Mm1', 'EM.Hs4', 'EM.Dr2'], dtype='<U26')

In [158]:
np.array(test_strs) == test_dataset.names

array([ True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True,  True,  True])

In [None]:
for asci in 

In [151]:
test_str

'scaffold_1:462149:462232'

In [91]:
test_dataset.ascii_names.shape

(42, 26)

In [92]:
np.any(test_dataset.name_lengths == test_dataset.name_lengths[0])

True

In [93]:
test_dataset.name_lengths

array([24, 26, 24, 24, 24, 24, 24, 25, 26, 25, 24, 25, 24, 25, 24, 25, 25,
       24, 23, 24,  9,  9,  9,  9, 16,  3,  3,  3,  3,  6,  6,  7,  6,  6,
        6,  6,  6,  6,  6,  6,  6,  6])

In [9]:
import numpy as np

In [51]:
test_dataset.longest_name

AttributeError: 'SeqDataset' object has no attribute 'longest_name'

In [107]:
from eugene.utils.seq_utils import ascii_encode, ascii_decode

In [122]:
asci_2 = np.zeros((len(test_dataset.names), test_dataset.longest_name))
for i, name in enumerate(test_dataset.names):
    pad_len = test_dataset.longest_name-len(name)
    asci_2[i] = ascii_encode(name, pad=pad_len)
    #np_name = np.array([ord(letter) for letter in name], dtype=int)
    #asci[i] = np.pad(np_name, pad_width=(0, pad_len), mode="constant", constant_values=36)

In [121]:
ascii_encode(name)

array([ 69,  77,  46,  68, 114,  50])

In [125]:
np.all(asci == asci_2)

True

In [114]:
asci.shape

(42, 26)

In [123]:
asci_2.shape

(42, 26)

In [49]:
test_dataset.names

array(['scaffold_1:462149:462232', 'scaffold_102:102675:102754',
       'scaffold_3:475040:475120', 'scaffold_3:781073:781145',
       'scaffold_357:46416:46496', 'scaffold_577:39771:39829',
       'scaffold_6:512297:512370', 'scaffold_98:245430:245507',
       'scaffold_248:118347:118482', 'Scaffold_31:173924:174038',
       'scaffold_6:562995:563124', 'scaffold_11:490974:491037',
       'scaffold_366:40184:40258', 'scaffold_37:273449:273512',
       'scaffold_371:68262:68340', 'scaffold_40:155281:155360',
       'scaffold_48:226447:226527', 'scaffold_490:18446:18523',
       'scaffold_50:51150:51235', 'scaffold_522:36260:36327', 'Otxa-WTg1',
       'Otxa-WTg2', 'Otxa-WTg3', 'Otxa-WTg5', 'Negative_control', 'nc2',
       'nc3', 'nc4', 'nc5', 'EM.Ci1', 'EM.Dr1', 'EM.Ggd1', 'EM.Mm6',
       'EM.Hs9', 'EM.Mm2', 'EM.Hs1', 'EM.Hs2', 'EM.Hs3', 'EM.Hs8',
       'EM.Mm1', 'EM.Hs4', 'EM.Dr2'], dtype='<U26')

In [47]:
asci.shape

(42, 26)

In [36]:
np_name

array([115,  99,  97, 102, 102, 111, 108, 100,  95,  49,  58,  52,  54,
        50,  49,  52,  57,  58,  52,  54,  50,  50,  51,  50])

In [41]:
test_dataset.longest_name

26

In [44]:
.shape

(26,)

In [20]:
np.pad(test_dataset.ascii_names, [(0, 0), (0, 0)], mode="constant", constant_values=-1)

ValueError: operands could not be broadcast together with remapped shapes [original->remapped]: (2,2)  and requested shape (1,2)

# MPRADataModule
PyTorch Lightning DataModule class for MPRA data that allows for asbtracting most of the dataloading process. These DataModules can be passed straight to trainers for model training.

In [77]:
from MPRADataModule import MPRADataModule

In [78]:
data_transform = transforms.Compose([
    Augment(randomize_linker_p=0.1, enhancer="WT-otx-a"), 
    OneHotEncode(), 
    ToTensor(transpose=True)
])

In [79]:
test_datamodule = MPRADataModule(
    seq_file="test_seqs.tsv",
    transform=data_transform,
    num_workers=0,
    batch_size=16,
    load_kwargs=dict(seq_col="SEQ"))

In [80]:
test_datamodule.setup()
test_dataset = test_datamodule.train_dataloader().dataset

In [81]:
# Check the Dataset Class
for i in range(len(test_datamodule.train_dataloader().dataset)):
    sample = test_datamodule.train_dataloader().dataset[i]
    print(i, sample[0].size(), sample[1].size(), sample[2].size(), sample[3])
    if i == 3:
        break

0 torch.Size([1]) torch.Size([4, 1000]) torch.Size([1]) tensor([-1.])
1 torch.Size([1]) torch.Size([4, 1000]) torch.Size([1]) tensor([-1.])
2 torch.Size([1]) torch.Size([4, 1000]) torch.Size([1]) tensor([-1.])
3 torch.Size([1]) torch.Size([4, 1000]) torch.Size([1]) tensor([-1.])


In [82]:
# Check the DataLoader
for i_batch, sample_batched in enumerate(test_datamodule.train_dataloader()):
    print(i_batch, sample_batched[1].size(), sample_batched[2].size(), sample_batched[3].size())
    # observe 4th batch and stop.
    if i_batch == 3:
        break

0 torch.Size([16, 4, 1000]) torch.Size([16, 1]) torch.Size([16, 1])
1 torch.Size([16, 4, 1000]) torch.Size([16, 1]) torch.Size([16, 1])
2 torch.Size([16, 4, 1000]) torch.Size([16, 1]) torch.Size([16, 1])
3 torch.Size([16, 4, 1000]) torch.Size([16, 1]) torch.Size([16, 1])


In [21]:
test_datamodule.

(tensor([-1.]),
 tensor([[0., 0., 0.,  ..., 1., 0., 1.],
         [0., 0., 1.,  ..., 0., 0., 0.],
         [1., 1., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 1., 0.]]),
 tensor([-1.]),
 tensor([-1.]))