In [None]:
import os, glob, librosa, random
import numpy as np
import soundfile as sf

from collections import defaultdict
from typing import Dict, List
from pathlib import Path

In [11]:
audio_files_musan = glob.glob(
    "../data/musan/music/hd-classical/**/*.*",
    recursive=True,
)
audio_files_musan = [f for f in audio_files_musan if f.endswith(".wav")]
print(f"Total musan music files: {len(audio_files_musan)}")

Total musan music files: 75


In [12]:
audio_files_parts2songkey = glob.glob(
    "../data/acrcloud/parts2songkey/**/*.*",
    recursive=True,
)
audio_files_parts2songkey = [f for f in audio_files_parts2songkey if f.endswith(".wav")]
print(f"Total parts2songkey music files: {len(audio_files_parts2songkey)}")

Total parts2songkey music files: 351


In [13]:
def extract_base_name(file_path):
    base = os.path.basename(file_path)
    base_no_ext, _ = os.path.splitext(base)
    parts = base_no_ext.split(".")
    return parts

In [14]:
label_to_idx = {
    "pop": 0,
    "reggae": 1,
    "country": 2,
    "rock": 3,
    "classical": 4,
    "disco": 5,
    "blues": 6,
    "hiphop": 7,
    "jazz": 8,
    "metal": 9,
}

In [15]:
label_to_songnames = defaultdict(set)

# —— parts2songkey ——
for metadata_path in audio_files_parts2songkey:
    name_parts = extract_base_name(metadata_path)
    label_idx = label_to_idx[name_parts[0]]
    song_path = Path("../data/acrcloud/songkey") / f"{name_parts[-1]}.mp3"

    label_to_songnames[label_idx].add(str(song_path))

# —— musan ——
for song_path in audio_files_musan:
    label_idx = label_to_idx["classical"]
    label_to_songnames[label_idx].add(str(song_path))

label_to_songnames = {lbl: list(paths) for lbl, paths in label_to_songnames.items()}

In [16]:
def split_audio(
    file_path: str, sample_rate: int, segment_duration: int = 30, overlap: int = 0
) -> List[np.ndarray]:
    audio, _ = librosa.load(file_path, sr=sample_rate)
    seg_samples = segment_duration * sample_rate
    step_samples = (segment_duration - overlap) * sample_rate

    segments = []
    for start in range(0, len(audio) - seg_samples + 1, step_samples):
        end = start + seg_samples
        segments.append(audio[start:end])
    return segments


def build_balanced_segments_roundrobin(
    label_to_songnames: Dict[int, List[str]],
    sample_rate: int,
    target_per_label: int,
    segment_duration: int = 30,
    overlap: int = 0,
    shuffle_songs: bool = True,
    shuffle_within_song: bool = True,
) -> Dict[int, List[np.ndarray]]:
    balanced = defaultdict(list)
    for label, song_paths in label_to_songnames.items():
        if shuffle_songs:
            random.shuffle(song_paths)
        song_to_segments = {
            p: split_audio(p, sample_rate, segment_duration, overlap)
            for p in song_paths
        }
        if shuffle_within_song:
            for segs in song_to_segments.values():
                random.shuffle(segs)
        songs_cycle = list(song_to_segments.keys())
        finished_songs = set()
        while len(balanced[label]) < target_per_label and songs_cycle:
            next_cycle = []
            for song in songs_cycle:
                segs = song_to_segments[song]
                if segs:
                    balanced[label].append(segs.pop())
                if segs:
                    next_cycle.append(song)
                else:
                    finished_songs.add(song)
                if len(balanced[label]) >= target_per_label:
                    break
            songs_cycle = next_cycle
        if len(balanced[label]) < target_per_label:
            print(
                f"[WARN] label {label} only got {len(balanced[label])} segments "
                f"(Target {target_per_label}); All labels are exhaustive."
            )
    return balanced

In [17]:
SR = 16000
TARGET_PER_LABEL = 30
SEG_DURATION = 30
OVERLAP = 0
OUTPUT_ROOT = "../data/segments_30s"

segments_by_label = build_balanced_segments_roundrobin(
    label_to_songnames,
    sample_rate=SR,
    target_per_label=TARGET_PER_LABEL,
    segment_duration=SEG_DURATION,
    overlap=OVERLAP,
)

os.makedirs(OUTPUT_ROOT, exist_ok=True)

for lbl, seg_list in segments_by_label.items():
    out_dir = os.path.join(OUTPUT_ROOT, f"label_{lbl}")
    os.makedirs(out_dir, exist_ok=True)

    for i, seg in enumerate(seg_list):
        wav_path = os.path.join(out_dir, f"{i:05d}.wav")
        sf.write(wav_path, seg, SR, subtype="PCM_16")

    print(f"[INFO] label {lbl}: saved {len(seg_list)} segments to {out_dir}")

[INFO] label 2: saved 30 segments to ../data/segments_30s/label_2
[INFO] label 0: saved 30 segments to ../data/segments_30s/label_0
[INFO] label 1: saved 30 segments to ../data/segments_30s/label_1
[INFO] label 7: saved 30 segments to ../data/segments_30s/label_7
[INFO] label 8: saved 30 segments to ../data/segments_30s/label_8
[INFO] label 3: saved 30 segments to ../data/segments_30s/label_3
[INFO] label 6: saved 30 segments to ../data/segments_30s/label_6
[INFO] label 9: saved 30 segments to ../data/segments_30s/label_9
[INFO] label 5: saved 30 segments to ../data/segments_30s/label_5
[INFO] label 4: saved 30 segments to ../data/segments_30s/label_4
