In [1]:
import numpy as np
import matplotlib.pyplot as plt
from argparse import ArgumentParser
import gpytorch
from scipy.optimize import linprog

In [2]:
from torch import nn, tanh
from torch.nn import Linear, Conv2d, ConvTranspose2d, BatchNorm2d, ReLU, LeakyReLU
import torch
from torch.utils.data import TensorDataset, DataLoader
from torchvision import datasets, transforms

In [4]:
from core.mmd import mmd_neg_biased, mmd_neg_unbiased
from data.pipeline import get_data_features
from core.kernel import get_kernel
from core.reward_calculation import get_v
from data.pipeline import get_data_raw

## Check data

In [10]:
dataset = 'cifar'
num_classes = 10
d = 8
num_parties = 5
party_data_size = 5000
candidate_data_size = 100000
split = 'equaldisjoint'

In [11]:
party_datasets, party_labels, reference_dataset, candidate_datasets, candidate_labels = get_data_features(dataset,
                                                                                            num_classes,
                                                                                            d,
                                                                                            num_parties,
                                                                                            party_data_size,
                                                                                            candidate_data_size,
                                                                                            split)

## Kernel optimization

In [None]:
party_datasets, party_labels, reference_dataset, candidate_datasets, candidate_labels = get_data_features('gmm',
                                                                                            5,
                                                                                            2,
                                                                                            5,
                                                                                            1000,
                                                                                            5000,
                                                                                            'equaldisjoint',
                                                                                            gamma=0)

In [None]:
device = 'cuda:0'

In [None]:
train_test_split_idx = 600

In [None]:
party_ds_size = train_test_split_idx
num_parties = len(party_datasets)
party_datasets_tens = torch.tensor(party_datasets[:, :train_test_split_idx], device=device, dtype=torch.float32)
reference_dataset_tens = torch.tensor(reference_dataset, device=device, dtype=torch.float32)

In [None]:
party_datasets_test = torch.tensor(party_datasets[:, train_test_split_idx:], device=device, dtype=torch.float32)

In [None]:
num_parties = 5
num_epochs = 50
batch_size = 128

In [None]:
ard_num_dims = 2

In [None]:
class SEKernel():
    """
    Custom squared exponential kernel parameterized by inverse lengthscale
    """
    def __init__(self, ard_num_dims, inv_lengthscale_squared, device):
        self.inv_ls_squared = torch.tensor([inv_lengthscale_squared for i in range(ard_num_dims)], device=device, requires_grad=True, dtype=torch.float32)
        self.ard_num_dims = ard_num_dims
        self.device = device
        
    def __call__(self, X, Y=None):
        """
        :param X: torch tensor of size (m, d)
        :param Y: torch tensor of size (n, d)
        :return: lazy tensor of size (m, n)
        """
        if Y is None:
            Y = X
        
        diff_squared = torch.square(torch.unsqueeze(X, 1) - Y)  # tensor of shape (m, n, d)
        exponent = torch.matmul(diff_squared, self.inv_ls_squared)  # tensor of shape (m, n)
        return torch.exp(-0.5 * exponent)
    
    
    def parameters(self):
        return [self.inv_ls_squared]
    
    
    def set_inv_ls_squared_scalar(self, inv_ls):
        self.inv_ls_squared = torch.tensor([inv_ls for i in range(self.ard_num_dims)], device=self.device, requires_grad=True, dtype=torch.float32)
    
    
    def set_inv_ls_squared(self, inv_ls_squared):
        self.inv_ls_squared = torch.tensor(inv_ls_squared, device=self.device, requires_grad=True, dtype=torch.float32)

In [None]:
my_kernel = SEKernel(2, 1, device)

In [None]:
def mmd_neg_unbiased_noeval(X, Y, k):
    """
    Used as loss function.
    :param X: Torch tensor
    :param Y: Torch tensor
    :param k: GPyTorch kernel
    :return: scalar
    """
    m = X.size(0)
    n = Y.size(0)

    S_X = (1 / (m * (m-1))) * (torch.sum(k(X)) - torch.sum(torch.diag(k(X))))
    S_XY = (2 / (m * n)) * torch.sum(k(X, Y))
    S_Y = (1 / (n * (n-1))) * (torch.sum(k(Y)) - torch.sum(torch.diag(k(Y))))

    return S_XY - S_X - S_Y

