In [1]:
import torch
from torch.utils.data import Dataset

import torchaudio
import torchaudio.transforms

import sys, os

from pprint import pprint

from tqdm.autonotebook import tqdm

import json

import numpy as np

import matplotlib.pylab as plt
import seaborn as sns

import librosa
import librosa.display

import pandas as pd

from pathlib import Path

import gc

MANUAL_SEED = 69

import torch.nn as nn
import torch.nn.functional as F

import torch.optim as optim

from datetime import date
from datetime import datetime

import os.path
from os import path
  
import json

import time

import copy

from matplotlib import pyplot as plt
plt.rcParams['figure.dpi'] = 150
plt.rcParams['savefig.dpi'] = 150

from sklearn.decomposition import PCA
from sklearn.preprocessing import MinMaxScaler


In [2]:
!jupyter nbextension enable --py widgetsnbextension

Enabling notebook extension jupyter-js-widgets/extension...
      - Validating: [32mOK[0m


In [3]:
def plot_spectrogram(spec, title=None, ylabel='freq_bin', aspect='auto', xmax=None):
  fig, axs = plt.subplots(1, 1)
  axs.set_title(title or 'Spectrogram (db)')
  axs.set_ylabel(ylabel)
  axs.set_xlabel('frame')
  im = axs.imshow(librosa.power_to_db(spec), origin='lower', aspect=aspect)
  if xmax:
    axs.set_xlim((0, xmax))
  fig.colorbar(im, ax=axs)
  plt.show(block=False)

In [4]:
def make_dir_if_absent(dir_path):
  
  # print("making dir: ", dir_path)
  
  if not os.path.exists(dir_path):
    os.makedirs(dir_path)
    

In [5]:
class FMADataset(Dataset):

  def __init__(
    self, path, normalize_audio, audio_num_frames, waveform_mean, waveform_std,
    mel_spectrogram_n_fft, mel_spectrogram_win_length, 
    mel_spectrogram_hop_length, mel_spectrogram_n_mels
  ):
    self.path = path
    self.normalize_audio = normalize_audio
    self.audio_num_frames = audio_num_frames
    self.waveform_mean = waveform_mean
    self.waveform_std = waveform_std

    self.mel_spectrogram_n_fft = mel_spectrogram_n_fft
    self.mel_spectrogram_win_length = mel_spectrogram_win_length
    self.mel_spectrogram_hop_length = mel_spectrogram_hop_length
    self.mel_spectrogram_n_mels = mel_spectrogram_n_mels
    
    self.data = self._load_audio_list()

  def __len__(self):
    return len(self.data)

  def __getitem__(self, idx):
    
    try: 
      
      waveform, sample_rate = torchaudio.load(
        filepath=self.data[idx], 
        normalize=self.normalize_audio,
        num_frames=self.audio_num_frames
      )
      
      if waveform.shape[1] < self.audio_num_frames:
        waveform = self._apply_padding(waveform)
      
      label = self.data[idx].split("/")[-2]
      label_one_hot = self._label_from_str_to_one_hot(label)

      mel_spectrogram = self._get_mel_spectrogram(
        sample_rate=sample_rate, waveform=waveform
      )

      mel_spectrogram_export_full_path = self.data[idx]
      mel_spectrogram_export_full_path = mel_spectrogram_export_full_path.replace(
        "audio", "mel_spectrogram"
      )

      og_dataset_name = mel_spectrogram_export_full_path.split("/")[3]
      mel_spec_dataset_name = f"{og_dataset_name}_n_fft_{self.mel_spectrogram_n_fft}_win_length_{self.mel_spectrogram_win_length}_hop_length_{self.mel_spectrogram_hop_length}_n_mels_{self.mel_spectrogram_n_mels}"


      mel_spectrogram_export_full_path_as_list = mel_spectrogram_export_full_path.split("/")
      mel_spectrogram_export_full_path_as_list[3] = mel_spec_dataset_name

      mel_spectrogram_export_full_path = "/".join(
        mel_spectrogram_export_full_path_as_list
      )

      mel_spectrogram_export_dir = "/".join(
        mel_spectrogram_export_full_path.split("/")[:-1]
      )

      make_dir_if_absent(mel_spectrogram_export_dir)

      mel_spectrogram_export_full_path = mel_spectrogram_export_full_path.replace("mp3", "mel_spec")
      
      torch.save(
        mel_spectrogram, mel_spectrogram_export_full_path
      )

      waveform_path = self.data[idx]

      waveform_path = waveform_path.replace("audio", "waveform")
      
      waveform_path = waveform_path.replace(".mp3", ".waveform")

      waveform_export_dir = "/".join(
        waveform_path.split("/")[:-1]
      )

      make_dir_if_absent(waveform_export_dir)

      torch.save(waveform, waveform_path)

      
      return waveform, waveform_path, mel_spectrogram, mel_spectrogram_export_full_path, label_one_hot
    
    except Exception as e:
      print(f"Got the following exception for the file {self.data[idx]}")
      print("\n\n")
      print(e)


  def _get_mel_spectrogram(self, sample_rate, waveform):
    mel_spectrogram = torchaudio.transforms.MelSpectrogram(
        sample_rate=sample_rate,
        n_fft=self.mel_spectrogram_n_fft,
        win_length=self.mel_spectrogram_win_length,
        hop_length=self.mel_spectrogram_hop_length,
        center=True,
        pad_mode="reflect",
        power=2.0,
        norm='slaney',
        onesided=True,
        n_mels=self.mel_spectrogram_n_mels,
        mel_scale="htk",
    )

    return mel_spectrogram(waveform)
      
  def _apply_padding(self, to_pad):
    padding_size = self.audio_num_frames - to_pad.shape[1]
    
    return torch.nn.functional.pad(
      to_pad, (0, padding_size)
    )
  
  def _label_from_str_to_one_hot(self, label_str: str): 
  
    if label_str == "Pop":
      return torch.tensor([1, 0, 0, 0, 0, 0]).float()
    
    if label_str == "Hip-Hop":
      return torch.tensor([0, 1, 0, 0, 0, 0]).float()
    
    if label_str == "Electronic":
      return torch.tensor([0, 0, 1, 0, 0, 0]).float()
    
    if label_str == "Rock":
      return torch.tensor([0, 0, 0, 1, 0, 0]).float()

    if label_str == "Folk":
      return torch.tensor([0, 0, 0, 0, 1, 0]).float()

    if label_str == "Jazz":
      return torch.tensor([0, 0, 0, 0, 0, 1]).float()
    
  
  def _load_audio_list(self):
    
    audio_path_list = []
    
    for path, subdirs, files in tqdm(os.walk(self.path), colour="magenta"):
      for name in files:
          
        file_audio_path = os.path.join(path, name)
        
        audio_path_list.append(file_audio_path)
        
    return audio_path_list
        
        

In [6]:
DATASET_SIZE = "xl"
DATASET_FOLDER = "./data/audio"

DATASET_NUM_SAMPLES_PER_SECOND = 8000
DATASET_NUM_CHANNELS = 1

DATASET_NAME = f"fma_{DATASET_SIZE}_resampled_{DATASET_NUM_SAMPLES_PER_SECOND}_rechanneled_{DATASET_NUM_CHANNELS}"

dataset_path = f"{DATASET_FOLDER}/{DATASET_NAME}"

TRAINING_LOGS_FOLDER = "./logs"

NORMALIZE_AUDIO = True
AUDIO_NUM_FRAMES = 238000

In [7]:
# summary_statistics_json = open(
#   f"{dataset_path}_summary_statistics/{DATASET_NAME}_summary_statistics.json"
# )

# summary_statistics_dict = json.load(summary_statistics_json)

In [8]:
fma_dataset = FMADataset(
  path=dataset_path, 
  normalize_audio=NORMALIZE_AUDIO, 
  audio_num_frames=AUDIO_NUM_FRAMES,
  waveform_mean=-1,  
  waveform_std=-1, 
  mel_spectrogram_n_fft=1024, 
  mel_spectrogram_win_length=None, 
  mel_spectrogram_hop_length=128,
  mel_spectrogram_n_mels=128

)

0it [00:00, ?it/s]

In [9]:
for data_sample in tqdm(fma_dataset, colour="cyan"):
  waveform, waveform_path, _, _, _ = data_sample

  0%|          | 0/23386 [00:00<?, ?it/s]

: 

: 

In [None]:
loaded_waveform = torch.load(
  waveform_path
)

In [None]:
loaded_waveform.shape

In [None]:
torch.isclose(loaded_waveform, waveform).sum()/torch.isclose(loaded_waveform, waveform).numel()