# Dataloader

Just a notebook showing that the dataloader can be imported and it loads the data

In [None]:
# Add the src folder to the path
import sys
sys.path.insert(0, '../src/')

from data.dataloader import MidiDataset
from data.bar_transform import BarTransform

## Initialize transform as a function to put into the dataloader

In [None]:
transform = BarTransform(bars=16, note_count=60+1) # +1 for silences
transform

In [None]:
midi_dataset = MidiDataset(csv_file='./concat.csv', transform = transform)
midi_dataset.__len__()

## Get the first item of the dataset

In [None]:
first = midi_dataset.__getitem__(0)
first

In [None]:
print(len(first['piano_rolls']))
first['piano_rolls'].shape

In [None]:
first['piano_rolls'][:,:-1].shape

## Iterating the dataset

In [None]:
for i in range(len(midi_dataset)):
    sample = midi_dataset[i]
    print("{}, {} timesteps".format(i, len(sample['piano_rolls'])))

    if i == 10:
        break

In [None]:
midi_dataset.get_mem_usage()

## Use a dataloader to batch and iterate over the whole dataset

In [None]:
from torch.utils.data import Dataset, DataLoader
dataloader = DataLoader(midi_dataset, batch_size=32,
                        shuffle=True, num_workers=4)

In [None]:
for i_batch, sample_batched in enumerate(dataloader):
    batch = sample_batched['piano_rolls']
    sample = batch[i_batch]
    print("Batch no: {}, Batch size: {} samples, Timesteps per sample: {}".format(i_batch, len(batch), len(sample)))

    # observe 4th batch and stop.
    if i_batch == 3:
        break

## Example of splitting custom dataset into test and train sets

In [None]:
from torch.utils.data.sampler import SubsetRandomSampler
from torch.utils.data import DataLoader, random_split
import numpy as np

random_seed = 42
batch_size = 512
test_split = .2
shuffle = True

if random_seed is not None:
    np.random.seed(random_seed)

dataset_size = len(midi_dataset)
test_size = int(test_split * dataset_size)
train_size = dataset_size - test_size
train_dataset, test_dataset = random_split(midi_dataset, [train_size, test_size])

train_loader = DataLoader(train_dataset, shuffle=shuffle, batch_size=batch_size)#, sampler=train_sampler)
test_loader = DataLoader(test_dataset, shuffle=shuffle, batch_size=batch_size)#, sampler=test_sampler)

print("Train size: {}, Test size: {}".format(train_size, test_size))

for i_batch, sample_batched in enumerate(train_loader):
    batch = sample_batched['piano_rolls']
    sample = batch[i_batch]
    print("Train Batch no: {}, Batch size: {} samples, Timesteps per sample: {}".format(i_batch, len(batch), len(sample)))

    break # don't actually enumerate the whole thing..

for i_batch, sample_batched in enumerate(test_loader):
    batch = sample_batched['piano_rolls']
    sample = batch[i_batch]
    print("Test Batch no: {}, Batch size: {} samples, Timesteps per sample: {}".format(i_batch, len(batch), len(sample)))

    break # don't actually enumerate the whole thing..