# I have had some serious issues with the collate function in pytorch
IT seems that with our loader we have these collate features, we also get this deffault collate, let's see if we can figure out why my collate is so weird

In [1]:
import sys
sys.path.append('/data/leslie/sarthak/hyena/hyena-dna/')
from src.dataloaders.genomics import HG38
from src.dataloaders.datasets.profile_atac_long import ProfileATACLong


In [3]:
from typing import Any, List, Union
from torch.utils.data.dataloader import DataLoader, Dataset
class ProfileATACLongLoader(HG38): #for unique cell type tokens
    _name_ = "ProfileATACLongLoader"
    l_output = 0  # need to set this for decoder to work correctly
    #global in the context of the class or its instances. potentially used by hydra? I am unsure of what this does...

    def __init__(self, dataset_name, dest_path=None, tokenizer_name='char', d_output=None, rc_aug=False,
                max_length=1024, use_padding=True, max_length_val=None, max_length_test=None,
                padding_side='left', return_mask=False, val_ratio=0.0005, val_split_seed=2357, add_eos=False, 
                detokenize=False, val_only=False, batch_size=32, batch_size_eval=None, num_workers=1,
                shuffle=True, pin_memory=False, drop_last=False, fault_tolerant=False, ddp=False,
                fast_forward_epochs=None, fast_forward_batches=None, single_cell_type = None,
                train_bias=False, data_path=None,jitter=0, *args, **kwargs):
        self.dataset_name = dataset_name
        self.dest_path = dest_path
        self.tokenizer_name = tokenizer_name
        self.d_output = d_output
        self.rc_aug = rc_aug
        self.max_length = max_length
        self.use_padding = use_padding
        self.max_length_val = max_length_val if max_length_val is not None else max_length
        self.max_length_test = max_length_test if max_length_test is not None else max_length
        self.padding_side = padding_side
        self.return_mask = return_mask
        self.val_ratio = val_ratio
        self.val_split_seed = val_split_seed
        self.val_only = val_only
        self.add_eos = add_eos
        self.detokenize = detokenize
        self.batch_size = batch_size
        self.batch_size_eval = batch_size_eval if batch_size_eval is not None else self.batch_size
        self.num_workers = num_workers
        self.shuffle = shuffle
        self.pin_memory = pin_memory
        self.drop_last = drop_last
        self.single_cell_type = single_cell_type
        self.train_bias = train_bias
        self.data_path = data_path
        self.jitter=jitter

        # if self.dest_path is None:
        #     self.dest_path = default_data_path / self._name_

        if fault_tolerant:
            assert self.shuffle
        self.fault_tolerant = fault_tolerant
        if ddp:
            assert fault_tolerant
        self.ddp = ddp
        self.fast_forward_epochs = fast_forward_epochs
        self.fast_forward_batches = fast_forward_batches
        if self.fast_forward_epochs is not None or self.fast_forward_batches is not None:
            assert ddp and fault_tolerant

    def setup(self, stage=None):
        # TODO instantiate with registry
        #what we need to do is have characters be the list of cell indices 0-161
        characters = ['A', 'C', 'G', 'T', 'N']

        # Combine the two lists to form the final list of tokens
        # characters = number_tokens + nucleotide_tokens
        # if self.tokenizer_name == 'char':
        #     print("**Using Char-level tokenizer**")
        #     self.tokenizer = CharacterTokenizer(
        #         characters=characters,
        #         model_max_length=self.max_length + 2,  # add 2 since default adds eos/eos tokens, crop later
        #         add_special_tokens=False,
        #         padding_side=self.padding_side,
        #     )
        self.tokenizer=None
        
        # Create all splits: torch datasets (only train/test in this benchmark)
        self.dataset_train, self.dataset_val = [
            ProfileATACLong(split=split,
                                max_length=max_len,
                                # dataset_name=self.dataset_name,
                                tokenizer=self.tokenizer,  # pass the tokenize wrapper
                                tokenizer_name=self.tokenizer_name,
                                use_padding=self.use_padding,
                                d_output=self.d_output, #we manually defined it in the dataset
                                add_eos=self.add_eos,
                                # dest_path=self.dest_path,
                                rc_aug=self.rc_aug,
                                return_augs=False,
                                single_cell_type = self.single_cell_type,
                                data_path=self.data_path,
                                train_bias=self.train_bias,
                                jitter = self.jitter,
                                # return_mask=self.return_mask,
            )
            for split, max_len in zip(['train', 'val'], [self.max_length, self.max_length_val])
        ] #uses dataset class and makes a train and validation using the basic loader
        
    def test_dataloader(self, *args: Any, **kwargs: Any) -> Union[DataLoader, List[DataLoader]]:
        """ The test dataloader, it's a dummy loader just to make the trainer happy, we don't use it."""
        return self._data_loader(self.dataset_val, batch_size=self.batch_size_eval)
    
    #need a new collate fn
    # @classmethod
    # def _collate_fn(cls, batch, *args, **kwargs): #my custom collate function that is used since it's better and works for this custom class
    #     """
    #     Custom collate function to handle nested tuples of tensors.
    #     """
    #     print("Using custom collate function")
    #     # Unzip the batch into separate components
    #     (seqs, one_hot_seqs), (cts, counts), *z = zip(*batch)
        
    #     # Collate each component separately
    #     seqs = cls._collate(seqs, *args, **kwargs)
    #     one_hot_seqs = cls._collate(one_hot_seqs, *args, **kwargs)
    #     cts = cls._collate(cts, *args, **kwargs)
    #     counts = cls._collate(counts, *args, **kwargs)
        
    #     # Combine the collated components back into the original structure
    #     x = (seqs, one_hot_seqs)
    #     y = (cts, counts)
        
    #     return_value = (x, y, *z)
    #     return cls._return_callback(return_value, *args, **kwargs)
    @classmethod
    def _collate_fn(cls, batch, *args, **kwargs): #my custom collate function that is used since it's better and works for this custom class
        #we will literally just return it as is
        return batch

