# Dataloader

Just a notebook showing that the dataloader can be imported and it loads the data

In [None]:
# Add the src folder to the path
import sys
sys.path.insert(0, '../src/')

from data.dataloader import MidiDataset
from data.bar_transform import BarTransform

## Initialize transform as a function to put into the dataloader

In [None]:
transform = BarTransform(bars=16, note_count=60)
transform

In [None]:
midi_dataset = MidiDataset(csv_file='./concat.csv', transform = transform)
midi_dataset.__len__()

## Get the first item of the dataset

In [None]:
first = midi_dataset.__getitem__(0)
first

In [None]:
print(len(first['piano_rolls']))
first['piano_rolls'].shape

## Iterating the dataset

In [None]:
for i in range(len(midi_dataset)):
    sample = midi_dataset[i]
    print("{}, {} timesteps".format(i, len(sample['piano_rolls'])))

    if i == 10:
        break

In [None]:
midi_dataset.get_mem_usage()

## Use a dataloader to batch and iterate over the dataset

In [None]:
from torch.utils.data import Dataset, DataLoader
dataloader = DataLoader(midi_dataset, batch_size=32,
                        shuffle=True, num_workers=4)

In [None]:
for i_batch, sample_batched in enumerate(dataloader):
    batch = sample_batched['piano_rolls']
    sample = batch[-1]
    print("Batch no: {}, Batch size: {} samples, Timesteps per sample: {}".format(i_batch, len(batch), len(sample)))

    # observe 4th batch and stop.
    if i_batch == 3:
        break