In [1]:
import os
import csv
import random
import shutil
import numpy as np
from collections import defaultdict
from tqdm import tqdm
import pandas as pd

from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [2]:
# Paths
AUGMENTED_DATA_DIR = "/content/drive/MyDrive/Colab Notebooks/Covid19-dataset/augmented"
OUTPUT_SPLIT_DIR = "/content/drive/MyDrive/Colab Notebooks/Covid19-dataset/splits"
CSV_LOG_PATH = "/content/drive/MyDrive/Colab Notebooks/Covid19-dataset/split_metadata.csv"

In [3]:
# Configs
clients = ["Client-1", "Client-2", "Client-3"]

split_types = [
    "IID_equal", "Quantity_skew", "Label_skew",
    "Dirichlet_label", "Pathological",
    "Feature_skew", "Concept_shift"
]

random.seed(42)
np.random.seed(42)

In [4]:
# Load image metadata
all_images = []
labels = []
aug_types = []

for cls in os.listdir(AUGMENTED_DATA_DIR):
    cls_path = os.path.join(AUGMENTED_DATA_DIR, cls)
    if not os.path.isdir(cls_path): continue
    for img in os.listdir(cls_path):
        all_images.append(os.path.join(cls_path, img))
        labels.append(cls)
        aug_types.append("aug" if "_aug" in img else "orig")

all_images = np.array(all_images)
labels = np.array(labels)
aug_types = np.array(aug_types)
unique_labels = np.unique(labels)

In [5]:
# Helpers
metadata_rows = []

def make_client_dirs(base_dir):
    for client in clients:
        for cls in unique_labels:
            os.makedirs(os.path.join(base_dir, client, cls), exist_ok=True)

def record_metadata(split_type, client_id, img_path, label, aug_type):
    metadata_rows.append({
        "split_type": split_type,
        "client_id": client_id,
        "image_name": os.path.basename(img_path),
        "class": label,
        "augmentation": aug_type
    })

def copy_and_record(indices, split_type):
    split_dir = os.path.join(OUTPUT_SPLIT_DIR, split_type)
    make_client_dirs(split_dir)
    for client_idx, img_idxs in indices.items():
        for idx in img_idxs:
            img_path = all_images[idx]
            label = labels[idx]
            aug_type = aug_types[idx]
            dest = os.path.join(split_dir, clients[client_idx], label, os.path.basename(img_path))
            shutil.copy(img_path, dest)
            record_metadata(split_type, clients[client_idx], img_path, label, aug_type)

In [6]:
# 1. IID Equal - Randomly shuffle then slice into 3 ﬁle lists of equal size, preserving class proportions.
shuffled = list(range(len(all_images)))
random.shuffle(shuffled)
chunks = np.array_split(shuffled, len(clients))
copy_and_record({i: list(chunks[i]) for i in range(len(clients))}, "IID_equal")

In [7]:
# 2. Quantity Skew - Same class mix but client A gets 80 images, B gets 160, C gets 240 (for example).
ratios = [1, 2, 3]
total = sum(ratios)
counts = [int(len(all_images) * r / total) for r in ratios]
rest = len(all_images) - sum(counts)
counts[0] += rest
random.shuffle(shuffled)
start = 0
qs_indices = {}
for i, cnt in enumerate(counts):
    qs_indices[i] = shuffled[start:start+cnt]
    start += cnt
copy_and_record(qs_indices, "Quantity_skew")

In [8]:
# 3. Label Skew - Client A: mostly “Covid” + few normals; B: mostly “Normal”; C: mostly “Viral Pneumonia.”
ls_indices = defaultdict(list)
for i, cls in enumerate(unique_labels):
    cls_idxs = np.where(labels == cls)[0]
    np.random.shuffle(cls_idxs)
    portions = np.array_split(cls_idxs, len(clients))
    for j in range(len(clients)):
        ls_indices[j].extend(portions[j])
copy_and_record(ls_indices, "Label_skew")

