## Set up dependencies

In [None]:
!pip install kaggle
!pip install torch torchaudio torchvision
!pip install matplotlib
!sudo apt install libsox-dev
!mkdir ~/.kaggle
!touch ~/.kaggle/kaggle.json
!echo '{"username":"antoinedangeard","key":"445fa2e3c51d7c9afd628cc57cd7fa33"}' > ~/.kaggle/kaggle.json

Reading package lists... Done
Building dependency tree... Done
Reading state information... Done
libsox-dev is already the newest version (14.4.2+git20190427-2+deb11u2ubuntu0.22.04.1).
0 upgraded, 0 newly installed, 0 to remove and 39 not upgraded.
mkdir: cannot create directory ‘/root/.kaggle’: File exists


In [None]:
import torch
import torchaudio
import torchvision
import matplotlib.pyplot as plt
from IPython.display import Audio
from PIL import Image
import random
import gc

device = "cuda" if torch.cuda.is_available() else "cpu"
print("Device: {}".format(device))

Device: cuda


## Load dataset

In [None]:
!kaggle datasets download -d andradaolteanu/gtzan-dataset-music-genre-classification
!unzip gtzan-dataset-music-genre-classification.zip -d GTZAN
!rm gtzan-dataset-music-genre-classification.zip

In [None]:
GENRE_TO_LABEL_MAPPING = ["blues", "classical", "country", "disco", "hiphop", "jazz", "metal", "pop", "reggae", "rock"]
GTZAN_SAMPLE_RATE = 22050
TARGET_SAMPLE_RATE = 11025
WAVEFORM_LENGTH = 30 * 11025 # 30 seconds of audio

# Load the dataset
print(f"Loading raw images from dataset at content/GTZAN")
samples = []
labels = []
with open("GTZAN/Data/features_30_sec.csv", 'r') as file:
    for line in file:
        fields = line.strip().split(",")
        genre = fields[-1]
        wav_filename = "GTZAN/Data/genres_original/{}/{}".format(genre, fields[0])
        try:
          raw_sample, _ = torchaudio.load(wav_filename)
          sample = torchaudio.functional.resample(raw_sample, orig_freq=GTZAN_SAMPLE_RATE, new_freq=TARGET_SAMPLE_RATE)
          sample = sample[:, :WAVEFORM_LENGTH]
          label = GENRE_TO_LABEL_MAPPING.index(genre)
          samples.append(sample)
          labels.append(label)
        except RuntimeError:
          print("Missing sample {}".format(wav_filename))

Loading raw images from dataset at content/GTZAN
Missing sample GTZAN/Data/genres_original/label/filename
Missing sample GTZAN/Data/genres_original/jazz/jazz.00054.wav


In [None]:
# Sanity check: play one of the samples to check it is correctly loaded
random_index = random.randint(0, len(samples)-1)
print("Listening to {}:".format(GENRE_TO_LABEL_MAPPING[labels[random_index]]))
print(samples[random_index].shape)
Audio(samples[random_index].squeeze().numpy(), rate=TARGET_SAMPLE_RATE)

Listening to metal:
torch.Size([1, 330750])


## Prepare + Augment Dataset

In [None]:
gains = ["-n"]
pitch_transforms = [torchaudio.transforms.PitchShift(TARGET_SAMPLE_RATE, n) for n in [-2, 2]]

total_iterations = len(samples) * (len(gains) + len(pitch_transforms))
iterations_completed = 0

print("Adding {} new samples to the dataset with augmentations...".format(total_iterations))

augmented_samples = []
augmented_labels = []

for i in range(len(samples)):
  for gain in gains:
    print("{}%               ".format(100 * iterations_completed / total_iterations))

    effects = [
        ["gain", str(gain)]
    ]

    augmented_sample, _ = torchaudio.sox_effects.apply_effects_tensor(samples[i], TARGET_SAMPLE_RATE, effects)
    augmented_sample = augmented_sample[:, :WAVEFORM_LENGTH]
    augmented_sample = augmented_sample
    augmented_samples.append(augmented_sample)
    augmented_labels.append(torch.tensor(labels[i]))
    iterations_completed += 1
  for pitch_shift in pitch_transforms:
    print("{}%               ".format(100 * iterations_completed / total_iterations))

    augmented_sample = pitch_shift(samples[i])
    augmented_sample = augmented_sample[:, :WAVEFORM_LENGTH]
    augmented_sample = augmented_sample
    augmented_samples.append(augmented_sample)
    augmented_labels.append(torch.tensor(labels[i]))

    iterations_completed += 1

print("Added {} new samples to the dataset with augmentations.".format(len(augmented_samples)))



Adding 2997 new samples to the dataset with augmentations...
0.0%               
0.0333667000333667%               
0.0667334000667334%               
0.1001001001001001%               
0.1334668001334668%               
0.1668335001668335%               
0.2002002002002002%               
0.2335669002335669%               
0.2669336002669336%               
0.3003003003003003%               
0.333667000333667%               
0.3670337003670337%               
0.4004004004004004%               
0.4337671004337671%               
0.4671338004671338%               
0.5005005005005005%               
0.5338672005338672%               
0.5672339005672339%               
0.6006006006006006%               
0.6339673006339673%               
0.667334000667334%               
0.7007007007007007%               
0.7340674007340674%               
0.7674341007674341%               
0.8008008008008008%               
0.8341675008341675%               
0.8675342008675342%               
0.900900900

In [None]:
# Sanity check: listen to a raw vs. augmented version of the same sample to ensure it is not too extreme
random_raw_sample_index = random.randint(0, len(samples)-1)
print("Original sample of {}:".format(GENRE_TO_LABEL_MAPPING[labels[random_raw_sample_index]]))
Audio(samples[random_raw_sample_index].detach().squeeze().numpy(), rate=TARGET_SAMPLE_RATE)

Original sample of rock:


In [None]:
augmented_samples_per_original_sample = int(len(augmented_samples) / len(samples))
random_augmented_sample_index = random.randint(random_raw_sample_index * augmented_samples_per_original_sample, ((random_raw_sample_index + 1) * augmented_samples_per_original_sample) - 1)
print("Augmented sample:")
Audio(augmented_samples[random_augmented_sample_index].detach().squeeze().numpy(), rate=TARGET_SAMPLE_RATE)

Augmented sample:


In [None]:
# Combine the original and augmented samples together into a tensor
augmented_samples.extend(samples)
augmented_labels.extend(labels)

# Convert the augmented waveforms into spectrograms
spec_transform = torchaudio.transforms.MelSpectrogram(TARGET_SAMPLE_RATE)
for i in range(len(augmented_samples)):
  augmented_samples[i] = spec_transform(augmented_samples[i])

# Stack the spectrograms into a large tensor
augmented_samples = torch.stack(augmented_samples, dim=0).to(device)
augmented_labels = torch.stack(augmented_labels, dim=0).to(device)

# Shuffle the dataset
random_indices = torch.randperm(augmented_samples.shape[0])
augmented_samples = augmented_samples[random_indices]
augmented_labels = augmented_labels[random_indices]

# Split into train and test sets
test_split = 0.2
n_test_samples = int(test_split * augmented_samples.shape[0])
train_samples, test_samples = augmented_samples[n_test_samples:], augmented_samples[:n_test_samples]
train_labels, test_labels = augmented_labels[n_test_samples:], augmented_labels[:n_test_samples]
print(f"{train_labels.shape[0]} images for training, {test_labels.shape[0]} samples for testing.")
del random_indices
del augmented_samples
del augmented_labels
del samples
del labels

