In [None]:
import numpy as np
import matplotlib.pyplot as plt
from argparse import ArgumentParser
import torch
from torch.utils.data import TensorDataset, DataLoader
from torchvision import datasets, transforms

from data.pipeline import get_data_raw

In [None]:
def merge(images, size):
    h, w = images.shape[1], images.shape[2]
    img = np.zeros((h * size[0], w * size[1], 3))
    for idx, image in enumerate(images):
        i = idx % size[1]
        j = idx // size[1]
        img[j*h:j*h+h, i*w:i*w+w, :] = image

    return img

In [None]:
parser = ArgumentParser()
parser.add_argument('--batch_size', default=64, type=int)
parser.add_argument('--hidden_dim', default=16, type=int)
parser.add_argument('--dataset', default='cifar', type=str)
parser.add_argument('--num_classes', default=10, type=int)
parser.add_argument('--party_data_size', default=4000, type=int)
parser.add_argument('--candidate_data_size', default=10000, type=int)
parser.add_argument('--split', default='equaldisjoint', type=str)
parser.add_argument('-f', type=str)
args = parser.parse_args()

In [None]:
party_datasets, party_labels, candidate_dataset, candidate_labels = get_data_raw(dataset=args.dataset,
                                                                                         num_classes=args.num_classes,
                                                                                         party_data_size=args.party_data_size,
                                                                                         candidate_data_size=args.candidate_data_size,
                                                                                         split=args.split)

In [None]:
party = 0

In [None]:
plt.imshow(merge(party_datasets[party][:64], [8,8]))

In [None]:
party_labels[party][:64]

In [None]:
plt.imshow(merge(candidate_dataset[:64], [8,8]))

In [None]:
candidate_labels[:64]

In [None]:
np.max(party_datasets)

In [None]:
np.min(party_datasets)

In [None]:
np.max(candidate_dataset)

In [None]:
np.min(candidate_labels)

In [None]:
num_channels = party_datasets.shape[-1]

In [None]:
np.mean(candidate_dataset.reshape([-1, num_channels]), axis=0)

In [None]:
combined = np.concatenate([np.concatenate(party_datasets), candidate_dataset])

In [None]:
combined.shape

In [None]:
means = np.mean(combined.reshape(-1, num_channels), axis=0)

In [None]:
means

In [None]:
stds = np.std(combined.reshape(-1, num_channels), axis=0)

In [None]:
stds

In [None]:
class ConcatDataset(torch.utils.data.Dataset):
    def __init__(self, *datasets):
        self.datasets = datasets

    def __getitem__(self, i):
        return tuple(d[i] for d in self.datasets)

    def __len__(self):
        return min(len(d) for d in self.datasets)

In [None]:
datasets = []
for i in range(len(party_datasets)):
    transformed = (party_datasets[i] - means) / stds
    dataset = TensorDataset(torch.tensor(transformed), torch.tensor(party_labels[i]))
    datasets.append(dataset)
transformed = (candidate_dataset - means) / stds
dataset = TensorDataset(torch.tensor(transformed), torch.tensor(candidate_labels))
datasets.append(dataset)

In [None]:
concat_dataset = ConcatDataset(*datasets)

In [None]:
loader = torch.utils.data.DataLoader(
            concat_dataset,
            batch_size=args.batch_size,
            shuffle=True,
            pin_memory=True)

In [None]:
my_iterator = iter(loader)

In [None]:
party = 5

In [None]:
batch_images, batch_labels = next(my_iterator)[party]

In [None]:
plt.imshow(merge(batch_images, [8, 8]))

In [None]:
batch_labels