In [None]:
import os
import shutil
import random
import math

#### Data Partitioning

In [None]:
dir_train_images = "../../data/inputs/kvasir/train/images"
dir_train_masks = "../../data/inputs/kvasir/train/masks"

# Get both mask and image names
names = os.listdir(dir_train_images)
names_cvccolondb = [name for name in names if name.split('.')[0].isdigit()]
names_kvasir = [name for name in names if name not in names_cvccolondb]

## Distribute the Data
fl_clients = [2, 4, 6 ,8]
random.seed(64)
sample_cvccolondb = random.sample(range(0, len(names_cvccolondb)), len(names_cvccolondb))
sample_kvasir = random.sample(range(0, len(names_kvasir)), len(names_kvasir))
sample_map = {}

for n_clients in fl_clients:
    # Number of samples per partition
    n_samples_cvccolondb = math.floor((len(names_cvccolondb) * 2) / n_clients)
    n_samples_kvasir = math.floor((len(names_kvasir) * 2) / n_clients)
    # Number of partitions per dataset
    n_partitions_dataset = int(n_clients / 2)

    ix_cvccolondb = 0
    ix_kvasir = 0

    partitions = []

    for partition in range(1, n_partitions_dataset + 1):
        # Get the dataset keys for each partition. If it is the last partition put all leftover samples in that partition.
        if(partition != n_partitions_dataset):
            keys_cvccolondb = sample_cvccolondb[ix_cvccolondb: ix_cvccolondb + n_samples_cvccolondb]
            keys_kvasir = sample_kvasir[ix_kvasir: ix_kvasir + n_samples_kvasir]
        else:
            keys_cvccolondb = sample_cvccolondb[ix_cvccolondb: len(sample_cvccolondb)]
            keys_kvasir = sample_kvasir[ix_kvasir: len(sample_kvasir)]

        # Get the names of the images
        partition_cvccolondb = [names_cvccolondb[key] for key in keys_cvccolondb]
        partition_kvasir = [names_kvasir[key] for key in keys_kvasir]

        # Append the samples
        partitions.append(partition_cvccolondb)
        partitions.append(partition_kvasir)
        
        # Update the keys for the next sample
        ix_cvccolondb += n_samples_cvccolondb
        ix_kvasir += n_samples_kvasir

    sample_map[n_clients] = partitions

#### Export Data to Disk

In [None]:
## Export the Data
output_dir = '../../data/inputs/kvasir_federated'

# Clean the directory first
if os.path.exists(output_dir): 
    shutil.rmtree(output_dir)
    
os.mkdir(output_dir)

# Creates all the directories and fills them with the federated clients data
for partition_conf in sample_map:
    partitions = sample_map[partition_conf]

    for n_partition, partition in enumerate(partitions):
        # If a folder to hold this client images and mask hasn't been created, create it
        partition_dir = f'{output_dir}/{partition_conf}_flclients/flclient_{n_partition + 1}'

        if not os.path.exists(partition_dir):
            os.makedirs(f'{partition_dir}/images')
            os.makedirs(f'{partition_dir}/masks')
            
        images_src_path = [f'{dir_train_images}/{partition}' for partition in partitions[n_partition]]
        masks_src_path = [f'{dir_train_masks}/{partition}' for partition in partitions[n_partition]]
        images_dest_path = [f'{partition_dir}/images/{partition}' for partition in partitions[n_partition]]
        masks_dest_path = [f'{partition_dir}/masks/{partition}' for partition in partitions[n_partition]]

        for item in range(0, len(images_src_path)):
            shutil.copy(images_src_path[item], images_dest_path[item])

        for item in range(0, len(masks_src_path)):
            shutil.copy(masks_src_path[item], masks_dest_path[item])

#### Analysis of Results

In [None]:
client_configs = os.listdir(output_dir)

for config in client_configs:
    clients = os.listdir(f'{output_dir}/{config}')
    imgs = []
    masks = []
    tot_images = 0
    tot_masks = 0
    image_client_str = ''
    mask_client_str = ''

    for client in clients:
        imgs_client = os.listdir(f'{output_dir}/{config}/{client}/images')
        masks_client = os.listdir(f'{output_dir}/{config}/{client}/masks')
        imgs += imgs_client
        masks += masks_client

        n_images = len(imgs_client)
        n_masks = len(masks_client)
        tot_images += n_images
        tot_masks += n_masks 

        image_client_str += f'{client.split("_")[1]}: {n_images}, '
        mask_client_str += f'{client.split("_")[1]}: {n_masks}, '

    print(f'Client Config: {config} has images: {tot_images} and masks: {tot_masks}')
    print(f'Unique Images: {len(set(imgs))}, Unique Masks: {len(set(masks))}')
    print(f'Images per Client: {image_client_str}')
    print(f'Masks per Client: {mask_client_str}')
    print()