# 

# 1. Preprocessing

This notebook has been run on my laptop, such that a dataset with spectrograms and genres from my
personal music library has been created.

## 1.1 Imports

In [None]:
import os
import json
import torch
import hashlib
import matplotlib.pyplot as plt

from tqdm import tqdm
from torch import tensor
from datasets import Dataset, DatasetDict, concatenate_datasets, Audio, ClassLabel, Features, Value, Array2D
from transformers import ASTFeatureExtractor

## 1.2 Load Music Paths

My music library consists of multiple top-level directories, like "Big Room", "Techno",
"Classic", etc., which define the genre of the songs inside them. Each of these directories can
have multiple subdirectories, e.g. for CDs, but the genre remains the same. Some folders like my own
songs ("Eigenes"), audio books ("Hörbücher"), and mixes ("Mixes") are excluded from the dataset.

Each folder needs at least 6 songs for the dataset to be used for training because the split into
train, validate, and test sets requires at least one song per genre in each set. 

In [None]:
music_lib: dict[str, list[str]] = {}
def add_songs(dir: str, skip_dirs: list[str] = [], genre: str = ""):
    for dirent in os.listdir(dir): 
        path = os.path.join(dir, dirent)
        if os.path.isdir(path) and dirent not in skip_dirs:
            add_songs(path, genre=genre if genre else dirent)
        elif os.path.isfile(path) and path.endswith(".mp3"):
            if genre not in music_lib:
                music_lib[genre] = []
            music_lib[genre].append(path)

add_songs("D:\\Music", ["$RECYCLE.BIN", "downloads", "Eigenes", "Hörbücher", "Mixes", "System Volume Information"])
for genre, songs in list(music_lib.items()):
    if len(songs) < 6: # at leat one in each dataset 
        del music_lib[genre]
    else:
        print(f"{genre}: {len(songs)} songs")

You can see, that the classes are highly unbalanced as I prefer some genres over others. E.g.
"Hands Up" has over 1373 songs while "Folk" only has 6 (the minimum).

The following genres are the under- and over-represented ones.

In [None]:
lengths = { genre: len(songs) for genre, songs in music_lib.items() }
treshold25 = tensor(list(lengths.values()), dtype=torch.float).quantile(0.25).item()
treshold75 = tensor(list(lengths.values()), dtype=torch.float).quantile(0.75).item()

print("Under-represented", [genre for genre, length in lengths.items() if length < treshold25])
print("Over-represented", [genre for genre, length in lengths.items() if length > treshold75])

In [None]:
plt.bar(lengths.keys(), lengths.values())
plt.title("Distribution of number of songs per genre")
plt.xticks(rotation=90)
plt.show()

Eventually, I have a favorite genre :D

From the loaded music dict, a list of genres, their respective index (label), and file paths are
created.

In [None]:
genres = list(music_lib.keys())
filepaths: list[str] = []
labels: list[int] = []
for genre_idx, songs in enumerate(music_lib.values()):
    for filepath in songs:
        filepaths.append(filepath)
        labels.append(genre_idx)

Save the genres for later use.

In [None]:
with open("../dataset/genres.json", "w+", encoding="utf-8") as f:
    json.dump(genres, f, ensure_ascii=False)

## 1.3 Create an audio dataset

Load the pre-trained model `MIT/ast-finetuned-audioset-10-10-0.4593` to get the name of the
input column and define other column names.

In [None]:
pretrained_model = "MIT/ast-finetuned-audioset-10-10-0.4593"
feature_extractor = ASTFeatureExtractor.from_pretrained(pretrained_model)

model_input_name = feature_extractor.model_input_names[0]
labels_name = "labels"
paths_name = "paths"

Now, a function to hash the music dict is provided because HuggingFace datasets support setting
fingerprints to load from cache. This way, the same splits are used when nothing changed in the
music library, otherwise new splits are created.

In [None]:
def hash_dict(d):
    dict_tuple = tuple(sorted(d.items()))
    return hashlib.sha256(repr(dict_tuple).encode()).hexdigest()[:48]

