In [17]:
import torch
from torch.utils.data import Dataset

import torchaudio
import torchaudio.transforms

import sys, os

from pprint import pprint

from tqdm.autonotebook import tqdm

import json

import numpy as np

import matplotlib.pylab as plt
import seaborn as sns

import librosa
import librosa.display

import pandas as pd

from pathlib import Path

import gc

MANUAL_SEED = 69

import torch.nn as nn
import torch.nn.functional as F

import torch.optim as optim

from datetime import date
from datetime import datetime

import os.path
from os import path
  
import json

import time

import copy

from matplotlib import pyplot as plt
plt.rcParams['figure.dpi'] = 150
plt.rcParams['savefig.dpi'] = 150

from sklearn.decomposition import PCA



In [18]:
def make_dir_if_absent(dir_path):
  
  # print("making dir: ", dir_path)
  
  if not os.path.exists(dir_path):
    os.makedirs(dir_path)
    

In [19]:
class FMADataset(Dataset):

  def __init__(self, path, normalize_audio, audio_num_frames):
    self.path = path
    self.normalize_audio = normalize_audio
    self.audio_num_frames = audio_num_frames
    
    self.data = self._load_audio_list()

  def __len__(self):
    return len(self.data)

  def __getitem__(self, idx):
    
    try: 
      
      waveform, _ = torchaudio.load(
        filepath=self.data[idx], 
        normalize=self.normalize_audio,
        num_frames=self.audio_num_frames
      )
      
      if waveform.shape[1] < self.audio_num_frames:
        waveform = self._apply_padding(waveform)
      
      label = self.data[idx].split("/")[-2]
      label_one_hot = self._label_from_str_to_one_hot(label)
      
      return waveform, label_one_hot
    
    except Exception as e:
      print(f"Got the following exception for the file {self.data[idx]}")
      print("\n\n")
      print(e)
      
  def _apply_padding(self, to_pad):
    padding_size = self.audio_num_frames - to_pad.shape[1]
    
    return torch.nn.functional.pad(
      to_pad, (0, padding_size)
    )
  
  def _label_from_str_to_one_hot(self, label_str: str): 
  
    if label_str == "Pop":
      return torch.tensor([1, 0, 0, 0, 0, 0]).float()
    
    if label_str == "Hip-Hop":
      return torch.tensor([0, 1, 0, 0, 0, 0]).float()
    
    if label_str == "Electronic":
      return torch.tensor([0, 0, 1, 0, 0, 0]).float()
    
    if label_str == "Rock":
      return torch.tensor([0, 0, 0, 1, 0, 0]).float()

    if label_str == "Folk":
      return torch.tensor([0, 0, 0, 0, 1, 0]).float()

    if label_str == "Jazz":
      return torch.tensor([0, 0, 0, 0, 0, 1]).float()
    
  
  def _load_audio_list(self):
    
    audio_path_list = []
    
    for path, subdirs, files in tqdm(os.walk(self.path), colour="magenta"):
      for name in files:
          
        file_audio_path = os.path.join(path, name)
        
        audio_path_list.append(file_audio_path)
        
    return audio_path_list
        
        

In [20]:
DATASET_SIZE = "extra_large"
DATASET_FOLDER = "../data/audio"

DATASET_NUM_SAMPLES_PER_SECOND = 8000
DATASET_NUM_CHANNELS = 1

DATASET_NAME = f"fma_{DATASET_SIZE}_organized_by_label_resampled_{DATASET_NUM_SAMPLES_PER_SECOND}_rechanneled_{DATASET_NUM_CHANNELS}"

dataset_path = f"{DATASET_FOLDER}/{DATASET_NAME}"

TRAINING_LOGS_FOLDER = "./logs"

NORMALIZE_AUDIO = True
AUDIO_NUM_FRAMES = 238000

In [21]:
# fma_dataset = FMADataset(
fma_dataset = FMADataset(
  path=dataset_path, 
  normalize_audio=NORMALIZE_AUDIO, 
  audio_num_frames=AUDIO_NUM_FRAMES
)

0it [00:00, ?it/s]

In [22]:
fma_dataloader = torch.utils.data.DataLoader(
  fma_dataset, batch_size=64, num_workers=16
)

In [23]:
def get_mean_std(dataloader):

  channels_sum, channels_squared_sum, num_batches = 0, 0, 0

  for batch_inputs, _ in dataloader:
    channels_sum += torch.mean(batch_inputs, dim=[0, 2])
    channels_squared_sum += torch.mean(batch_inputs ** 2, dim=[0, 2])
    num_batches += 1

  mean = channels_sum / num_batches

  std = (channels_squared_sum / num_batches - mean ** 2) ** 0.5

  return mean, std


In [24]:
mean, std = get_mean_std(fma_dataloader)

In [25]:
to_export = {
  "dataset_name": DATASET_NAME,
  "mean": mean.item(),
  "std": std.item()
}

In [26]:
make_dir_if_absent(dir_path=f"{dataset_path}_summary_statistics")

In [27]:
with open(
  f"{dataset_path}_summary_statistics/{DATASET_NAME}_summary_statistics.json", 
  'w'
) as fp:
    json.dump(to_export, fp)

In [28]:
to_export

{'dataset_name': 'fma_extra_large_organized_by_label_resampled_8000_rechanneled_1',
 'mean': 0.0006753258057869971,
 'std': 0.2897718548774719}

In [29]:
# to_compute = []

# for batch_x, batch_y in fma_dataloader:
#   to_compute.append(batch_x)

In [30]:
# to_compute_tensor = torch.cat(to_compute, dim=0)

In [31]:
# to_compute_tensor.shape

In [32]:
# torch.std_mean(to_compute_tensor)