In [4]:
#now let's create a dataloader instance
loader = ProfileATACLongLoader(dataset_name='train')

In [8]:
#can test the output
loader.setup() #simply instantiates the datasets
loader.dataset_train[0]

((tensor([ 8, 10,  7,  ...,  8,  7, 10]), []),
 (tensor([1., 0., 3.,  ..., 0., 1., 0.]), tensor([6.7081])))

In [9]:
#now iterate over the loader
for i, batch in enumerate(loader.train_dataloader()):
    print(i)
    print(batch)
    if i>2:
        break

0
[[tensor([[10,  8, 10,  ..., 10,  7,  8],
        [ 9,  8, 10,  ...,  8,  8,  7],
        [ 7,  9,  7,  ...,  8, 10,  7],
        ...,
        [ 9,  8,  8,  ...,  8,  7,  9],
        [ 8,  8,  7,  ...,  8,  7, 10],
        [ 7, 10, 10,  ..., 10, 10,  8]]), []], [tensor([[1., 0., 0.,  ..., 0., 1., 0.],
        [0., 0., 2.,  ..., 0., 2., 1.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 0., 0., 4.],
        [0., 0., 0.,  ..., 0., 1., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.]]), tensor([[6.2729],
        [5.6937],
        [6.2766],
        [5.7557],
        [6.9441],
        [6.0137],
        [5.4848],
        [6.5889],
        [6.3919],
        [7.2506],
        [5.9814],
        [6.3351],
        [5.6419],
        [6.7286],
        [7.3421],
        [9.5280],
        [6.1399],
        [6.3630],
        [8.2602],
        [7.5110],
        [6.8427],
        [7.0148],
        [7.6183],
        [6.3456],
        [5.7589],
        [6.6846],
        [7.4

In [16]:
print(len(batch))
print(len(batch[0]))
print(batch[0][0].shape) #this is the seq data in tensor format!


2
2
torch.Size([32, 1024])


In [17]:
batch[0][0]

tensor([[ 8, 10,  9,  ...,  7,  7, 10],
        [ 7, 10, 10,  ..., 10,  8,  7],
        [ 7, 10,  8,  ...,  7,  7,  7],
        ...,
        [ 7,  7,  7,  ...,  7,  7, 10],
        [10,  7,  9,  ..., 10, 10,  7],
        [ 9,  9,  9,  ..., 10,  9,  7]])

In [10]:
loader.fault_tolerant

False

In [11]:
#import dataloader from pytorch
from torch.utils.data import DataLoader
manual_loader = DataLoader(loader.dataset_train, batch_size=32, shuffle=True, num_workers=1, collate_fn=loader._collate_fn)

In [18]:
a=next(iter(manual_loader))

In [20]:
print(len(a))
#see so this returns it very differently...

32


# testing the dataset to ensure it actually works

In [1]:
import sys
sys.path.append('/data/leslie/sarthak/hyena/hyena-dna/')
from src.dataloaders.genomics import HG38
from src.dataloaders.datasets.profile_atac_long import ProfileATACLong


In [2]:
dataset = ProfileATACLong(split='train', max_length=32768, tokenizer=None, tokenizer_name='char', use_padding=True, d_output=None, add_eos=False, rc_aug=False, return_augs=False, jitter=0)

In [4]:
out = dataset[0]
print(out)

((tensor([ 7,  8,  7,  ...,  7, 10,  7]), tensor([0., 0., 0.,  ..., 0., 0., 0.])), (tensor([0., 0., 0.,  ..., 0., 0., 0.]), tensor([10.1966])))


In [10]:
x,y = out
print(x[0].shape)
print(x[1].shape)
print(y[0].shape)
print(y[1].shape)

torch.Size([32768])
torch.Size([32768])
torch.Size([32768])
torch.Size([1])


In [8]:
x[0].shape[0] == 32768

True

In [11]:
#now let's go through the dataset
from tqdm import tqdm
for i in tqdm(range(len(dataset))):
    out = dataset[i]
    x,y = out
    if x[0].shape[0] != 32768:
        print(i)
        break
    if x[1].shape[0] != 32768:
        print(i)
        break
    if y[0].shape[0] != 32768:
        print(i)
        break
    if y[1].shape[0] != 1:
        print(i)
        break

 19%|█▉        | 42917/220311 [00:08<00:34, 5183.05it/s]

42917





In [14]:
print(x[0].shape)
print(x[1].shape)
print(y[0].shape)
print(y[1].shape)

#are we at the end of the genome? let's see

torch.Size([26787])
torch.Size([26787])
torch.Size([26787])
torch.Size([1])


In [15]:
dataset.peak_coords[i]

array(['chr15', '101980786', 'f', '1'], dtype='<U21')

In [16]:
print(len(dataset.genome['chr15']))

101991189


In [17]:
len(dataset.genome['chr15']) - int(dataset.peak_coords[i][1])

10403

In [24]:
len(dataset.peak_coords)

220311

In [25]:
#only 10k to the right! We need to do some sort of filtering
import numpy as np
remove_array=np.ones(len(dataset.peak_coords))
for i,row in enumerate(dataset.peak_coords):
    start = int(row[1])
    if start + 32768 > len(dataset.genome[row[0]]) or start < 32768:
        remove_array[i]=0
#this is very quick

In [27]:
np.where(remove_array==0)

(array([ 20280,  20281,  20282,  20283,  20284,  32025,  32026,  32027,
         32028,  32029,  42917,  51375,  53014,  53015,  53016,  85967,
         88788,  95373,  95374, 118312, 142040, 147674, 195135, 195136,
        196790, 196894, 197016, 197017, 214822, 215779, 216887, 216888,
        216889, 216890, 217089, 217090, 217091, 217092, 217093]),)

In [28]:
#remove those arrays
dataset.peak_coords = dataset.peak_coords[remove_array==1]

In [29]:
print(len(dataset.peak_coords))

220272


In [30]:
#now let's go through the dataset
from tqdm import tqdm
for i in tqdm(range(len(dataset))):
    out = dataset[i]
    x,y = out
    if x[0].shape[0] != 32768:
        print(i)
        break
    if x[1].shape[0] != 32768:
        print(i)
        break
    if y[0].shape[0] != 32768:
        print(i)
        break
    if y[1].shape[0] != 1:
        print(i)
        break

100%|██████████| 220272/220272 [00:40<00:00, 5448.47it/s]


In [1]:
#helllllllll yeah! Now let's add this to the thing and test it
import sys
sys.path.append('/data/leslie/sarthak/hyena/hyena-dna/')
from src.dataloaders.genomics import HG38
from src.dataloaders.datasets.profile_atac_long import ProfileATACLong

dataset = ProfileATACLong(split='train', max_length=32768, tokenizer=None, tokenizer_name='char', use_padding=True, d_output=None, add_eos=False, rc_aug=False, return_augs=False, jitter=100_000)
#huge jitter just because we're trying to test it

In [2]:
#now let's go through the dataset
#see that it's a decent amount shorter
from tqdm import tqdm
for i in tqdm(range(len(dataset))):
    out = dataset[i]
    x,y = out
    if x[0].shape[0] != 32768:
        print(i)
        break
    if x[1].shape[0] != 32768:
        print(i)
        break
    if y[0].shape[0] != 32768:
        print(i)
        break
    if y[1].shape[0] != 1:
        print(i)
        break

100%|██████████| 219854/219854 [00:39<00:00, 5500.60it/s]


In [None]:
#ok finally seems good! Let's run the actual thing