In [None]:
def nonneg_lb(n, S, k, eta=None, tol=1e-04):
    """
    :param n: size of reference dataset Y
    :param S: minimum party dataset size
    :param k: value of diagonal terms (usually 1)
    :param eta: upper bound (if none, set to k)
    :return: scalar
    """
    if eta is None:
        eta = k
    
    return (n-2*S)/(2*S*(n-S)) * (k + (S-1) * eta) - tol

In [None]:
n = len(reference_dataset)
S = np.min([len(ds) for ds in party_datasets])

In [None]:
lb = nonneg_lb(n, S, 1)

In [None]:
def is_pareto_efficient(costs, return_mask = True):
    """
    Find the pareto-efficient points
    :param costs: An (n_points, n_costs) array
    :param return_mask: True to return a mask
    :return: An array of indices of pareto-efficient points.
        If return_mask is True, this will be an (n_points, ) boolean array
        Otherwise it will be a (n_efficient_points, ) integer array of indices.
    """
    is_efficient = np.arange(costs.shape[0])
    n_points = costs.shape[0]
    next_point_index = 0  # Next index in the is_efficient array to search for
    while next_point_index<len(costs):
        nondominated_point_mask = np.any(costs<costs[next_point_index], axis=1)
        nondominated_point_mask[next_point_index] = True
        is_efficient = is_efficient[nondominated_point_mask]  # Remove dominated points
        costs = costs[nondominated_point_mask]
        next_point_index = np.sum(nondominated_point_mask[:next_point_index])+1
    if return_mask:
        is_efficient_mask = np.zeros(n_points, dtype = bool)
        is_efficient_mask[is_efficient] = True
        return is_efficient_mask
    else:
        return is_efficient

In [None]:
# Select num_val_points random points to check k(x_i, x_j) > lb
num_val_points = 2000
val_points = torch.tensor(reference_dataset[np.random.permutation(np.arange(num_val_points))], device=device, dtype=torch.float32)

In [None]:
def is_all_above_lb(k, val_points, lb):
    num_above = (k(val_points).cpu().detach().numpy() > lb).sum()
    return num_above == len(val_points) ** 2

In [None]:
d = 2

In [None]:
val_points_np = val_points.cpu().numpy()

In [None]:
squared_diffs = np.square(np.expand_dims(val_points_np, 1) - val_points_np)  # (m, m, d)
squared_diffs = np.reshape(squared_diffs, [-1, d])
squared_diff_idxs = np.where((np.triu(np.ones((num_val_points, num_val_points))) - np.diag(np.ones(num_val_points))).flatten())[0]
squared_diffs_reduced = squared_diffs[squared_diff_idxs]
reduced_D = squared_diffs_reduced[is_pareto_efficient(-squared_diffs_reduced)]

In [None]:
b = (-2 * np.log(lb)) * np.ones(len(reduced_D), dtype=np.float32)

In [None]:
k = SEKernel(2, 1, device)

In [None]:
# # Do a binary search for a good value of inv_ls_squared, low but above upper bound
# num_iters = 20
# low = 1
# high = 1000
# current = low

# # Check bounds
# k.set_inv_ls_squared_scalar(low)
# if not is_all_above_lb(k, val_points, lb):
#     raise Exception("Low value of inv_ls_squared is already invalid")
    
# k.set_inv_ls_squared_scalar(high)
# if is_all_above_lb(k, val_points, lb):
#     raise Exception("High value of inv_ls_squared is still valid, can be pushed higher")

# for i in range(num_iters):
#     mid = (high + low) / 2
#     print(mid)
#     k.set_inv_ls_squared_scalar(mid)
#     if is_all_above_lb(k, val_points, lb):
#         low = mid
#     else:
#         high = mid

# k.set_inv_ls_squared_scalar(low)
# print("Optimal inverse lengthscale squared: {}".format(low))

In [None]:
optimizer = torch.optim.SGD(k.parameters(), lr=0.1)
t = 0
patience = 20
averages = []
best_idx = 0
num_epochs = 100

