In [3]:
import os
import shutil
import random

def distribute_noniid_clients(
    good_dir,
    bad_dir,
    output_base_dir,
    num_clients=5,
    skew=None,
    seed=42
):
    """
    Distribute 4410 files (2205 good + 2205 bad) per client with class imbalance (non-IID).

    Args:
        good_dir (str): Path to 'good' audio .wav files.
        bad_dir (str): Path to 'bad' audio .wav files.
        output_base_dir (str): Output directory to store client folders.
        num_clients (int): Number of federated clients.
        skew (list): A list of 10 ratios (good_ratio_1, bad_ratio_1, ..., good_ratio_5, bad_ratio_5)
                     where good + bad = 1 for each client.
        seed (int): Seed for reproducibility.
    """
    random.seed(seed)

    good_files = sorted([os.path.join(good_dir, f) for f in os.listdir(good_dir) if f.endswith('.wav')])
    bad_files = sorted([os.path.join(bad_dir, f) for f in os.listdir(bad_dir) if f.endswith('.wav')])

    min_samples_per_class = min(len(good_files), len(bad_files))
    total_samples_per_client = 4410  # 2205 good + 2205 bad expected

    max_clients_supported = min_samples_per_class * 2 // total_samples_per_client
    if num_clients > max_clients_supported:
        raise ValueError(f"Too many clients ({num_clients}) for available data.")

    print(f"[INFO] Total Good Files: {len(good_files)}")
    print(f"[INFO] Total Bad  Files: {len(bad_files)}")
    print(f"[INFO] Each client will receive: {total_samples_per_client} files")

    good_index, bad_index = 0, 0

    for i in range(num_clients):
        g_ratio = skew[i * 2] if skew else 0.5
        b_ratio = 1 - g_ratio

        g_count = int(total_samples_per_client * g_ratio)
        b_count = total_samples_per_client - g_count

        g_samples = good_files[good_index:good_index + g_count]
        b_samples = bad_files[bad_index:bad_index + b_count]

        good_index += g_count
        bad_index += b_count

        client_dir = os.path.join(output_base_dir, f'client_{i+1}')
        g_client_dir = os.path.join(client_dir, 'good')
        b_client_dir = os.path.join(client_dir, 'bad')

        os.makedirs(g_client_dir, exist_ok=True)
        os.makedirs(b_client_dir, exist_ok=True)

        for f in g_samples:
            shutil.copy(f, os.path.join(g_client_dir, os.path.basename(f)))
        for f in b_samples:
            shutil.copy(f, os.path.join(b_client_dir, os.path.basename(f)))

        print(f"\nðŸ“¦ Client {i+1} Total: {len(g_samples) + len(b_samples)}")
        print(f"   â””â”€ Good files: {len(g_samples)}")
        print(f"   â””â”€ Bad  files: {len(b_samples)}")

    print(f"\nâœ… Successfully distributed {num_clients} clients with 4410 files each (non-IID).")

distribute_noniid_clients(
    good_dir="../../resources/material/train-data/augmented-good-material-taps",
    bad_dir="../../resources/material/train-data/augmented-bad-material-taps",
    output_base_dir="../../resources/material/train-data/federated/non-IID",
    num_clients=5,
    skew=[
        0.7, 0.3,   # Client 1 â†’ 70% good, 30% bad
        0.5, 0.5,   # Client 2 â†’ 50% good, 50% bad
        0.3, 0.7,   # Client 3 â†’ 30% good, 70% bad
        0.6, 0.4,   # Client 4 â†’ 60% good, 40% bad
        0.4, 0.6    # Client 5 â†’ 40% good, 60% bad
    ],
    seed=42
)


[INFO] Total Good Files: 11025
[INFO] Total Bad  Files: 11619
[INFO] Each client will receive: 4410 files

ðŸ“¦ Client 1 Total: 4410
   â””â”€ Good files: 3087
   â””â”€ Bad  files: 1323

ðŸ“¦ Client 2 Total: 4410
   â””â”€ Good files: 2205
   â””â”€ Bad  files: 2205

ðŸ“¦ Client 3 Total: 4410
   â””â”€ Good files: 1323
   â””â”€ Bad  files: 3087

ðŸ“¦ Client 4 Total: 4410
   â””â”€ Good files: 2646
   â””â”€ Bad  files: 1764

ðŸ“¦ Client 5 Total: 4410
   â””â”€ Good files: 1764
   â””â”€ Bad  files: 2646

âœ… Successfully distributed 5 clients with 4410 files each (non-IID).
