In [None]:
import numpy as np
import matplotlib.pyplot as plt
from argparse import ArgumentParser
import torch
from torch.utils.data import TensorDataset, DataLoader
from torchvision import datasets, transforms

from data.pipeline import get_data_raw

In [None]:
from torch import nn, tanh
from torch.nn import Linear, Conv2d, ConvTranspose2d, BatchNorm2d, ReLU, LeakyReLU

In [None]:
input_dim = 32
num_channels = 3
hidden_dim = 16

In [None]:
inp = torch.randn(4, num_channels, input_dim, input_dim)

In [None]:
conv1 = Conv2d(in_channels=num_channels, out_channels=64, kernel_size=5, stride=2, padding=2)

In [None]:
torch.std(conv1(inp))

In [None]:
conv1(inp).shape

In [None]:
conv1_bn = BatchNorm2d(64)

In [None]:
torch.mean(conv1_bn(conv1(inp)))

In [None]:
conv2 = Conv2d(in_channels=64, out_channels=128, kernel_size=5, stride=2, padding=2)

In [None]:
conv2(conv1(inp)).shape

In [None]:
conv3 = Conv2d(in_channels=128, out_channels=256, kernel_size=5, stride=2, padding=2)

In [None]:
conv3(conv2(conv1(inp))).shape

In [None]:
conv4 = Conv2d(in_channels=256, out_channels=512, kernel_size=5, stride=2, padding=2)

In [None]:
conv4(conv3(conv2(conv1(inp)))).shape

In [None]:
that_size = conv4(conv3(conv2(conv1(inp)))).shape

In [None]:
fc1 = Linear(2048, 16)

In [None]:
fc1(conv4(conv3(conv2(conv1(inp)))).view((4, -1))).shape

In [None]:
class CustomView(nn.Module):  # Flattening layer for nn.Sequential
    def __init__(self, shape):
        super(CustomView, self).__init__()
        self.shape = shape

    def forward(self, x):
        return x.view(*self.shape)

In [None]:
encoder = nn.Sequential(
            nn.Conv2d(in_channels=num_channels, out_channels=64, kernel_size=5, stride=2, padding=2),
            nn.LeakyReLU(),
            nn.BatchNorm2d(64),
            nn.Conv2d(in_channels=64, out_channels=128, kernel_size=5, stride=2, padding=2),
            nn.LeakyReLU(),
            nn.BatchNorm2d(128),
            nn.Conv2d(in_channels=128, out_channels=256, kernel_size=5, stride=2, padding=2),
            nn.LeakyReLU(),
            nn.BatchNorm2d(256),
            nn.Conv2d(in_channels=256, out_channels=512, kernel_size=5, stride=2, padding=2),
            nn.LeakyReLU(),
            nn.BatchNorm2d(512),
            CustomView((-1, 2048)),
            nn.Linear(2048, hidden_dim)
        )

In [None]:
encoder(inp)

In [None]:
decoder = nn.Sequential(
            nn.Linear(hidden_dim, 2048),
            nn.ReLU(),
            nn.BatchNorm1d(2048),
            CustomView((-1, 512, 2, 2)),
            nn.ConvTranspose2d(512, 256, kernel_size=5, stride=2, padding=2, output_padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(256),
            nn.ConvTranspose2d(256, 128, kernel_size=5, stride=2, padding=2, output_padding=0 if input_dim==28 else 1),
            nn.ReLU(),
            nn.BatchNorm2d(128),
            nn.ConvTranspose2d(128, 64, kernel_size=5, stride=2, padding=2, output_padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(64),
            nn.ConvTranspose2d(64, num_channels, kernel_size=5, stride=2, padding=2, output_padding=1),
            nn.Tanh()
)

In [None]:
decoder(encoder(inp)).shape#.permute((0, 2, 3, 1))

In [None]:
means = (1, 2, 3)

In [None]:
np.expand_dims(means, axis=[1, 2]).shape

In [None]:
means = np.expand_dims(means, axis=[1, 2])

In [None]:
decoder(encoder(inp)).detach().numpy() * means

In [None]:
def merge(images, size):
    h, w = images.shape[1], images.shape[2]
    img = np.zeros((h * size[0], w * size[1], 3))
    for idx, image in enumerate(images):
        i = idx % size[1]
        j = idx // size[1]
        img[j*h:j*h+h, i*w:i*w+w, :] = image

    return img

In [None]:
parser = ArgumentParser()
parser.add_argument('--batch_size', default=64, type=int)
parser.add_argument('--hidden_dim', default=16, type=int)
parser.add_argument('--dataset', default='cifar', type=str)
parser.add_argument('--num_classes', default=10, type=int)
parser.add_argument('--party_data_size', default=4000, type=int)
parser.add_argument('--candidate_data_size', default=10000, type=int)
parser.add_argument('--split', default='equaldisjoint', type=str)
parser.add_argument('-f', type=str)
args = parser.parse_args()

In [None]:
party_datasets, party_labels, candidate_dataset, candidate_labels = get_data_raw(dataset=args.dataset,
                                                                                         num_classes=args.num_classes,
                                                                                         party_data_size=args.party_data_size,
                                                                                         candidate_data_size=args.candidate_data_size,
                                                                                         split=args.split)

In [None]:
party = 0

In [None]:
plt.imshow(merge(party_datasets[party][:64], [8,8]))

In [None]:
party_labels[party][:64]

In [None]:
plt.imshow(merge(candidate_dataset[:64], [8,8]))

In [None]:
candidate_labels[:64]

In [None]:
np.max(party_datasets)

In [None]:
np.min(party_datasets)

In [None]:
np.max(candidate_dataset)

In [None]:
np.min(candidate_labels)

In [None]:
num_channels = party_datasets.shape[-1]

In [None]:
np.mean(candidate_dataset.reshape([-1, num_channels]), axis=0)

In [None]:
combined = np.concatenate([np.concatenate(party_datasets), candidate_dataset])

In [None]:
combined.shape

In [None]:
means = np.mean(combined.reshape(-1, num_channels), axis=0)

In [None]:
means

In [None]:
stds = np.std(combined.reshape(-1, num_channels), axis=0)

In [None]:
stds

In [None]:
class ConcatDataset(torch.utils.data.Dataset):
    def __init__(self, *datasets):
        self.datasets = datasets

    def __getitem__(self, i):
        return tuple(d[i] for d in self.datasets)

    def __len__(self):
        return min(len(d) for d in self.datasets)

In [None]:
datasets = []
for i in range(len(party_datasets)):
    transformed = (party_datasets[i] - means) / stds
    dataset = TensorDataset(torch.tensor(transformed), torch.tensor(party_labels[i]))
    datasets.append(dataset)
transformed = (candidate_dataset - means) / stds
dataset = TensorDataset(torch.tensor(transformed), torch.tensor(candidate_labels))
datasets.append(dataset)

In [None]:
concat_dataset = ConcatDataset(*datasets)

In [None]:
loader = torch.utils.data.DataLoader(
            concat_dataset,
            batch_size=args.batch_size,
            shuffle=True,
            pin_memory=True)

In [None]:
my_iterator = iter(loader)

In [None]:
party = 5

In [None]:
batch_images, batch_labels = next(my_iterator)[party]

In [None]:
plt.imshow(merge(batch_images, [8, 8]))

In [None]:
batch_labels