for epoch in range(num_epochs):
    print("Epoch {}".format(epoch))
    
    print("========= Test -MMD unbiased ===========")
    stats = []
    for i in range(num_parties):
        stat = mmd_neg_unbiased_noeval(party_datasets_test[i], reference_dataset_tens, k).cpu().detach().numpy()
        print("Party {}: {}".format(i+1, stat))
        stats.append(stat)
    avg = np.mean(stats)
    print("Average: {}".format(avg))
    
    print("========= Kernel parameters ===========")
    print("inv lengthscale squared:")
    print(k.inv_ls_squared)
    print("lengthscale:")
    print(np.sqrt(1 / k.inv_ls_squared.cpu().detach().numpy()))
    print("k still valid (all above upper bound): {}".format(is_all_above_lb(k, val_points, lb)))
    
    for i in range(party_ds_size // batch_size):
        # Zero gradients from previous iteration
        optimizer.zero_grad()
        loss = 0

        idx = (i + 1)
        next_m = np.min([idx * batch_size, party_ds_size])
        m = i * batch_size

        ref_idx = np.random.randint(0, len(reference_dataset) - batch_size)
        next_ref_idx = ref_idx + batch_size
        
        #print(m, next_m)

        for party in range(num_parties):
            loss += mmd_neg_unbiased_noeval(party_datasets_tens[party][m:next_m],
                                     reference_dataset_tens[ref_idx:next_ref_idx],
                                     k)

        # Calc loss and backprop gradients
        loss.backward()
        
        # change gradients to argmin x \in C <grad, x>
        grad = k.inv_ls_squared.grad.cpu().numpy()
        print("Actual grad: {}".format(grad))
        res = linprog(grad, A_ub=reduced_D, b_ub=b, method='interior-point')
        y_t = res['x']
        print("y_t: {}".format(y_t))
        print("inv_ls_squared: {}".format(k.inv_ls_squared))
        
        # original conditional gradient update method
        step_size = 2/(t + 2)
        print("Step size: {}".format(step_size))
        k.set_inv_ls_squared((1 - step_size) * k.inv_ls_squared.cpu().detach().numpy() + step_size * y_t)
        t += 1
        
#         # gradient descent with constant step size
#         step_size = 0.1
#         k.set_inv_ls_squared((1 - step_size) * k.inv_ls_squared.cpu().detach().numpy() + step_size * y_t)
        
        #k.inv_ls_squared.grad = torch.tensor(new_grad, device=device, dtype=torch.float32)
        #optimizer.step()
        #print("k still valid (all above upper bound): {}".format(is_all_above_lb(k, val_points, lb)))
        #print("inv lengthscale squared:")
        #print(k.inv_ls_squared)
        #print("lengthscale:")
        #print(np.sqrt(1 / k.inv_ls_squared.cpu().detach().numpy()))

    
    # Code for early termination if no improvement after patience number of epochs
    averages.append(avg)
    if avg <= averages[best_idx]:
        best_idx = epoch  # Low is better for this
    elif avg >= averages[best_idx] and epoch - best_idx >= patience:
        print("No improvement for {} epochs, terminating early".format(patience))
        break
    

In [None]:
k2 = get_kernel('se', d, 1)

In [None]:
k2.base_kernel.lengthscale = np.sqrt(1 / k.inv_ls_squared.cpu().detach().numpy())

In [None]:
for i in range(5):
    print(mmd_neg_biased(party_datasets[i], reference_dataset, k2)[0])

## Kernel selection

In [None]:
k = get_kernel('se', 1, 1)

In [None]:
k.base_kernel.lengthscale = 0.5
X = torch.tensor([1, 2, 3, 4, -1, 10, 3, 2, 1])
mat = k(X,X).evaluate().detach().numpy()
mat

In [None]:
for i in range(len(X)):
    for j in range(len(X)):
        lb = 0.4375
        if mat[i, j] < lb:
            mat[i, j] = lb

In [None]:
for i in range(1, len(X)):
    S = X[:i]
    lower_bound = gamma(X, S)
    print("lb: {}".format(lower_bound))
    
    pos_term = 2/(len(S) * len(X)) * np.sum(mat[len(S):, :len(S)])
    neg_term = (2/(len(S) * len(X)) - 1/(len(S) ** 2)) * np.sum(mat[:len(S), :len(S)])
    
    print("v: {}".format(pos_term + neg_term))

In [None]:
for i in range(len(X)):
    for j in range(len(X)):
        if i != j:
            mat[i, j] = 0.125

In [None]:
mat

In [None]:
def gamma(X, S):
    n = len(X)
    s = len(S)
    return ((n/(2*s)-1)/(n-s)) * (1+(s-1)*eta)

In [None]:
eta = 1

In [None]:
lower_bound = ((n/(2*s)-1)/(n-s)) * (1+(s-1)*eta)

In [None]:
lower_bound

In [None]:
for i in range(1, len(X)):
    S = X[:i]
    pos_term = 2/(len(S) * len(X)) * np.sum(mat[len(S):, :len(S)])
    neg_term = (2/(len(S) * len(X)) - 1/(len(S) ** 2)) * np.sum(mat[:len(S), :len(S)])
    
    SXY = 2 / (len(S) * len(X)) * np.sum(mat[:, :len(S)])
    SX =  1/(len(S) ** 2) * np.sum(mat[:len(S), :len(S)])
    print("SXY: {}".format(SXY))
    print("SX: {}".format(SX))
    
    print(pos_term + neg_term)
    print("SXY - SX: {}".format(SXY - SX))

In [None]:
mmd_neg_biased(S, X, k)[0]

In [None]:
pos_term = 2/(len(S) * 10) * np.sum(mat[len(S):, :len(S)])

In [None]:
pos_term

In [None]:
neg_term = (2/(len(S) * 10) - 1/(len(S) ** 2)) * np.sum(mat[:len(S), :len(S)])

In [None]:
neg_term

In [None]:
pos_term + neg_term

## CIFAR-5

In [None]:
def merge(images, size):
    h, w = images.shape[1], images.shape[2]
    img = np.zeros((h * size[0], w * size[1], 3))
    for idx, image in enumerate(images):
        i = idx % size[1]
        j = idx // size[1]
        img[j*h:j*h+h, i*w:i*w+w, :] = image

    return img

In [None]:
train_images = np.load("data/cifar/cifar_train_images.npy")
train_labels = np.load("data/cifar/cifar_train_labels.npy")
candidate_images = np.load("data/cifar/cifar_samples.npy")
candidate_labels = np.load("data/cifar/cifar_samples_labels.npy")

In [None]:
classes = [0, 1, 6, 7, 8]  # Classes we want, airplane, automobile, frog, horse, ship

In [None]:
class_dict = {0: 0,
             1: 1,
             6: 2,
             7: 3,
             8: 4}

In [None]:
train_idx = []
for i in range(len(train_images)):
    if train_labels[i] in classes:
        train_idx.append(i)

In [None]:
cifar5_train_images = train_images[train_idx]
cifar5_train_labels = train_labels[train_idx]

In [None]:
cifar5_train_new_labels = [class_dict[label] for label in cifar5_train_labels]

In [None]:
plt.imshow(merge(cifar5_train_images[200:264], [8,8]))

In [None]:
cifar5_train_new_labels[200:264]

In [None]:
candidate_idx = []
for i in range(len(candidate_images)):
    if candidate_labels[i] in classes:
        candidate_idx.append(i)

In [None]:
cifar5_candidate_images = candidate_images[candidate_idx]
cifar5_candidate_labels = candidate_labels[candidate_idx]

In [None]:
cifar5_candidate_new_labels = [class_dict[label] for label in cifar5_candidate_labels]

In [None]:
plt.imshow(merge(cifar5_candidate_images[200:264], [8,8]))

In [None]:
cifar5_candidate_new_labels[200:264]

In [None]:
np.save("data/cifar5/cifar5_train_images.npy", cifar5_train_images)
np.save("data/cifar5/cifar5_train_labels.npy", cifar5_train_new_labels)
np.save("data/cifar5/cifar5_samples.npy", cifar5_candidate_images)
np.save("data/cifar5/cifar5_samples_labels.npy", cifar5_candidate_new_labels)

In [None]:
party_datasets, party_labels, candidate_dataset, candidate_labels = get_data_raw(dataset='cifar5',
                                                                                     num_classes=5,
                                                                                     party_data_size=5000,
                                                                                     candidate_data_size=20000,
                                                                                     split='unequal')

In [None]:
party_idx = 4

In [None]:
plt.imshow(merge(party_datasets[party_idx, 200:264], [8,8]))

In [None]:
party_labels[party_idx, 200:264]

In [None]:
plt.imshow(merge(candidate_dataset[200:264], [8,8]))

In [None]:
candidate_labels[200:264]

## MMD GPU batch calculation

In [None]:
party_datasets, party_labels, reference_dataset, candidate_datasets, candidate_labels = get_data_features('gmm',
                                                                                            5,
                                                                                            2,
                                                                                            5,
                                                                                            1000,
                                                                                            5000,
                                                                                            'equaldisjoint',
                                                                                            gamma=0)

In [None]:
kernel = get_kernel('se', 2, 1)

In [None]:
##### BUGGED CODE, WILL NOT GIVE CORRECT RESULT WHEN BATCH_SIZE > PARTY DATA SIZE

def mmd_neg_biased_batched(X, Y, k, device, batch_size=128):
    """
    Calculates biased MMD^2 without the S_YY term, where S_X, S_XY and S_YY are the pairwise-XX, pairwise-XY, pairwise-YY
    summation terms respectively. Does so using the GPU in a batch-wise manner.
    :param X: array of shape (m, d)
    :param Y: array of shape (n, d)
    :param k: GPyTorch kernel
    :param device:
    :param batch_size:
    :return: MMD^2, S_X, S_XY
    """
    max_m = X.shape[0]
    n = Y.shape[0]

    X_tens = torch.tensor(X, device=device)
    Y_tens = torch.tensor(Y, device=device)
    k.to(device)

    with torch.no_grad():
        # first batch
        S_XY = (2 / (batch_size * n)) * torch.sum(k(X_tens[:batch_size], Y_tens).evaluate())
        S_X = (1 / (batch_size ** 2)) * torch.sum(k(X_tens[:batch_size]).evaluate())

        for i in range(max_m // batch_size):
            idx = i + 2
            next_m = np.min([idx * batch_size, max_m])
            m = (idx - 1) * batch_size
            S_XY = (m * S_XY + (2 / n) * torch.sum(k(X_tens[m:next_m], Y_tens).evaluate())) / next_m
            S_X = ((m ** 2) * S_X + 2 * torch.sum(k(X_tens[m:next_m], X_tens[:m]).evaluate()) +
                   torch.sum(k(X_tens[m:next_m]).evaluate())) / (next_m ** 2)
    
    return (S_XY - S_X).item(), S_X.item(), S_XY.item()

In [None]:
X = party_datasets[2]
Y = reference_dataset

In [None]:
X.shape

In [None]:
mmd_neg_biased(X, Y, kernel.to('cpu'))

In [None]:
mmd_neg_biased_batched(X, Y, kernel, 'cuda:0', batch_size=2000)

In [None]:
device = 'cuda:0'

In [None]:
Y_tens = torch.tensor(Y).to(device)

In [None]:
Y_tens.device

In [None]:
X_tens = torch.tensor(X).to(device)

In [None]:
kernel = kernel.to(device)
with torch.no_grad():
    print(torch.sum(kernel(Y_tens, Y_tens).evaluate()))

In [None]:
kernel.to('cpu')
with torch.no_grad():
    print(torch.sum(kernel(torch.tensor(X), torch.tensor(Y)).evaluate() * 1/100000))

In [None]:
kernel(X_tens[:64], Y_tens).evaluate()

In [None]:
batch_size = 32

In [None]:
n = Y_tens.size(0)

In [None]:
max_m = X.shape[0]

In [None]:
Y_tens.shape

In [None]:
kernel.to(device)
X_tens.to(device)
Y_tens.to(device)
with torch.no_grad():
    # first batch
    S_XY = (2/(batch_size * n)) * torch.sum(kernel(X_tens[:batch_size], Y_tens).evaluate())
    S_X = (1/(batch_size ** 2)) * torch.sum(kernel(X_tens[:batch_size]).evaluate())
    
    for i in range(max_m // batch_size):
        idx = i + 2
        m = np.min([idx * batch_size, max_m])
        prev_m = (idx - 1) * batch_size
        c = m - prev_m
        print(prev_m, m)
        S_XY = (prev_m * S_XY + (2/n) * torch.sum(kernel(X_tens[prev_m:m], Y_tens).evaluate())) / (prev_m + c)
        S_X = ((prev_m ** 2) * S_X + 2 * torch.sum(kernel(X_tens[prev_m:m], X_tens[:prev_m]).evaluate()) + 
                torch.sum(kernel(X_tens[prev_m:m]).evaluate())) / ((prev_m + c) ** 2) 
        print((S_XY.item() - S_X.item(), S_X.item(), S_XY.item()))

In [None]:
mmd_neg_biased(X, Y, kernel)

In [None]:
kernel.to('cpu')
for i in range(max_m // batch_size):
    idx = i + 2
    m = np.min([idx * batch_size, max_n])
    print(mmd_neg_biased(X[:m], Y, kernel))

In [None]:
kernel.to(device)
with torch.no_grad():
    # first batch
    S_XY = (2/(batch_size * n)) * torch.sum(kernel(X_tens[:batch_size], Y_tens).evaluate())
    S_X = (1/(batch_size ** 2)) * torch.sum(kernel(X_tens[:batch_size]).evaluate())
    print((S_XY.item() - S_X.item(), S_X.item(), S_XY.item()))

In [None]:
kernel.to('cpu')
print(mmd_neg_biased(X[:32], Y, kernel))
_, S_X, S_XY = mmd_neg_biased(X[:32], Y, kernel)

In [None]:
(32 * S_XY + (2/n) * torch.sum(kernel(torch.tensor(X[32:64]), torch.tensor(Y)).evaluate())) / (32 + 32)

In [None]:
S_XY

In [None]:
S_XY - S_X

In [None]:
1000 // batch_size

In [None]:
[np.min([(i+2) * batch_size, 1000]) for i in range(1000 // batch_size)]

## Conv net

In [None]:
input_dim = 32
num_channels = 3
hidden_dim = 16

In [None]:
inp = torch.randn(4, num_channels, input_dim, input_dim)

In [None]:
conv1 = Conv2d(in_channels=num_channels, out_channels=64, kernel_size=5, stride=2, padding=2)

In [None]:
torch.std(conv1(inp))

In [None]:
conv1(inp).shape

In [None]:
conv1_bn = BatchNorm2d(64)

In [None]:
torch.mean(conv1_bn(conv1(inp)))

In [None]:
conv2 = Conv2d(in_channels=64, out_channels=128, kernel_size=5, stride=2, padding=2)

In [None]:
conv2(conv1(inp)).shape

In [None]:
conv3 = Conv2d(in_channels=128, out_channels=256, kernel_size=5, stride=2, padding=2)

In [None]:
conv3(conv2(conv1(inp))).shape

In [None]:
conv4 = Conv2d(in_channels=256, out_channels=512, kernel_size=5, stride=2, padding=2)

In [None]:
conv4(conv3(conv2(conv1(inp)))).shape

In [None]:
that_size = conv4(conv3(conv2(conv1(inp)))).shape

In [None]:
fc1 = Linear(2048, 16)

In [None]:
fc1(conv4(conv3(conv2(conv1(inp)))).view((4, -1))).shape

In [None]:
class CustomView(nn.Module):  # Flattening layer for nn.Sequential
    def __init__(self, shape):
        super(CustomView, self).__init__()
        self.shape = shape

    def forward(self, x):
        return x.view(*self.shape)

In [None]:
encoder = nn.Sequential(
            nn.Conv2d(in_channels=num_channels, out_channels=64, kernel_size=5, stride=2, padding=2),
            nn.LeakyReLU(),
            nn.BatchNorm2d(64),
            nn.Conv2d(in_channels=64, out_channels=128, kernel_size=5, stride=2, padding=2),
            nn.LeakyReLU(),
            nn.BatchNorm2d(128),
            nn.Conv2d(in_channels=128, out_channels=256, kernel_size=5, stride=2, padding=2),
            nn.LeakyReLU(),
            nn.BatchNorm2d(256),
            nn.Conv2d(in_channels=256, out_channels=512, kernel_size=5, stride=2, padding=2),
            nn.LeakyReLU(),
            nn.BatchNorm2d(512),
            CustomView((-1, 2048)),
            nn.Linear(2048, hidden_dim)
        )

In [None]:
encoder(inp)

In [None]:
decoder = nn.Sequential(
            nn.Linear(hidden_dim, 2048),
            nn.ReLU(),
            nn.BatchNorm1d(2048),
            CustomView((-1, 512, 2, 2)),
            nn.ConvTranspose2d(512, 256, kernel_size=5, stride=2, padding=2, output_padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(256),
            nn.ConvTranspose2d(256, 128, kernel_size=5, stride=2, padding=2, output_padding=0 if input_dim==28 else 1),
            nn.ReLU(),
            nn.BatchNorm2d(128),
            nn.ConvTranspose2d(128, 64, kernel_size=5, stride=2, padding=2, output_padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(64),
            nn.ConvTranspose2d(64, num_channels, kernel_size=5, stride=2, padding=2, output_padding=1),
            nn.Tanh()
)

In [None]:
decoder(encoder(inp)).shape#.permute((0, 2, 3, 1))

In [None]:
means = (1, 2, 3)

In [None]:
np.expand_dims(means, axis=[1, 2]).shape

In [None]:
means = np.expand_dims(means, axis=[1, 2])

In [None]:
decoder(encoder(inp)).detach().numpy() * means

## Dataloaders

In [None]:
def merge(images, size):
    h, w = images.shape[1], images.shape[2]
    img = np.zeros((h * size[0], w * size[1], 3))
    for idx, image in enumerate(images):
        i = idx % size[1]
        j = idx // size[1]
        img[j*h:j*h+h, i*w:i*w+w, :] = image

    return img

In [None]:
parser = ArgumentParser()
parser.add_argument('--batch_size', default=64, type=int)
parser.add_argument('--hidden_dim', default=16, type=int)
parser.add_argument('--dataset', default='cifar', type=str)
parser.add_argument('--num_classes', default=10, type=int)
parser.add_argument('--party_data_size', default=4000, type=int)
parser.add_argument('--candidate_data_size', default=10000, type=int)
parser.add_argument('--split', default='equaldisjoint', type=str)
parser.add_argument('-f', type=str)
args = parser.parse_args()

In [None]:
party_datasets, party_labels, candidate_dataset, candidate_labels = get_data_raw(dataset=args.dataset,
                                                                                         num_classes=args.num_classes,
                                                                                         party_data_size=args.party_data_size,
                                                                                         candidate_data_size=args.candidate_data_size,
                                                                                         split=args.split)

In [None]:
party = 0

In [None]:
plt.imshow(merge(party_datasets[party][:64], [8,8]))

In [None]:
party_labels[party][:64]

In [None]:
plt.imshow(merge(candidate_dataset[:64], [8,8]))

In [None]:
candidate_labels[:64]

In [None]:
np.max(party_datasets)

In [None]:
np.min(party_datasets)

In [None]:
np.max(candidate_dataset)

In [None]:
np.min(candidate_labels)

In [None]:
num_channels = party_datasets.shape[-1]

In [None]:
np.mean(candidate_dataset.reshape([-1, num_channels]), axis=0)

In [None]:
combined = np.concatenate([np.concatenate(party_datasets), candidate_dataset])

In [None]:
combined.shape

In [None]:
means = np.mean(combined.reshape(-1, num_channels), axis=0)

In [None]:
means

In [None]:
stds = np.std(combined.reshape(-1, num_channels), axis=0)

In [None]:
stds

In [None]:
class ConcatDataset(torch.utils.data.Dataset):
    def __init__(self, *datasets):
        self.datasets = datasets

    def __getitem__(self, i):
        return tuple(d[i] for d in self.datasets)

    def __len__(self):
        return min(len(d) for d in self.datasets)

In [None]:
datasets = []
for i in range(len(party_datasets)):
    transformed = (party_datasets[i] - means) / stds
    dataset = TensorDataset(torch.tensor(transformed), torch.tensor(party_labels[i]))
    datasets.append(dataset)
transformed = (candidate_dataset - means) / stds
dataset = TensorDataset(torch.tensor(transformed), torch.tensor(candidate_labels))
datasets.append(dataset)

In [None]:
concat_dataset = ConcatDataset(*datasets)

In [None]:
loader = torch.utils.data.DataLoader(
            concat_dataset,
            batch_size=args.batch_size,
            shuffle=True,
            pin_memory=True)

In [None]:
my_iterator = iter(loader)

In [None]:
party = 5

In [None]:
batch_images, batch_labels = next(my_iterator)[party]

In [None]:
plt.imshow(merge(batch_images, [8, 8]))

In [None]:
batch_labels