In [1]:
!pip install evaluate transformers[torch] torchaudio wandb

Collecting evaluate
  Downloading evaluate-0.4.3-py3-none-any.whl.metadata (9.2 kB)
Collecting wandb
  Downloading wandb-0.19.4-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (10 kB)
Collecting transformers[torch]
  Downloading transformers-4.48.1-py3-none-any.whl.metadata (44 kB)
Collecting datasets>=2.0.0 (from evaluate)
  Downloading datasets-3.2.0-py3-none-any.whl.metadata (20 kB)
Collecting xxhash (from evaluate)
  Downloading xxhash-3.5.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (12 kB)
Collecting multiprocess (from evaluate)
  Downloading multiprocess-0.70.17-py312-none-any.whl.metadata (7.2 kB)
Collecting huggingface-hub>=0.7.0 (from evaluate)
  Downloading huggingface_hub-0.27.1-py3-none-any.whl.metadata (13 kB)
Collecting regex!=2019.12.17 (from transformers[torch])
  Downloading regex-2024.11.6-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (40 kB)
Collecting tokenizers<0.22,>=0.21 (from transformers[torch])
 

In [1]:
import json
from tqdm import tqdm
from collections import Counter
from datasets import DatasetDict, concatenate_datasets

In [2]:
with open("./dataset/genres.json", "r", encoding="utf-8") as f:
    genres = json.load(f)
genres

['Bad',
 'Bassy',
 'Big Room',
 'Bounce',
 'Chill',
 'Chillstep',
 'Classic',
 'Coding',
 'Country',
 'Cro',
 'Deep House',
 'Drum and Bass',
 'Dubstep',
 'EDM',
 'Electro',
 'Electro House',
 'Emotional',
 'Epic',
 'Folk',
 'Frenchcore',
 'Glitch Hop',
 'God',
 'Groove',
 'Hands Up',
 'Hardcore',
 'Hardstyle',
 'Harp',
 'Hip Hop & Rap',
 'Historic',
 'Latino',
 'Lounge',
 'Malle',
 'Minimal',
 'Motivation',
 'Orchestra Pop',
 'Orchestral Electro',
 'OVERWERK',
 'Pop',
 'Pop mit Beat',
 'Psy',
 'Psytrance',
 'RnB',
 'Rock',
 'Synthpop',
 'Techno',
 'Tekk',
 'Trance',
 'Weihnachten']

In [3]:
preprocessed_dataset = DatasetDict.load_from_disk("./dataset/music_lib")
preprocessed_dataset

DatasetDict({
    train: Dataset({
        features: ['input_values', 'labels', 'paths'],
        num_rows: 16378
    })
    validate: Dataset({
        features: ['input_values', 'labels', 'paths'],
        num_rows: 2342
    })
    test: Dataset({
        features: ['input_values', 'labels', 'paths'],
        num_rows: 4675
    })
})

In [5]:
max_songs_per_genre = 25

counter = { genre: 0 for genre in genres }

indices_to_train = []
indices_to_test = []

current_song = ""
keep = True
for idx, sample in tqdm(enumerate(preprocessed_dataset["train"]), desc="Processing train samples", total=len(preprocessed_dataset["train"])):
    if current_song != sample["paths"]:
        current_song = sample["paths"]
        genre = genres[sample["labels"]]
        if counter[genre] == 25:
           keep = False
        else:
            keep = True
            counter[genre] += 1
    if keep:
        indices_to_train.append(idx)
    else:
        indices_to_test.append(idx)
    
len(indices_to_train), len(indices_to_test)

Processing train samples: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 16378/16378 [00:58<00:00, 278.97it/s]


(4764, 11614)

In [5]:
{ genres[k]: v for k,v in sorted(Counter(preprocessed_dataset["validate"]["labels"]).items(), key=lambda item: item[1])}

{'Harp': 5,
 'Coding': 5,
 'Folk': 5,
 'Synthpop': 5,
 'Psytrance': 5,
 'Epic': 5,
 'Pop': 5,
 'OVERWERK': 5,
 'Weihnachten': 10,
 'God': 10,
 'Orchestra Pop': 10,
 'Bassy': 10,
 'Emotional': 10,
 'Pop mit Beat': 10,
 'Historic': 10,
 'Dubstep': 15,
 'Glitch Hop': 15,
 'Orchestral Electro': 15,
 'Chillstep': 15,
 'Tekk': 15,
 'RnB': 19,
 'Minimal': 20,
 'Chill': 20,
 'Deep House': 20,
 'Latino': 20,
 'Drum and Bass': 25,
 'EDM': 25,
 'Cro': 30,
 'Big Room': 30,
 'Psy': 35,
 'Frenchcore': 35,
 'Classic': 39,
 'Malle': 45,
 'Country': 45,
 'Groove': 49,
 'Electro House': 55,
 'Hip Hop & Rap': 55,
 'Motivation': 55,
 'Lounge': 55,
 'Techno': 60,
 'Bounce': 79,
 'Hardcore': 80,
 'Rock': 83,
 'Electro': 85,
 'Hardstyle': 110,
 'Bad': 148,
 'Trance': 150,
 'Hands Up': 685}

In [7]:
balanced_dataset = DatasetDict({
    "train": preprocessed_dataset["train"].select(indices_to_train),
    "validate": preprocessed_dataset["validate"],
    "test": concatenate_datasets([preprocessed_dataset["test"], preprocessed_dataset["train"].select(indices_to_test)])
})

balanced_dataset

DatasetDict({
    train: Dataset({
        features: ['input_values', 'labels', 'paths'],
        num_rows: 4764
    })
    validate: Dataset({
        features: ['input_values', 'labels', 'paths'],
        num_rows: 2342
    })
    test: Dataset({
        features: ['input_values', 'labels', 'paths'],
        num_rows: 16289
    })
})

In [8]:
{ k: v for k,v in sorted(Counter(balanced_dataset["train"]["labels"]).items(), key=lambda item: item[1])}

{26: 20,
 18: 20,
 43: 20,
 37: 25,
 40: 30,
 36: 40,
 17: 40,
 7: 40,
 28: 56,
 38: 65,
 21: 70,
 47: 75,
 34: 80,
 16: 80,
 1: 85,
 5: 99,
 45: 100,
 35: 105,
 20: 108,
 30: 120,
 19: 120,
 0: 123,
 3: 123,
 9: 123,
 39: 124,
 15: 124,
 13: 124,
 14: 125,
 23: 125,
 27: 125,
 42: 125,
 22: 125,
 41: 125,
 46: 125,
 24: 125,
 2: 125,
 25: 125,
 31: 125,
 4: 125,
 8: 125,
 29: 125,
 11: 125,
 6: 125,
 10: 125,
 32: 125,
 33: 125,
 44: 125,
 12: 125}

In [9]:
balanced_dataset.save_to_disk("./dataset/music_lib_balanced")

Saving the dataset (0/3 shards):   0%|          | 0/4764 [00:00<?, ? examples/s]

Saving the dataset (0/2 shards):   0%|          | 0/2342 [00:00<?, ? examples/s]

Saving the dataset (0/9 shards):   0%|          | 0/16289 [00:00<?, ? examples/s]