In [1]:
import os
import shutil
import random
from pathlib import Path
from typing import List


def chunkify(lst: List[str], num_chunks: int) -> List[List[str]]:
    """
    Splits a list into `num_chunks` roughly equal parts.
    """
    avg_len = len(lst) // num_chunks
    return [lst[i * avg_len: (i + 1) * avg_len] for i in range(num_chunks)]


def distribute_iid_clients(
    good_dir: str,
    bad_dir: str,
    output_base_dir: str,
    num_clients: int = 5,
    seed: int = 42
) -> None:
    """
    Distributes a balanced number of 'good' and 'bad' audio files to clients.

    This ensures each client gets the same number of files per class.

    Args:
        good_dir (str): Path to directory with good audio files.
        bad_dir (str): Path to directory with bad audio files.
        output_base_dir (str): Root output folder where client dirs will be created.
        num_clients (int): Number of federated clients.
        seed (int): Random seed for shuffling.
    """
    random.seed(seed)

    # Collect and shuffle file names
    good_files = sorted([f for f in os.listdir(good_dir) if f.endswith(".wav")])
    bad_files = sorted([f for f in os.listdir(bad_dir) if f.endswith(".wav")])

    print(f"[INFO] Total Good Files: {len(good_files)}")
    print(f"[INFO] Total Bad  Files: {len(bad_files)}")

    min_files = min(len(good_files), len(bad_files))
    files_per_client_per_class = min_files // num_clients

    print(f"[INFO] Each client will get {files_per_client_per_class} good + {files_per_client_per_class} bad files")

    # Trim and shuffle
    good_files = good_files[:files_per_client_per_class * num_clients]
    bad_files = bad_files[:files_per_client_per_class * num_clients]
    random.shuffle(good_files)
    random.shuffle(bad_files)

    good_chunks = chunkify(good_files, num_clients)
    bad_chunks = chunkify(bad_files, num_clients)

    # Distribute
    for i in range(num_clients):
        client_path = Path(output_base_dir) / f"client_{i+1}"
        good_out = client_path / "good"
        bad_out = client_path / "bad"
        good_out.mkdir(parents=True, exist_ok=True)
        bad_out.mkdir(parents=True, exist_ok=True)

        for fname in good_chunks[i]:
            shutil.copy2(Path(good_dir) / fname, good_out / fname)
        for fname in bad_chunks[i]:
            shutil.copy2(Path(bad_dir) / fname, bad_out / fname)

        print(f"\nðŸ“¦ Client {i+1}")
        print(f"   â””â”€ Good files: {len(good_chunks[i])}")
        print(f"   â””â”€ Bad  files: {len(bad_chunks[i])}")

    print(f"\nâœ… Distributed {files_per_client_per_class + files_per_client_per_class} files per class across {num_clients} clients.")


distribute_iid_clients(
    good_dir="../../resources/material/train-data/augmented-good-material-taps",
    bad_dir="../../resources/material/train-data/augmented-bad-material-taps",
    output_base_dir="../../resources/material/train-data/federated/IID",
    num_clients=5
)

[INFO] Total Good Files: 11025
[INFO] Total Bad  Files: 11619
[INFO] Each client will get 2205 good + 2205 bad files

ðŸ“¦ Client 1
   â””â”€ Good files: 2205
   â””â”€ Bad  files: 2205

ðŸ“¦ Client 2
   â””â”€ Good files: 2205
   â””â”€ Bad  files: 2205

ðŸ“¦ Client 3
   â””â”€ Good files: 2205
   â””â”€ Bad  files: 2205

ðŸ“¦ Client 4
   â””â”€ Good files: 2205
   â””â”€ Bad  files: 2205

ðŸ“¦ Client 5
   â””â”€ Good files: 2205
   â””â”€ Bad  files: 2205

âœ… Distributed 4410 files per class across 5 clients.
