# 3. Balanced Dataset

This notebook aims to create a dataset with a maximum of 25 songs per genre in the train split. The
other songs are added to the test split. You can see the results in the
[results notebook](./6-results.ipynb).

In [None]:
!pip install evaluate transformers[torch] torchaudio wandb

In [None]:
import json
from tqdm import tqdm
from collections import Counter
from datasets import DatasetDict, concatenate_datasets

In [None]:
with open("./dataset/genres.json", "r", encoding="utf-8") as f:
    genres = json.load(f)
genres

In [None]:
preprocessed_dataset = DatasetDict.load_from_disk("./dataset/music_lib")
preprocessed_dataset

In [None]:
max_songs_per_genre = 25

counter = { genre: 0 for genre in genres }

indices_to_train = []
indices_to_test = []

current_song = ""
keep = True
for idx, sample in tqdm(enumerate(preprocessed_dataset["train"]), desc="Processing train samples", total=len(preprocessed_dataset["train"])):
    if current_song != sample["paths"]:
        current_song = sample["paths"]
        genre = genres[sample["labels"]]
        if counter[genre] == 25:
           keep = False
        else:
            keep = True
            counter[genre] += 1
    if keep:
        indices_to_train.append(idx)
    else:
        indices_to_test.append(idx)
    
len(indices_to_train), len(indices_to_test)

In [None]:
{ genres[k]: v for k,v in sorted(Counter(preprocessed_dataset["validate"]["labels"]).items(), key=lambda item: item[1])}

In [None]:
balanced_dataset = DatasetDict({
    "train": preprocessed_dataset["train"].select(indices_to_train),
    "validate": preprocessed_dataset["validate"],
    "test": concatenate_datasets([preprocessed_dataset["test"], preprocessed_dataset["train"].select(indices_to_test)])
})

balanced_dataset

In [None]:
{ k: v for k,v in sorted(Counter(balanced_dataset["train"]["labels"]).items(), key=lambda item: item[1])}

In [None]:
balanced_dataset.save_to_disk("./dataset/music_lib_balanced")