music_lib_hash = hash_dict(music_lib)

The dataset is created with features of `Audio` for the audio files and `ClassLabel` for their
labels. To avoid loading the songs during the dataset split, temporary features are used at first
with a simple string value for the song path.

The split is then done twice. First with 70 % train and 30 % test, of which one third is used for
validation and two third for the actual test set. After that, all three datasets are merged into one
and the dataset is cast to the final features. The `stratify_bo_column` options allows the dataset
to contain a balanced of genres, such that each genre has at least one song in each set.

In [None]:
features = Features({
    model_input_name: Audio(sampling_rate=feature_extractor.sampling_rate),
    labels_name: ClassLabel(names=genres)
})

tmp_features = Features({
    model_input_name: Value("string"),
    labels_name: ClassLabel(names=genres)
})

dataset = Dataset.from_dict({
    model_input_name: filepaths,
    labels_name: labels,
}, features=tmp_features).train_test_split(
    test_size=0.3,
    shuffle=True,
    stratify_by_column=labels_name,
    load_from_cache_file=True,
    train_indices_cache_file_name="../cache/train_indices",
    test_indices_cache_file_name="../cache/validate_test_indices",
    train_new_fingerprint=f"train_{music_lib_hash}",
    test_new_fingerprint=f"validate_test_{music_lib_hash}"
)

validation_dataset = dataset["test"].train_test_split(
    test_size=2/3,
    shuffle=True,
    stratify_by_column=labels_name,
    load_from_cache_file=True,
    train_indices_cache_file_name="../cache/validate_indices",
    test_indices_cache_file_name="../cache/test_indices",
    train_new_fingerprint=f"validate_{music_lib_hash}",
    test_new_fingerprint=f"test_{music_lib_hash}"
)

dataset["validate"] = validation_dataset["train"]
dataset["test"] = validation_dataset["test"]

dataset = dataset.cast(features)
dataset

## 1.4 Create a spectrogram dataset

Next, a function to create spectrograms is provided. This function gets an `Audio` object as input
and uses the pre-defined parameters `sample_len_sec` and `samples_per_song` to create multiple
spectrograms from the song with a given length. Since the AST model is trained on 10.24 seconds
of audio, this is the default value.

At first, I thought that the feature extractor would create a spectrogram from the whole song. Thus,
my first dataset contained only one spectrogram per song and only of the first ten seconds. Using it
to fine-tune the model led to a disappointing result of around 55 % accuracy. During the analysis, I
saw how the feature extractor really works. Then, I decided to take to use snippets of 5 seconds at
the beginning and after every 30 seconds of a song, but 5 at max. This led to a variable number of
snippets per song due to different lengths. But using a different time than 10.24 s leads to
modifications of the model config. I had to change the `max_length` field to be able to process
another snippet length. I didn't realize, that this caused the model to ignore the pre-trained
weights of the position embeddings, which I had to train anew then. The best model got a performance
of around 79 %. When using the pre-trained length, everything worked far better. The training
converged faster, and led to a better accuracy. Also, I was using 6 snippets per song, equally
distributed over the length of a song, which enabled an easier creation of a model working with
whole songs. The best model had then an accuracy of around 87.5 %.

Visualizations of the results are inside the [results-notebook](./6-results.ipynb). 

