In [143]:
import h5py
import numpy as np

from torch.utils.data import Dataset
from torchvision.transforms import ToTensor, Lambda
import torch

In [153]:
path = "datasets\FEMNIST_by_write\write_all_labels.hdf5"
binary_data_file = h5py.File(path, "r")

writers = sorted(binary_data_file.keys())

FEMNIST dataset contains images from 3597 writers


In [146]:
dic_train_indices = dict()
dic_test_indices = dict()

for writer in writers:
    labels = binary_data_file[writer]['labels']

    lst_range = np.arange(0, len(labels))
    lst_random = np.random.permutation(lst_range)
    test_indices = lst_random[: int(len(lst_random)*0.1)]
    train_indices = list(filter(lambda i: i not in test_indices, lst_range))

    dic_train_indices[writer] = train_indices
    dic_test_indices[writer] = test_indices

In [147]:
class FEMNIST_Test_Dataset(Dataset):
    def __init__(self, transform=None, target_transform=None):
        self.images = []
        self.labels = []
        for writer in writers:
            self.images.extend(binary_data_file[writer]['images'][sorted(dic_test_indices[writer])])
            self.labels.extend(binary_data_file[writer]['labels'][sorted(dic_test_indices[writer])])

        self.transform = transform
        self.target_transform = target_transform

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, index):
        img = self.images[index]
        label = self.labels[index]

        if self.transform:
            img = self.transform(img)

        if self.target_transform:
            label = self.target_transform(label)
        
        return img, label

In [148]:
class FEMNIST_Train_Dataset(Dataset):
    def __init__(self, writer, transform=None, target_transform=None):
        self.images = binary_data_file[writer]['images'][sorted(dic_train_indices[writer])]
        self.labels = binary_data_file[writer]['labels'][sorted(dic_train_indices[writer])]
        self.transform = transform
        self.target_transform = target_transform

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, index):
        img = self.images[index]
        label = self.labels[index]

        if self.transform:
            img = self.transform(img)

        if self.target_transform:
            label = self.target_transform(label)
        
        return img, label

In [150]:
client_datasets = list()

for writer in writers:
    client_dataset = FEMNIST_Train_Dataset(writer,
                                           transform=ToTensor(),
                                           target_transform=Lambda(lambda y: torch.zeros(62)
                                                 .scatter_(dim=0, index=torch.tensor(y, dtype=torch.int64), value=1))
                                        )

    client_datasets.append(client_dataset)

In [151]:
test_dataset = FEMNIST_Test_Dataset(transform=ToTensor(),
                                    target_transform=Lambda(lambda y: torch.zeros(62)
                                          .scatter_(dim=0, index=torch.tensor(y, dtype=torch.int64), value=1))
                                    )