### REI505M Final project: Music genre classification starter pack

The following Dataset class operates on the GTZAN dataset.

* The duration of most GTZAN files are 30 seconds (3022050=661500 samples) but some are slightly shorter (approx 29.9 seconds). For this reason we truncate at 660000 samples below.
* It may be beneficial to work with smaller chunks than ~30 seconds.
* You may want to perform the data augmentations in the `__get_item__` function.
* For now, `train_dataset` contains all the dataset, you need to set aside some examples for validation and test sets.

In [None]:
import torch
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
import scipy.io.wavfile as wav
import os

audio_dir = 'music/' # Path to folder with GTZAN files
# music/
#  - rock/
#       rock.00099.wav
#       ...
#  - reggie/
#  ...
#  - blues/

batch_size = 32

class AudioDataset(Dataset):
    def __init__(self, audio_files, labels, audio_path,
                 maxlen, sampling_rate, duration):
        self.audio_files = audio_files
        self.audio_path = audio_path
        self.labels = labels
        self.maxlen = maxlen
        self.sampling_rate = sampling_rate
        self.duration = duration

    def __len__(self):
        return len(self.audio_files)

    def __getitem__(self, idx):
        label = self.labels[idx]
        audio_file = self.audio_files[idx]
        audio_dir = audio_file[:audio_file.index('.')]
        file_path = os.path.join(self.audio_path, audio_dir, audio_file)
        (rate,audio_samples) = wav.read(file_path)
        audio_samples = torch.from_numpy(audio_samples).to(torch.float32)
        if len(audio_samples) > self.maxlen:
            # Truncate
            audio_samples = audio_samples[:self.maxlen]

        tstart = 0 # Offset from start of song (hyper-parameter!)
        audio_samples = audio_samples[int(self.sampling_rate*tstart):int(self.sampling_rate*(tstart+self.duration))]

        return audio_samples, label

label_map={'blues' : 0, 'classical' : 1, 'country' : 2,
           'disco' : 3, 'hiphop'    : 4, 'jazz'    : 5,
           'metal' : 6, 'pop'       : 7, 'reggae'  : 8, 'rock' : 9}

audio_files = []
labels = []
for root, subdirs, files in os.walk(audio_dir):
    for fname in files:
        if fname == '.DS_Store':
            continue
        audio_files.append(fname)
        labels.append(label_map[fname[:fname.index('.')]])

torch.manual_seed(0) # Reproducible results

# TODO: Create validation and test sets
train_files = [audio_files[i] for i in range(len(audio_files))]
train_labels = [labels[i] for i in range(len(audio_files))]
print("Training set:", len(train_files))
#print("Validation set:", len(validation_files))
#print("Test set:", len(test_files))
#assert len(set(train_files) & set(validation_files) & set(test_files)) == 0

train_dataset = AudioDataset(audio_files=train_files, labels=train_labels,
                             audio_path=audio_dir, 
                             maxlen=660000, sampling_rate=22050, duration=25)
test_dataset=AudioDataset(audio_files=test_files, labels=test_labels,
                          audio_path=audio_dir,
                          maxlen=660000, sampling_rate=22050, duration=25)
val_dataset=AudioDataset(audio_files=validation_files, labels=validation_labels,
                          audio_path=audio_dir,
                          maxlen=660000, sampling_rate=22050, duration=25)



train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False )
val_loader = DataLoader(val_dataset,  batch_size=batch_size, shuffle=False)

tmp_features, tmp_labels = next(iter(train_loader))
print(f"Feature batch shape: {tmp_features.size()}")
print(f"Labels batch shape: {tmp_labels.size()}")