In [None]:
sample_len_sec = 10.24
samples_per_song = 6
def preprocess_song(song):
    audio = song[model_input_name]
    sampling_rate = audio["sampling_rate"]
    wav = audio["array"]
    song_spetrograms = []
    song_labels = []
    song_paths = []
    for start_sample in range(0, len(wav) - len(wav) % samples_per_song, len(wav) // samples_per_song):
        end_sample = start_sample + round(sample_len_sec * sampling_rate)
        input_wav = wav[start_sample:end_sample]
        song_spetrograms.append(feature_extractor(input_wav, sampling_rate=sampling_rate, return_tensors="pt")[model_input_name][0])
        song_labels.append(song[labels_name])
        song_paths.append(audio["path"])
    
    return song_spetrograms, song_labels, song_paths

A small helper function to visualize the spectrograms. You will find it in several places.

In [None]:
def visualize_spectrum(specs, size=(10,6), cols=1, rows=1):
    plt.figure(figsize=size)
    for idx, spec in enumerate(specs):
        plt.subplot(rows, cols, idx + 1)
        plt.imshow(spec.T, aspect='auto', origin='lower', cmap='viridis')
        plt.colorbar(label="Amplitude")
        plt.xlabel("Time Frames")
        plt.ylabel("Frequency Bins")
        plt.tight_layout()
    plt.tight_layout()
    plt.show()

Now, we can for example look at the spectrogram of the first song in the train split using the
`preprocess_song` function. Since one song has six spectrograms, we can choose a greater plot size
and set the cols to 2 and the rows to 3.

In [None]:
specs, labels, paths = preprocess_song(dataset["train"][0])
print(paths[0])
visualize_spectrum(specs, size=(20, 12), cols=2, rows=3)

These are the spectrograms of a "Drum and Bass" song. You can clearly see where drums are playing
and where not.

Now we can preprocess the whole dataset using the `preprocess_song` function. The new dataset
features now consist of a 2D array for the spectrograms, the song label, and the song path. The
preprocessing has to be done in batches such that the process does not run out of memory.

In [None]:
spectrogram_features = Features({
    model_input_name: Array2D(shape=(feature_extractor.max_length, feature_extractor.num_mel_bins), dtype="float32"),
    labels_name: ClassLabel(names=genres),
    paths_name: Value(dtype="string")
})

for split, sub_dataset in dataset.items():
    batch_size = 2_000
    batch_spectrograms = []
    batch_labels = []
    batch_paths = []
    def save_batch(idx: int):
        batch_dict = {model_input_name: batch_spectrograms, labels_name: batch_labels, paths_name: batch_paths}
        partial_dataset = Dataset.from_dict(batch_dict, features=spectrogram_features)
        partial_dataset.save_to_disk(f"../spectrums/{split}_batch_{idx}")

    idx = 0
    for song in tqdm(sub_dataset, desc=f"Preparing {split} spectrums", total=len(sub_dataset)):
        song_spectrograms, song_labels, song_paths = preprocess_song(song)
        batch_spectrograms = [*batch_spectrograms, *song_spectrograms]
        batch_labels = [*batch_labels, *song_labels]
        batch_paths = [*batch_paths, *song_paths]
        if len(batch_spectrograms) >= batch_size:
            save_batch(idx)
            batch_spectrograms = []
            batch_labels = []
            batch_paths = []
            idx += 1
    save_batch(idx)

Now, we can load the spectrograms from disk using the `Dataset.load_from_disk` function, concatenate
them into a final preprocessed dataset, and save it to disk.

In [None]:
batch_paths = [f"../spectrums/{batch_dir}" for batch_dir in os.listdir("../spectrums")]

train_batches = [Dataset.load_from_disk(path) for path in batch_paths if "train" in path]
validate_batches = [Dataset.load_from_disk(path) for path in batch_paths if "validate" in path]
test_batches = [Dataset.load_from_disk(path) for path in batch_paths if "test" in path]

train_dataset = concatenate_datasets(train_batches)
validate_dataset = concatenate_datasets(validate_batches)
test_dataset = concatenate_datasets(test_batches)

preprocessed_dataset = DatasetDict({
    "train": train_dataset,
    "validate": validate_dataset,
    "test": test_dataset
})

In [None]:
preprocessed_dataset.save_to_disk("../dataset/music_lib")

In [None]:
preprocessed_dataset = DatasetDict.load_from_disk("../dataset/music_lib")
preprocessed_dataset

As you can see, the resulting dataset has six times as many entries as the original dataset due to
the six spectrograms per song.

To prove that everything worked as expected, we can again look at the first spectrogram of the train
split...

In [None]:
visualize_spectrum(tensor([preprocessed_dataset["train"][0][model_input_name]]))

...and see that everything is fine. Now we can move on to notebook
[2. Augmentation](./2-augmentation.ipynb).