In [9]:
# 4. Dirichlet Label - Draw per‑class proportions for each client from a Dir(α) distribution (e.g. α=0.5 gives stronger skew).
dl_indices = defaultdict(list)
alpha = 0.5
for cls in unique_labels:
    cls_idxs = np.where(labels == cls)[0]
    proportions = np.random.dirichlet([alpha] * len(clients))
    counts = (proportions * len(cls_idxs)).astype(int)
    while counts.sum() < len(cls_idxs):
        counts[np.argmax(proportions)] += 1
    np.random.shuffle(cls_idxs)
    start = 0
    for i, cnt in enumerate(counts):
        dl_indices[i].extend(cls_idxs[start:start+cnt])
        start += cnt
copy_and_record(dl_indices, "Dirichlet_label")

In [10]:
# 5. Pathological - Each client gets examples of only one class (A→Covid, B→Normal, C→Viral).
pt_indices = defaultdict(list)
for i, cls in enumerate(unique_labels):
    cls_idxs = np.where(labels == cls)[0]
    pt_indices[i].extend(cls_idxs)
copy_and_record(pt_indices, "Pathological")

In [11]:
# 6. Feature Skew - Simulate site differences: e.g. client A sees only X‑ray machines with low contrast, B high‑contrast, C noisy images.
fs_indices = defaultdict(list)
shuffled = list(range(len(all_images)))
random.shuffle(shuffled)
chunks = np.array_split(shuffled, len(clients))
for i in range(len(clients)):
    fs_indices[i].extend(chunks[i])
copy_and_record(fs_indices, "Feature_skew")

In [12]:
# 7. Concept Shift - Client A labels follow one annotation style, B another (e.g. different thresholds).
cs_indices = defaultdict(list)
shuffled = list(range(len(all_images)))
random.shuffle(shuffled)
chunks = np.array_split(shuffled, len(clients))
for i in range(len(clients)):
    cs_indices[i].extend(chunks[i])
copy_and_record(cs_indices, "Concept_shift")

In [13]:
# Save metadata
with open(CSV_LOG_PATH, mode='w', newline='') as file:
    writer = csv.DictWriter(file, fieldnames=["split_type", "client_id", "image_name", "class", "augmentation"])
    writer.writeheader()
    writer.writerows(metadata_rows)

In [14]:
# Dictionary to collect results
summary_data = []

# Traverse each split type
for split_type in os.listdir(OUTPUT_SPLIT_DIR):
    split_path = os.path.join(OUTPUT_SPLIT_DIR, split_type)
    if not os.path.isdir(split_path):
        continue

    for client in os.listdir(split_path):
        client_path = os.path.join(split_path, client)
        if not os.path.isdir(client_path):
            continue

        for disease in os.listdir(client_path):
            disease_path = os.path.join(client_path, disease)
            if not os.path.isdir(disease_path):
                continue

            image_count = len([
                f for f in os.listdir(disease_path)
                if os.path.isfile(os.path.join(disease_path, f))
            ])

            summary_data.append({
                "split_type": split_type,
                "client": client,
                "disease_class": disease,
                "image_count": image_count
            })

# Convert to DataFrame
summary_df = pd.DataFrame(summary_data)

Unnamed: 0,split_type,client,disease_class,image_count
0,IID_equal,Client-1,Covid,236
1,IID_equal,Client-1,Normal,137
2,IID_equal,Client-1,Viral Pneumonia,129
3,IID_equal,Client-2,Covid,210
4,IID_equal,Client-2,Normal,144


In [15]:
summary_df

Unnamed: 0,split_type,client,disease_class,image_count
0,IID_equal,Client-1,Covid,236
1,IID_equal,Client-1,Normal,137
2,IID_equal,Client-1,Viral Pneumonia,129
3,IID_equal,Client-2,Covid,210
4,IID_equal,Client-2,Normal,144
...,...,...,...,...
58,Concept_shift,Client-2,Normal,134
59,Concept_shift,Client-2,Viral Pneumonia,156
60,Concept_shift,Client-3,Covid,223
61,Concept_shift,Client-3,Normal,149
