In [42]:
import matplotlib.pyplot as plt
import numpy as np
from IPython.display import Audio

from robustness.audio_functions.jsinV3DataLoader_precombined import *
from robustness.audio_functions.audio_transforms import *

import lightning_scripts.jsinV3DataLoader_precombined_batched as batched_jsin 
import importlib

In [28]:
example_path = "/mnt/ceph/users/jfeather/data/training_datasets_audio/JSIN_all_v3/subsets/valid_RQTTZB4C3TJJVLJUWDV72TYMC7S4MNHH/JSIN_all__run_000_RQTTZB4C3TJJVLJUWDV72TYMC7S4MNHH.h5"
example_path_dir = "/mnt/ceph/users/jfeather/data/training_datasets_audio/JSIN_all_v3/subsets"
transform = AudioCompose(
    [
        AudioToTensor(),
        CombineWithRandomDBSNR()
    ]
)
example_dset = H5Dataset(example_path, transform=transform, target_keys=['signal/word_int'])
#example_dset_paired = H5DatasetPaired(example_path, transform=transform, target_keys=['signal/word_int'])
example_dset_all_sig = jsinV3_precombined_all_signals(example_path_dir, transform=transform, train=True)

In [29]:
### Run timing test for iteration with vs without transforms in dataset 

In [30]:
import torch 

loader = torch.utils.data.DataLoader(example_dset_all_sig, batch_size=16, num_workers=0, shuffle=False, pin_memory=True)

In [31]:
%%timeit 
for ix, _ in enumerate(loader):
    if ix == 100:
        break
    

3 s ± 26.9 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [65]:
### Run timing test for iteration without transforms in dataset 

example_dset_all_sig_raw = jsinV3_precombined_all_signals(example_path_dir, transform=None, train=True)
raw_loader = torch.utils.data.DataLoader(example_dset_all_sig_raw, batch_size=16, num_workers=0, shuffle=False, pin_memory=True)

In [40]:
%%timeit 
for ix, _ in enumerate(collated_loader):
    if ix == 100:
        break
    

2.35 s ± 27.8 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [99]:
importlib.reload(batched_jsin)
batched_jsinV3_precombined_all_signals = batched_jsin.jsinV3_precombined_all_signals(example_path_dir, transform=transform, train=True, batch_size=16)
batched_loader = torch.utils.data.DataLoader(batched_jsinV3_precombined_all_signals, batch_size=1, num_workers=0, shuffle=False, pin_memory=True)

In [100]:
%%timeit 
for ix, _ in enumerate(batched_loader):
    if ix == 100:
        break
    

586 ms ± 5.32 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [64]:
### Test if faster using transforms in collate function rather than dataset 

In [113]:
def collate_fn(batch):
    batch = batch[0] # unbox wrapper
    signals = []
    labels = batch[-1] # labels already collated 
    if isinstance(labels, dict):
        for task_key, task_labels in labels.items():
            labels[task_key] = torch.from_numpy(task_labels)
    else:
        labels = torch.from_numpy(labels) 
    # only need to convert fg and bg into signal, labels will be fine as-is
    for (fg, bg) in  zip(*batch[:2]):
        signal, _ = transform(fg, bg)
        signals.append(signal)
    signals = torch.vstack(signals)
    return signals, labels 


importlib.reload(batched_jsin)
batched_jsinV3_precombined_all_signals = batched_jsin.jsinV3_precombined_all_signals(example_path_dir, transform=None, train=True, batch_size=16)
collated_loader = torch.utils.data.DataLoader(batched_jsinV3_precombined_all_signals, batch_size=1, num_workers=2, shuffle=False, pin_memory=True, collate_fn=collate_fn)

In [96]:
batch = batched_jsinV3_precombined_all_signals[0]

In [83]:
for row in zip(*batch[:2]):
    print(row)
    break

(array([ 0.57629395,  0.27570432, -0.17647317, ..., -1.2251    ,
       -1.9700506 , -1.8175945 ], dtype=float32), array([0.8073209 , 2.523126  , 0.37178677, ..., 0.12427753, 0.20780419,
       0.40587074], dtype=float32))


In [105]:
%%timeit 
for ix, _ in enumerate(batched_loader):
    if ix == 7:
        break
    

start ix: 0 
start ix: 16 
start ix: 32 
start ix: 48 
start ix: 80 start ix: 64 



In [94]:
batch = next(iter(collated_loader))

In [106]:
### Iter test to make sure different workers are grabbing distinct batches 


In [115]:
for ix, batch in enumerate(collated_loader):
    if ix == 6:
        break
    

start ix: 0 on pid 596819start ix: 16 on pid 596820

start ix: 32 on pid 596819
start ix: 48 on pid 596820
start ix: 64 on pid 596819
start ix: 96 on pid 596819start ix: 80 on pid 596820

start ix: 128 on pid 596819
