In [1]:
import torch
from torch.utils.data import Dataset

import torchaudio
import torchaudio.transforms

import sys, os

from pprint import pprint

from tqdm.notebook import tqdm


In [6]:
class DatasetConverter(Dataset):
    def __init__(
      self, path: str, use_spectrogram: bool, normalize_audio: bool, 
      audio_num_frames: int, audio_hop_length: int
    ):
      self.path = path
      self.use_spectrogram = use_spectrogram
      self.normalize_audio = normalize_audio
      self.audio_num_frames = audio_num_frames
      self.audio_hop_length = audio_hop_length
      
      # TODO load the raw dataset only once, then use it in another class/cell
      # to perform all the desired "compressions"/samplings/whatever 
      self.data_waveform, self.data_mel_spectrogram = self.load_audio_data()
      
      # TODO export self.data_waveform to disk
      # TODO export self.data_spectrogram to disk
  

    def __len__(self):
      return len(self.data)

    def __getitem__(self, idx):
      return self.data_waveform[idx], self.data_mel_spectrogram[idx]
    
    def label_from_str_to_one_hot(self, label_str: str): 
      
      if label_str == "Pop":
        return torch.tensor([1, 0, 0, 0, 0, 0])
      
      if label_str == "Hip-Hop":
        return torch.tensor([0, 1, 0, 0, 0, 0])
      
      if label_str == "Electronic":
        return torch.tensor([0, 0, 1, 0, 0, 0])
      
      if label_str == "Rock":
        return torch.tensor([0, 0, 0, 1, 0, 0])

      if label_str == "Folk":
        return torch.tensor([0, 0, 0, 0, 1, 0])

      if label_str == "Jazz":
        return torch.tensor([0, 0, 0, 0, 0, 1])
      
    def get_mel_spectrogram(self, waveform, sample_rate):
      
      # TODO save the MelSpectrogram object in the class constructor in order
      # to avoid to re-init it every time
      
      n_fft = 1024
      win_length = None
      # hop_length = 512
      hop_length = 1
      n_mels = 128

      mel_spectrogram_transform = torchaudio.transforms.MelSpectrogram(
        sample_rate=sample_rate,
        n_fft=n_fft,
        win_length=win_length,
        hop_length=hop_length,
        center=True,
        pad_mode="reflect",
        power=2.0,
        norm='slaney',
        onesided=True,
        n_mels=n_mels,
        mel_scale="htk",
      )

      return mel_spectrogram_transform(waveform)
    
    def load_audio_data(self):
      
      audio_file_list = []
      data_waveform, data_mel_spectrogram = [], []

      for path, subdirs, files in tqdm(os.walk(self.path), colour="teal"):
        for name in tqdm(files, colour="turquoise"):
            
            file_audio_path = os.path.join(path, name)
            
            try:
              waveform, sample_rate = torchaudio.load(
                file_audio_path, normalize=self.normalize_audio,
                num_frames=self.audio_num_frames
              )
            except:
              print(f"Got an error while loading {file_audio_path}")
              continue
            
            waveform = torchaudio.functional.resample(
              waveform, orig_freq=self.audio_hop_length, new_freq=1
              # orig_freq=sample_rate, 
              # new_freq=sample_rate / self.audio_hop_length
            )
            
            label = file_audio_path.split("/")[-2]
            label_one_hot = self.label_from_str_to_one_hot(label)
            
            temp_dict_waveform = {
              "label": label_one_hot,
              "waveform": waveform
            }
            
            temp_dict_mel_spectrogram = {
              "label": label_one_hot,
              "mel_spectrogram": self.get_mel_spectrogram(
                waveform, sample_rate
              )
            }
            
            data_waveform.append(temp_dict_waveform)
            data_mel_spectrogram.append(temp_dict_mel_spectrogram)
            
      return data_waveform, data_mel_spectrogram


In [7]:
dataset_converter = DatasetConverter(
  path="./data/fma_extra_small_organized_by_label/", 
  # path="./data/fma_large_6_top_level_downsampled_organized_by_label/", 
  # KEEP THIS SET TO TRUE WHILE EXPORTING THE DATASETS IN THE "COMPRESSED" FORMAT
  use_spectrogram=True, 
  normalize_audio=True, audio_num_frames=1320000, 
  audio_hop_length=512
)

0it [00:00, ?it/s]

0it [00:00, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

In [10]:
print(dataset_converter.data_mel_spectrogram[3]["mel_spectrogram"].shape)
print(dataset_converter.data_waveform[3]["waveform"].shape)

torch.Size([2, 128, 2580])
torch.Size([2, 2579])
