In [1]:
import torch
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, Dataset
import numpy as np

In [2]:
# IID function

def mnistIID(dataset, num_users):
    num_images = int(len(dataset) / num_users)
    users_dict, indices = {}, [i for i in range(len(dataset))]
    for i in range(num_users):
        np.random.seed(i) 
        users_dict[i] = set(np.random.choice(indices, num_images, replace=False))
        indices = list(set(indices) - users_dict[i])
    return users_dict

In [3]:
# non-IID function

def mnistNonIID(dataset, num_users):
    classes, images = 100, 600
    classes_indx = [i for i in range(classes)]
    indices = np.arrange(classes * images) 
    users_dict = { i:np.array([]) for i in range(num_users)}
    unsorted_labels = dataset.train_labels.numpy()
    
    indices_unlabels = np.vstack(indices, unsorted_labels)
    labels = indices_unlabels[:, indices_unlabels[1,:].argsort()]
    indices = labels[0, :]
    
    for i in range(num_users):
        temp = set(np.random.choice(classes_indx, 2, replace=False))
        classes_indx = list(set(classes_indx) - temp)
        
        for i in temp:
            users_dict[i] = np.concatenate(
            (users_dict[i], indices[t*images:(t+1)*images]), axis = 0)
    return users_dict

In [4]:
# non-IIDUnequal function

def mnistNonIIDUnequal(dataset, num_users):
    classes, images = 1200, 50
    classes_indx = [i for i in range(classes)]
    indices = np.arrange(classes * images)
    users_dict = { i:np.array([]) for i in range(num_users)}
    labels = dataset.train_labels.numpy()
    
    indices_labels = np.vstack(indices, labels)
    indices_labels = indices_labels[:, indices_labels[1,:].argsort()]
    indices = indices_labels[0, :]
    
    min_cls_per_client = 1
    max_cls_per_client = 30
    
    random_selected_classes = np.random.tint(min_cls_per_client, max_cls_per_client+1, size=num_users)
    random_selected_classes = np.around(random_selected_classes / sum(random_selected_classes) * classes)
    random_selected_classes = random_selected_classes.astype(int)
    
    if sum(random_selected_classes)> classes:
        
        for i in range(num_users):
            temp = set(np.random.choice(classes_indx, 1, replace=False))
            classes_indx = list(set(classes_indx) - temp) 
            for t in temp:
                users_dict[i] = np.concatenate((users_dict[i], indices[t*images:(t+1)*images]), axis = 0)
            
        random_selected_classes = random_selected_classes - 1
        
        for i in range(num_users):
            if len(classes_indx) == 0:
                continue
            class_size = random_selected_classes[i]    
            
            if class_size > len(classes_indx):
                class_size = len(class_indx)
            
            temp = set(np.random.choice(classes_indx, class_size, replace=False))
            classes_indx = list(set(classes_indx) - temp)
            for t in temp:
                users_dict[i] = np.concatenate((users_dict[i], indices[t*images:(t+1)*images]), axis = 0)
            
    else:
        
        for i in range(num_users):
            class_size = random_selected_classes[i]
            temp = set(np.random.choice(classes_indx, class_size, replace=False))
            classes_indx = list(set(classes_indx) - temp)
            for t in temp:
                users_dict[i] = np.concatenate((users_dict[i], indices[t*images:(t+1)*images]), axis = 0)
    
        
        if len(classes_indx) > 0:
            class_size = len(classes_indx)
            k = min(users_dict, key=lambda x: len(users_dict.get(x)))
            temp = set(np.random.choice(classes_indx, class_size, replace=False))
            classes_indx = list(set(classes_indx) - temp)
            for t in temp:
                users_dict[k] = np.concatenate((users_dict[i], indices[t*images:(t+1)*images]), axis = 0)
    
    return users_dict

In [5]:


def load_dataset(num_users, iidtype):
    
    transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])

    train_dataset = datasets.MNIST('./data', train=True, download=True, transform=transform)

    test_dataset = datasets.MNIST('./data', train=False, download=True, transform=transform)
    
    train_group, test_group = None, None
    if iidtype== 'iid':
        train_group = mnistIID(train_dataset, num_users)
        test_group = mnistIID(test_dataset, num_users)
    elif iidtype== 'noniid':
        train_group = mnistNonIID(train_dataset, num_users)
        test_group = mnistNonIID(test_dataset, num_users)
    else:
        train_group = mnistNonIIDUnequal(train_dataset, num_users)
        test_group = mnistNonIIDUnequal(test_dataset, num_users)
    
    return train_dataset, test_dataset, train_group, test_group

In [6]:
class FedDataset(Dataset):
    def __init__(self, dataset, indx):
        self.dataset = dataset
        self.indx = [int(i) for i in indx]
    
    def __len__(self):
        return len(self.indx)
    
    def __getitem__(self, item):
        images, label = self.dataset[self.indx[item]]
        return torch.tensor(images).clone().detach(), torch.tensor(label).clone().detach()

In [7]:
def getActualImages(dataset, indices, batch_size):
    return DataLoader(FedDataset(dataset, indices), batch_size=batch_size, shuffle=True)