## Set up dependencies

In [None]:
!pip install kaggle
!pip install torch torchaudio torchvision
!pip install matplotlib
!sudo apt install libsox-dev
!mkdir ~/.kaggle
!touch ~/.kaggle/kaggle.json
!echo '{"username":"antoinedangeard","key":"445fa2e3c51d7c9afd628cc57cd7fa33"}' > ~/.kaggle/kaggle.json

In [6]:
import torch
import torchaudio
import torchvision
import matplotlib.pyplot as plt
from IPython.display import Audio
from PIL import Image
import random
import gc

device = "cuda" if torch.cuda.is_available() else "cpu"
print("Device: {}".format(device))

Device: cpu


## Load dataset

In [None]:
!kaggle datasets download -d andradaolteanu/gtzan-dataset-music-genre-classification
!unzip gtzan-dataset-music-genre-classification.zip -d GTZAN
!rm gtzan-dataset-music-genre-classification.zip

In [15]:
GENRE_TO_LABEL_MAPPING = ["blues", "classical", "country", "disco", "hiphop", "jazz", "metal", "pop", "reggae", "rock"]
GTZAN_SAMPLE_RATE = 22050
TARGET_SAMPLE_RATE = 11025

# Load the dataset
print(f"Loading raw images from dataset at content/GTZAN")
samples = []
labels = []
with open("GTZAN/Data/features_30_sec.csv", 'r') as file:
    for line in file:
        fields = line.strip().split(",")
        genre = fields[-1]
        wav_filename = "GTZAN/Data/genres_original/{}/{}".format(genre, fields[0])
        try:
          raw_sample, _ = torchaudio.load(wav_filename)
          sample = torchaudio.functional.resample(raw_sample, orig_freq=GTZAN_SAMPLE_RATE, new_freq=TARGET_SAMPLE_RATE)
          label = GENRE_TO_LABEL_MAPPING.index(genre)
          samples.append(sample)
          labels.append(label)
        except RuntimeError:
          print("Missing sample {}".format(wav_filename))

Loading raw images from dataset at content/GTZAN
Missing sample GTZAN/Data/genres_original/label/filename
Missing sample GTZAN/Data/genres_original/jazz/jazz.00054.wav


In [16]:
# Sanity check: play one of the samples to check it is correctly loaded
random_index = random.randint(0, len(samples)-1)
print("Listening to {}:".format(GENRE_TO_LABEL_MAPPING[labels[random_index]]))
Audio(samples[random_index].squeeze().numpy(), rate=TARGET_SAMPLE_RATE)

Listening to jazz:


## Augment Dataset

In [None]:
gains = [-4, -2, 2, 4]
pitches = [100 * n for n in [-2, -1, 1, 2]]

total_iterations = len(samples) * (len(gains) + len(pitches))
iterations_completed = 0

print("Adding {} new samples to the dataset with augmentations...".format(total_iterations))

augmented_samples = []
augmented_labels = []

for i in range(len(samples)):
  for gain in gains:
    print("{}%               ".format(100 * iterations_completed / total_iterations))

    effects = [
        ["gain", str(gain)]
    ]

    augmented_sample, _ = torchaudio.sox_effects.apply_effects_tensor(samples[i], TARGET_SAMPLE_RATE, effects)
    augmented_sample = augmented_sample.to(device)
    augmented_samples.append(augmented_sample)
    augmented_labels.append(torch.tensor(labels[i]).to(device))
    iterations_completed += 1
  for cent_pitch_shift in pitches:
    print("{}%               ".format(100 * iterations_completed / total_iterations))

    effects = [
        ["pitch", str(cent_pitch_shift)]
    ]

    augmented_sample, _ = torchaudio.sox_effects.apply_effects_tensor(samples[i], TARGET_SAMPLE_RATE, effects)
    augmented_sample = augmented_sample.to(device)
    augmented_samples.append(augmented_sample)
    augmented_labels.append(torch.tensor(labels[i]).to(device))
    iterations_completed += 1

print("Added {} new samples to the dataset with augmentations.".format(len(augmented_samples)))

[1;30;43mStreaming output truncated to the last 5000 lines.[0m
26.276276276276278%               
26.28878878878879%               
26.3013013013013%               
26.313813813813812%               
26.326326326326328%               
26.33883883883884%               
26.35135135135135%               
26.363863863863862%               
26.376376376376378%               
26.38888888888889%               
26.4014014014014%               
26.413913913913913%               
26.426426426426428%               
26.43893893893894%               
26.45145145145145%               
26.463963963963963%               
26.476476476476478%               
26.48898898898899%               
26.5015015015015%               
26.514014014014013%               
26.526526526526528%               
26.53903903903904%               
26.55155155155155%               
26.564064064064063%               
26.576576576576578%               
26.58908908908909%               
26.6016016016016%               
26.61411

In [19]:
# Sanity check: listen to a raw vs. augmented version of the same sample to ensure it is not too extreme
random_raw_sample_index = random.randint(0, len(samples)-1)
print("Original sample of {}:".format(GENRE_TO_LABEL_MAPPING[labels[random_raw_sample_index]]))
Audio(samples[random_raw_sample_index].squeeze().numpy(), rate=TARGET_SAMPLE_RATE)

Original sample of metal:


In [20]:
augmented_samples_per_original_sample = int(len(augmented_samples) / len(samples))
random_augmented_sample_index = random.randint(random_raw_sample_index * augmented_samples_per_original_sample, ((random_raw_sample_index + 1) * augmented_samples_per_original_sample) - 1)
print("Augmented sample:")
Audio(augmented_samples[random_augmented_sample_index].cpu().squeeze().numpy(), rate=TARGET_SAMPLE_RATE)

ValueError: empty range for randrange() (0, 0, 0)

In [None]:
# train / test split
test_ratio = 0.2
test_size = int(test_ratio * len(dataset))
train_size = len(dataset) - test_size
train_dataset, test_dataset = torch.utils.data.random_split(dataset, [train_size, test_size])
print(f"{train_size} images for training, {test_size} images for testing.")