In [None]:
%pip install librosa python-dotenv pydot
%pip install torch==1.10.1+cu113 torchvision==0.11.2+cu113 torchaudio==0.10.1+cu113 -f https://download.pytorch.org/whl/cu113/torch_stable.html

In [None]:
import os
from torchvision import models
import torchvision
import torchvision.transforms as transforms
import torch
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pack_sequence, pad_packed_sequence
import librosa
import utils
import numpy as np
import warnings
import random

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [None]:
full_tracks = utils.load('../fma_metadata/tracks.csv')
TRACKS = full_tracks[full_tracks['set', 'subset'] <= 'small']
GENRES = utils.load('../fma_metadata/genres.csv')
FEATURES = utils.load('../fma_metadata/features.csv')
ECHONEST = utils.load('../fma_metadata/echonest.csv')

In [None]:
# some files are misclipped, see issue #41 and #8
SHORTER_IDS = [99134, 108925, 133297,  # empty
               98565, 98567, 98569]  # < 2sec


def prepare_fullset(musicSet, is_logmel=True, verbose=True):
    i = 0
    for _id, _ in musicSet.iterrows():
        if _id in SHORTER_IDS:
            continue
        filename = utils.get_audio_path('../fma_small/', _id)
        filename_root = os.path.splitext(filename)[0]
        file_stft = filename_root + "_stft.npy"
        file_logmel = filename_root + "_log_mel.npy"
        log_mel_exists = os.path.isfile(file_logmel)
        stft_exists = os.path.isfile(file_stft)
        if not(stft_exists and is_logmel and log_mel_exists):
            print(f"Music {_id} does not have precomputed features... Computing them now")
            x, sr = librosa.load(filename, sr=None, mono=True)
            if not stft_exists:
                stft = np.abs(librosa.stft(x, n_fft=2048, hop_length=512))
                np.save(file_stft, (stft, sr))
            if is_logmel and not log_mel_exists:
                mel = librosa.feature.melspectrogram(sr=sr, S=stft**2)
                log_mel = librosa.amplitude_to_db(mel)
                np.save(file_logmel, (log_mel, sr))
        i += 1
        if verbose and i % 100 == 0:
            print(f"Loaded {i} samples")


class MusicSetv1(Dataset):
    def __init__(self,
                 musicSet,
                 is_logmel=True,
                 device=torch.device("cpu"),
                 verbose=True):
        prepare_fullset(musicSet, is_logmel=True, verbose=False)
        self.songs = []
        self.features = []
        self.srs = []
        i = 0
        for _id, _ in musicSet.iterrows():
            if _id in SHORTER_IDS:
                continue
            filename = utils.get_audio_path('../fma_small/', _id)
            data_path = (os.path.splitext(filename)[0]
                         + ("_log_mel" if is_logmel else "_stft")
                         + ".npy")
            data = np.load(data_path, allow_pickle=True)
            precomputed, sr = data
            self.songs.append(torch.from_numpy(precomputed.T))
            self.srs.append(sr)
            feat = features.loc[_id]
            self.features.append(feat)
            i += 1
            if verbose and i % 100 == 0:
                print(f"Loaded {i} samples")
        self.songs = sorted(self.songs, key=lambda t:t.shape[0], reverse=True)
        self.songs = pack_sequence(self.songs, enforce_sorted=True).to(device)
        self.songs, self.lengths = pad_packed_sequence(self.songs)
        """
        max_length = sorted(map(lambda a: a.shape[1], self.songs))[-1]
        self.songs = list(map(lambda a: np.pad(a, [(0, 0), (0, 1)],
                                               constant_values=float("inf")),
                              self.songs))
        self.songs = list(map(lambda a: padding(a, max_length+1, axis=1),
                                  self.songs))
        self.songs = torch.from_numpy(np.array(self.songs)).to(device)
        """
        self.features = torch.from_numpy(np.array(self.features)).to(device)

    def __len__(self):
        return len(self.songs)

    def __getitem__(self, idx):
        return {"song": self.songs[idx],
                "length": self.lengths[idx],
                "sr": self.srs[idx],
                "feature": self.features[idx]}

In [None]:
class MusicSet(Dataset):
    def __init__(self,
                 musicSet,
                 is_logmel=True,
                 device=torch.device("cpu"),
                 verbose=True):
        # prepare_fullset(musicSet, is_logmel=True, verbose=False)
        self.is_logmel = True
        self.device = device
        self.ids = [_id for _id, _ in TRACKS.iterrows()]
        for i in SHORTER_IDS:
            self.ids.remove(i)
        self.path_extension = ("_log_mel" if self.is_logmel else "_stft")
        self.path_extension += ".npy"

    def __len__(self):
        return len(self.ids)

    def __getitem__(self, idx):
        _id = self.ids[idx]
        file_audio = utils.get_audio_path('../fma_small/', _id)
        data_path = os.path.splitext(file_audio)[0] + self.path_extension
        data = np.load(data_path, allow_pickle=True)
        precomputed, sr = data
        precomputed = torch.from_numpy(precomputed).to(self.device)
        features = FEATURES.loc[_id]
        return (precomputed, sr, features, _id)

In [None]:
ms = MusicSet(tracks, is_logmel=False, verbose=False, device=device)

In [None]:
ms[0]

In [None]:
dl = DataLoader(ms)
dl