In [None]:
# Import libraries
import torch
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, Subset
import matplotlib.pyplot as plt

# Import custom data splitting module
import data_splitting as ds  # Assuming the file is named 'data_splitting.py'

# CIFAR-100 Dataset Preparation
transform = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.5071, 0.4865, 0.4409), (0.2673, 0.2564, 0.2761))
])

train_dataset = datasets.CIFAR100(root='./data', train=True, download=True, transform=transform)

# IID Sharding
num_clients = 100
iid_clients = ds.iid_sharding(train_dataset, num_clients)
print("IID Sharding Completed.")

# Non-IID Sharding
Nc = 5  # Number of classes per client
non_iid_clients = ds.non_iid_sharding(train_dataset, num_clients, Nc)
print("Non-IID Sharding Completed.")

# Visualize Data Distribution (Optional)
def plot_data_distribution(client_data, num_clients, title):
    client_counts = [len(indices) for indices in client_data.values()]
    plt.figure(figsize=(10, 6))
    plt.bar(range(num_clients), client_counts)
    plt.xlabel("Client ID")
    plt.ylabel("Number of Samples")
    plt.title(title)
    plt.show()

plot_data_distribution(iid_clients, num_clients, "IID Data Distribution")
plot_data_distribution(non_iid_clients, num_clients, "Non-IID Data Distribution")
