In [90]:
import h5py
import numpy as np
import matplotlib.pyplot as plt

from torch.utils.data import Dataset, DataLoader
from torchvision.transforms import ToTensor, Lambda
import torch

In [91]:
path = "datasets\FEMNIST_by_write\write_all_labels.hdf5"
binary_data_file = h5py.File(path, "r")

writers = sorted(binary_data_file.keys())
print(f'datasets contains images from {len(writers)} writers')

datasets contains images from 3597 writers


In [92]:
dic_train_indices = dict()
dic_test_indices = dict()

In [93]:
for writer in writers:
    labels = binary_data_file[writer]['labels']

    lst_range = np.arange(0, len(labels))
    lst_random = np.random.permutation(lst_range)
    test_indices = lst_random[: int(len(lst_random)*0.1)]
    train_indices = list(filter(lambda i: i not in test_indices, lst_range))

    dic_train_indices[writer] = train_indices
    dic_test_indices[writer] = test_indices

In [95]:
class CustomDataset(Dataset):
    def __init__(self, writer, transform=None, target_transform=None):
        self.images = binary_data_file[writer]['images'][sorted(dic_train_indices[writer])]
        self.labels = binary_data_file[writer]['labels'][sorted(dic_train_indices[writer])]
        self.transform = transform
        self.target_transform = target_transform

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, index):
        img = self.images[index]
        label = self.labels[index]

        if self.transform:
            img = self.transform(img)

        if self.target_transform:
            label = self.target_transform(label)
        
        return img, label

In [None]:
client_datasets = list()

for writer in writers:
    customDataset = CustomDataset(writer,
                                  transform=ToTensor(),
                                  target_transform=Lambda(lambda y: torch.zeros(62)
                                            .scatter_(dim=0, index=torch.tensor(y, dtype=torch.int64), value=1))
                                )

    client_datasets.append(customDataset)

In [88]:
print(len(customDataset))
print(customDataset[0][0].shape)
print(customDataset[0][1])
print(customDataset[0][0])

344
torch.Size([1, 28, 28])
tensor([1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0.])
tensor([[[1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
          1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
          1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
          1.0000, 1.0000, 1.0000, 1.0000],
         [1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
          1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
          1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
          1.0000, 1.0000, 1.0000, 1.0000],
         [1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
          1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
          1.0

In [89]:
binary_data_file[writer]['labels'][sorted(test_indices)]

array([ 0,  1,  2,  3,  4,  5,  6,  6,  8,  8,  8,  9,  9, 12, 18, 22, 23],
      dtype=uint8)