# 4. Song Dataset

This notebook creates a new dataset where entries are songs and not snippets to classify a whole
song using its six snippets at once.

You can look at the training process here: [5. Training Songs](./5-training_songs.ipynb)

In [1]:
!pip install evaluate transformers[torch] torchaudio wandb

In [2]:
import os
import json
from tqdm import tqdm
from torch import tensor
from datasets import DatasetDict, Dataset, concatenate_datasets, Features, Value, ClassLabel, Array3D
from transformers import ASTFeatureExtractor

In [3]:
with open("./dataset/genres.json", "r", encoding="utf-8") as f:
    genres = json.load(f)
genres

In [4]:
pretrained_model = "MIT/ast-finetuned-audioset-10-10-0.4593"
feature_extractor = ASTFeatureExtractor.from_pretrained(pretrained_model)

model_input_name = feature_extractor.model_input_names[0]
labels_name = "labels"
paths_name = "paths"

In [5]:
preprocessed_dataset = DatasetDict.load_from_disk("./dataset/music_lib")
preprocessed_dataset

In [6]:
samples_per_song = 6

song_features = Features({
    model_input_name: Array3D(shape=(samples_per_song, feature_extractor.max_length, feature_extractor.num_mel_bins), dtype="float32"),
    labels_name: ClassLabel(names=genres),
    paths_name: Value(dtype="string")
})

for split, dataset in preprocessed_dataset.items():
    batch_size = 250
    batch_spectrograms = []
    batch_labels = []
    batch_paths = []
    def save_batch(idx: int):
        batch_dict = {model_input_name: batch_spectrograms, labels_name: batch_labels, paths_name: batch_paths}
        partial_dataset = Dataset.from_dict(batch_dict, features=song_features)
        partial_dataset.save_to_disk(f"./songs/{split}_batch_{idx}")

    idx = 0
    for i in tqdm(range(0, len(dataset), samples_per_song), desc=f"Processing {split} samples", total=len(dataset) // samples_per_song):
        batch_spectrograms.append(dataset[i:i + samples_per_song]["input_values"])
        batch_labels.append(dataset[i]["labels"])
        batch_paths.append(dataset[i]["paths"])

        if len(batch_spectrograms) == batch_size:
            save_batch(idx)
            batch_spectrograms = []
            batch_labels = []
            batch_paths = []
            idx += 1
    save_batch(idx)

In [7]:
batch_paths = [f"./songs/{batch_dir}" for batch_dir in os.listdir("./songs")]

train_batches = [Dataset.load_from_disk(path) for path in batch_paths if "train" in path]
validate_batches = [Dataset.load_from_disk(path) for path in batch_paths if "validate" in path]
test_batches = [Dataset.load_from_disk(path) for path in batch_paths if "test" in path]

train_dataset = concatenate_datasets(train_batches)
validate_dataset = concatenate_datasets(validate_batches)
test_dataset = concatenate_datasets(test_batches)

song_dataset = DatasetDict({
    "train": train_dataset,
    "validate": validate_dataset,
    "test": test_dataset
})

song_dataset

In [8]:
song_dataset.save_to_disk("./dataset/music_lib_songs")

In [9]:
tensor(song_dataset["train"][0][model_input_name]).shape