In [1]:
import numpy as np
import tensorflow as tf
import logging
import os
import random
import torch
from abc import ABC, abstractmethod
import time
from sklearn.manifold import TSNE
from torch.utils.data import Dataset
from torchvision import datasets, transforms

In [2]:
# test

In [2]:
# Data Stuff
# Set download to False if already downloaded

transform = transforms.Compose([transforms.ToTensor()])

# Load the MNIST dataset 
mnist_trainset = datasets.MNIST(root='./data', train=True, download=False, transform=transform)
mnist_testset = datasets.MNIST(root='./data', train=False, download=False, transform=transform)

print(f"Data: {mnist_trainset.data.shape}, Targets: {mnist_trainset.targets.shape}")

Data: torch.Size([60000, 28, 28]), Targets: torch.Size([60000])


# Help Class

In [3]:
import numpy as np

def dirichlet_partition(dataset, num_partitions, num_classes, alpha=0.5, seed=42):
    """
    Partition the dataset into multiple subsets using a Dirichlet distribution.

    Args:
        dataset (torch.utils.data.Dataset): The dataset to partition. It should have
                                            'data' and 'targets' attributes.
        num_partitions (int): Number of partitions (clients) to divide the dataset into.
        num_classes (int): Number of unique classes in the dataset.
        alpha (float): Dirichlet concentration parameter (controls imbalance).
        seed (int): Random seed for reproducibility.

    Returns:
        dict: A dictionary where keys are partition indices (0 to num_partitions-1)
              and values are lists of indices corresponding to the samples in each partition.
    """
    np.random.seed(seed)

    # Extract labels
    if isinstance(dataset.targets, np.ndarray):
        y_train = dataset.targets
    elif hasattr(dataset.targets, "numpy"):  # For torch.Tensor
        y_train = dataset.targets.numpy()
    else:
        y_train = np.asarray(dataset.targets)

    min_size = 0
    K = np.unique(y_train)
    N = y_train.shape[0]  # Total number of samples
    net_dataidx_map = {}

    # Ensure minimum size of partition
    while min_size < 10:
        idx_batch = [[] for _ in range(num_partitions)]
        for k in K:
            idx_k = np.where(y_train == k)[0]
            np.random.shuffle(idx_k)

            if len(idx_k) > 0:
                proportions = np.random.dirichlet(np.repeat(alpha, num_partitions))
                proportions = np.array([p * (len(idx_j) < N / num_partitions) for p, idx_j in zip(proportions, idx_batch)])
                proportions = proportions / proportions.sum()
                proportions = (np.cumsum(proportions) * len(idx_k)).astype(int)[:-1]
                idx_batch = [idx_j + idx.tolist() for idx_j, idx in zip(idx_batch, np.split(idx_k, proportions))]

        min_size = min([len(idx_j) for idx_j in idx_batch])

    for j in range(num_partitions):
        np.random.shuffle(idx_batch[j])
        net_dataidx_map[j] = idx_batch[j]

    return net_dataidx_map


# Initial Setup

In [4]:
dataset = mnist_trainset

# Number of clients (participants in federated learning)
num_clients = 10
# Number of unique classes in MNIST
num_classes = 10

In [5]:
partitions = dirichlet_partition(dataset, num_partitions=num_clients, num_classes=num_classes, alpha=0.5, seed=42)

# Check partitions
x = 0
for i, partition in partitions.items():
    x = x + len(partition)
    print(f"Client {i}: {len(partition)} samples")
print(x)


Client 0: 5744 samples
Client 1: 7794 samples
Client 2: 4667 samples
Client 3: 2937 samples
Client 4: 7538 samples
Client 5: 6137 samples
Client 6: 6201 samples
Client 7: 7199 samples
Client 8: 5441 samples
Client 9: 6342 samples


60000
