In [None]:
import torch
import random
from torchvision import datasets, transforms
import matplotlib.pyplot as plt
from collections import defaultdict

# Step 1: Define the IID and Non-IID Splitting Functions

def iid_split(dataset, num_clients):
    num_samples = len(dataset)
    samples_per_client = num_samples // num_clients
    indices = torch.randperm(num_samples).tolist()
    client_data = {i: indices[i * samples_per_client:(i + 1) * samples_per_client] for i in range(num_clients)}
    return client_data

def non_iid_split(dataset, num_clients, num_classes_per_client):
    class_indices = defaultdict(list)
    for idx, (_, label) in enumerate(dataset):
        class_indices[label].append(idx)

    all_classes = list(class_indices.keys())
    client_data = defaultdict(list)
    for client_id in range(num_clients):
        selected_classes = random.sample(all_classes, num_classes_per_client)
        for cls in selected_classes:
            data_per_client = len(class_indices[cls]) // num_clients
            client_data[client_id].extend(random.sample(class_indices[cls], data_per_client))
    return client_data

# Step 2: Visualizing the Split Data
def visualize_client_data(client_data, num_clients, dataset, title="Data Distribution"):
    # Number of samples per client
    client_sample_counts = [len(client_data[client]) for client in range(num_clients)]
    plt.figure(figsize=(10, 6))
    plt.bar(range(num_clients), client_sample_counts, color='skyblue')
    plt.xlabel("Client ID")
    plt.ylabel("Number of Samples")
    plt.title(f"{title}: Samples Per Client")
    plt.show()

    # Class distribution per client
    class_distribution = defaultdict(lambda: defaultdict(int))
    for client_id, indices in client_data.items():
        for idx in indices:
            _, label = dataset[idx]
            class_distribution[client_id][label] += 1

    plt.figure(figsize=(15, 6))
    for client_id, dist in class_distribution.items():
        classes, counts = zip(*sorted(dist.items()))
        plt.bar(classes, counts, alpha=0.6, label=f"Client {client_id}")
    plt.xlabel("Class Label")
    plt.ylabel("Number of Samples")
    plt.title(f"{title}: Class Distribution Per Client")
    plt.legend(loc="upper right", bbox_to_anchor=(1.15, 1.0), ncol=1, fontsize='small')
    plt.show()

# Step 3: Load CIFAR-100 Dataset
transform = transforms.Compose([transforms.ToTensor()])
dataset = datasets.CIFAR100(root='./data', train=True, download=True, transform=transform)

# Step 4: Generate IID and Non-IID Splits
num_clients = 10
iid_clients = iid_split(dataset, num_clients)
non_iid_clients = non_iid_split(dataset, num_clients, num_classes_per_client=2)

# Step 5: Visualize the Data Splits
visualize_client_data(iid_clients, num_clients, dataset, title="IID Split")
visualize_client_data(non_iid_clients, num_clients, dataset, title="Non-IID Split")
