In [1]:
import numpy as np
import matplotlib.pyplot as plt
from argparse import ArgumentParser
import torch
from torch.utils.data import TensorDataset, DataLoader
from torchvision import datasets, transforms
import gpytorch

from data.pipeline import get_data_raw

In [2]:
from torch import nn, tanh
from torch.nn import Linear, Conv2d, ConvTranspose2d, BatchNorm2d, ReLU, LeakyReLU

In [3]:
from core.mmd import mmd_neg_biased
from data.pipeline import get_data_features
from core.kernel import get_kernel
from core.reward_calculation import get_v

## CIFAR-5

In [None]:
def merge(images, size):
    h, w = images.shape[1], images.shape[2]
    img = np.zeros((h * size[0], w * size[1], 3))
    for idx, image in enumerate(images):
        i = idx % size[1]
        j = idx // size[1]
        img[j*h:j*h+h, i*w:i*w+w, :] = image

    return img

In [None]:
train_images = np.load("data/cifar/cifar_train_images.npy")
train_labels = np.load("data/cifar/cifar_train_labels.npy")
candidate_images = np.load("data/cifar/cifar_samples.npy")
candidate_labels = np.load("data/cifar/cifar_samples_labels.npy")

In [None]:
classes = [0, 1, 6, 7, 8]  # Classes we want, airplane, automobile, frog, horse, ship

In [None]:
class_dict = {0: 0,
             1: 1,
             6: 2,
             7: 3,
             8: 4}

In [None]:
train_idx = []
for i in range(len(train_images)):
    if train_labels[i] in classes:
        train_idx.append(i)

In [None]:
cifar5_train_images = train_images[train_idx]
cifar5_train_labels = train_labels[train_idx]

In [None]:
cifar5_train_new_labels = [class_dict[label] for label in cifar5_train_labels]

In [None]:
plt.imshow(merge(cifar5_train_images[200:264], [8,8]))

In [None]:
cifar5_train_new_labels[200:264]

In [None]:
candidate_idx = []
for i in range(len(candidate_images)):
    if candidate_labels[i] in classes:
        candidate_idx.append(i)

In [None]:
cifar5_candidate_images = candidate_images[candidate_idx]
cifar5_candidate_labels = candidate_labels[candidate_idx]

In [None]:
cifar5_candidate_new_labels = [class_dict[label] for label in cifar5_candidate_labels]

In [None]:
plt.imshow(merge(cifar5_candidate_images[200:264], [8,8]))

In [None]:
cifar5_candidate_new_labels[200:264]

In [None]:
np.save("data/cifar5/cifar5_train_images.npy", cifar5_train_images)
np.save("data/cifar5/cifar5_train_labels.npy", cifar5_train_new_labels)
np.save("data/cifar5/cifar5_samples.npy", cifar5_candidate_images)
np.save("data/cifar5/cifar5_samples_labels.npy", cifar5_candidate_new_labels)

In [None]:
party_datasets, party_labels, candidate_dataset, candidate_labels = get_data_raw(dataset='cifar5',
                                                                                     num_classes=5,
                                                                                     party_data_size=5000,
                                                                                     candidate_data_size=20000,
                                                                                     split='unequal')

In [None]:
party_idx = 4

In [None]:
plt.imshow(merge(party_datasets[party_idx, 200:264], [8,8]))

In [None]:
party_labels[party_idx, 200:264]

In [None]:
plt.imshow(merge(candidate_dataset[200:264], [8,8]))

In [None]:
candidate_labels[200:264]

## MMD GPU batch calculation

In [23]:
party_datasets, party_labels, reference_dataset, candidate_datasets, candidate_labels = get_data_features('gmm',
                                                                                            5,
                                                                                            2,
                                                                                            5,
                                                                                            1000,
                                                                                            5000,
                                                                                            'equaldisjoint',
                                                                                            gamma=0)

In [24]:
kernel = get_kernel('se', 2, 1)

In [36]:
##### BUGGED CODE, WILL NOT GIVE CORRECT RESULT WHEN BATCH_SIZE > PARTY DATA SIZE

def mmd_neg_biased_batched(X, Y, k, device, batch_size=128):
    """
    Calculates biased MMD^2 without the S_YY term, where S_X, S_XY and S_YY are the pairwise-XX, pairwise-XY, pairwise-YY
    summation terms respectively. Does so using the GPU in a batch-wise manner.
    :param X: array of shape (m, d)
    :param Y: array of shape (n, d)
    :param k: GPyTorch kernel
    :param device:
    :param batch_size:
    :return: MMD^2, S_X, S_XY
    """
    max_m = X.shape[0]
    n = Y.shape[0]

    X_tens = torch.tensor(X, device=device)
    Y_tens = torch.tensor(Y, device=device)
    k.to(device)

    with torch.no_grad():
        # first batch
        S_XY = (2 / (batch_size * n)) * torch.sum(k(X_tens[:batch_size], Y_tens).evaluate())
        S_X = (1 / (batch_size ** 2)) * torch.sum(k(X_tens[:batch_size]).evaluate())

        for i in range(max_m // batch_size):
            idx = i + 2
            next_m = np.min([idx * batch_size, max_m])
            m = (idx - 1) * batch_size
            S_XY = (m * S_XY + (2 / n) * torch.sum(k(X_tens[m:next_m], Y_tens).evaluate())) / next_m
            S_X = ((m ** 2) * S_X + 2 * torch.sum(k(X_tens[m:next_m], X_tens[:m]).evaluate()) +
                   torch.sum(k(X_tens[m:next_m]).evaluate())) / (next_m ** 2)
    
    return (S_XY - S_X).item(), S_X.item(), S_XY.item()

In [25]:
X = party_datasets[2]
Y = reference_dataset

In [32]:
X.shape

(1000, 2)

In [27]:
mmd_neg_biased(X, Y, kernel.to('cpu'))

(0.9387285113334656, 0.9873847365379333, 1.926113247871399)

In [38]:
mmd_neg_biased_batched(X, Y, kernel, 'cuda:0', batch_size=2000)

(0.7162104561466199, 0.24684620067407545, 0.9630566568206954)

In [None]:
device = 'cuda:0'

In [None]:
Y_tens = torch.tensor(Y).to(device)

In [None]:
Y_tens.device

In [None]:
X_tens = torch.tensor(X).to(device)

In [None]:
kernel = kernel.to(device)
with torch.no_grad():
    print(torch.sum(kernel(Y_tens, Y_tens).evaluate()))

In [None]:
kernel.to('cpu')
with torch.no_grad():
    print(torch.sum(kernel(torch.tensor(X), torch.tensor(Y)).evaluate() * 1/100000))

In [None]:
kernel(X_tens[:64], Y_tens).evaluate()

In [None]:
batch_size = 32

In [None]:
n = Y_tens.size(0)

In [None]:
max_m = X.shape[0]

In [None]:
Y_tens.shape

In [None]:
kernel.to(device)
X_tens.to(device)
Y_tens.to(device)
with torch.no_grad():
    # first batch
    S_XY = (2/(batch_size * n)) * torch.sum(kernel(X_tens[:batch_size], Y_tens).evaluate())
    S_X = (1/(batch_size ** 2)) * torch.sum(kernel(X_tens[:batch_size]).evaluate())
    
    for i in range(max_m // batch_size):
        idx = i + 2
        m = np.min([idx * batch_size, max_m])
        prev_m = (idx - 1) * batch_size
        c = m - prev_m
        print(prev_m, m)
        S_XY = (prev_m * S_XY + (2/n) * torch.sum(kernel(X_tens[prev_m:m], Y_tens).evaluate())) / (prev_m + c)
        S_X = ((prev_m ** 2) * S_X + 2 * torch.sum(kernel(X_tens[prev_m:m], X_tens[:prev_m]).evaluate()) + 
                torch.sum(kernel(X_tens[prev_m:m]).evaluate())) / ((prev_m + c) ** 2) 
        print((S_XY.item() - S_X.item(), S_X.item(), S_XY.item()))

In [None]:
mmd_neg_biased(X, Y, kernel)

In [None]:
kernel.to('cpu')
for i in range(max_m // batch_size):
    idx = i + 2
    m = np.min([idx * batch_size, max_n])
    print(mmd_neg_biased(X[:m], Y, kernel))

In [None]:
kernel.to(device)
with torch.no_grad():
    # first batch
    S_XY = (2/(batch_size * n)) * torch.sum(kernel(X_tens[:batch_size], Y_tens).evaluate())
    S_X = (1/(batch_size ** 2)) * torch.sum(kernel(X_tens[:batch_size]).evaluate())
    print((S_XY.item() - S_X.item(), S_X.item(), S_XY.item()))

In [None]:
kernel.to('cpu')
print(mmd_neg_biased(X[:32], Y, kernel))
_, S_X, S_XY = mmd_neg_biased(X[:32], Y, kernel)

In [None]:
(32 * S_XY + (2/n) * torch.sum(kernel(torch.tensor(X[32:64]), torch.tensor(Y)).evaluate())) / (32 + 32)

In [None]:
S_XY

In [None]:
S_XY - S_X

In [None]:
1000 // batch_size

In [None]:
[np.min([(i+2) * batch_size, 1000]) for i in range(1000 // batch_size)]

## Conv net

In [None]:
input_dim = 32
num_channels = 3
hidden_dim = 16

In [None]:
inp = torch.randn(4, num_channels, input_dim, input_dim)

In [None]:
conv1 = Conv2d(in_channels=num_channels, out_channels=64, kernel_size=5, stride=2, padding=2)

In [None]:
torch.std(conv1(inp))

In [None]:
conv1(inp).shape

In [None]:
conv1_bn = BatchNorm2d(64)

In [None]:
torch.mean(conv1_bn(conv1(inp)))

In [None]:
conv2 = Conv2d(in_channels=64, out_channels=128, kernel_size=5, stride=2, padding=2)

In [None]:
conv2(conv1(inp)).shape

In [None]:
conv3 = Conv2d(in_channels=128, out_channels=256, kernel_size=5, stride=2, padding=2)

In [None]:
conv3(conv2(conv1(inp))).shape

In [None]:
conv4 = Conv2d(in_channels=256, out_channels=512, kernel_size=5, stride=2, padding=2)

In [None]:
conv4(conv3(conv2(conv1(inp)))).shape

In [None]:
that_size = conv4(conv3(conv2(conv1(inp)))).shape

In [None]:
fc1 = Linear(2048, 16)

In [None]:
fc1(conv4(conv3(conv2(conv1(inp)))).view((4, -1))).shape

In [None]:
class CustomView(nn.Module):  # Flattening layer for nn.Sequential
    def __init__(self, shape):
        super(CustomView, self).__init__()
        self.shape = shape

    def forward(self, x):
        return x.view(*self.shape)

In [None]:
encoder = nn.Sequential(
            nn.Conv2d(in_channels=num_channels, out_channels=64, kernel_size=5, stride=2, padding=2),
            nn.LeakyReLU(),
            nn.BatchNorm2d(64),
            nn.Conv2d(in_channels=64, out_channels=128, kernel_size=5, stride=2, padding=2),
            nn.LeakyReLU(),
            nn.BatchNorm2d(128),
            nn.Conv2d(in_channels=128, out_channels=256, kernel_size=5, stride=2, padding=2),
            nn.LeakyReLU(),
            nn.BatchNorm2d(256),
            nn.Conv2d(in_channels=256, out_channels=512, kernel_size=5, stride=2, padding=2),
            nn.LeakyReLU(),
            nn.BatchNorm2d(512),
            CustomView((-1, 2048)),
            nn.Linear(2048, hidden_dim)
        )

In [None]:
encoder(inp)

In [None]:
decoder = nn.Sequential(
            nn.Linear(hidden_dim, 2048),
            nn.ReLU(),
            nn.BatchNorm1d(2048),
            CustomView((-1, 512, 2, 2)),
            nn.ConvTranspose2d(512, 256, kernel_size=5, stride=2, padding=2, output_padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(256),
            nn.ConvTranspose2d(256, 128, kernel_size=5, stride=2, padding=2, output_padding=0 if input_dim==28 else 1),
            nn.ReLU(),
            nn.BatchNorm2d(128),
            nn.ConvTranspose2d(128, 64, kernel_size=5, stride=2, padding=2, output_padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(64),
            nn.ConvTranspose2d(64, num_channels, kernel_size=5, stride=2, padding=2, output_padding=1),
            nn.Tanh()
)

In [None]:
decoder(encoder(inp)).shape#.permute((0, 2, 3, 1))

In [None]:
means = (1, 2, 3)

In [None]:
np.expand_dims(means, axis=[1, 2]).shape

In [None]:
means = np.expand_dims(means, axis=[1, 2])

In [None]:
decoder(encoder(inp)).detach().numpy() * means

## Dataloaders

In [None]:
def merge(images, size):
    h, w = images.shape[1], images.shape[2]
    img = np.zeros((h * size[0], w * size[1], 3))
    for idx, image in enumerate(images):
        i = idx % size[1]
        j = idx // size[1]
        img[j*h:j*h+h, i*w:i*w+w, :] = image

    return img

In [None]:
parser = ArgumentParser()
parser.add_argument('--batch_size', default=64, type=int)
parser.add_argument('--hidden_dim', default=16, type=int)
parser.add_argument('--dataset', default='cifar', type=str)
parser.add_argument('--num_classes', default=10, type=int)
parser.add_argument('--party_data_size', default=4000, type=int)
parser.add_argument('--candidate_data_size', default=10000, type=int)
parser.add_argument('--split', default='equaldisjoint', type=str)
parser.add_argument('-f', type=str)
args = parser.parse_args()

In [None]:
party_datasets, party_labels, candidate_dataset, candidate_labels = get_data_raw(dataset=args.dataset,
                                                                                         num_classes=args.num_classes,
                                                                                         party_data_size=args.party_data_size,
                                                                                         candidate_data_size=args.candidate_data_size,
                                                                                         split=args.split)

In [None]:
party = 0

In [None]:
plt.imshow(merge(party_datasets[party][:64], [8,8]))

In [None]:
party_labels[party][:64]

In [None]:
plt.imshow(merge(candidate_dataset[:64], [8,8]))

In [None]:
candidate_labels[:64]

In [None]:
np.max(party_datasets)

In [None]:
np.min(party_datasets)

In [None]:
np.max(candidate_dataset)

In [None]:
np.min(candidate_labels)

In [None]:
num_channels = party_datasets.shape[-1]

In [None]:
np.mean(candidate_dataset.reshape([-1, num_channels]), axis=0)

In [None]:
combined = np.concatenate([np.concatenate(party_datasets), candidate_dataset])

In [None]:
combined.shape

In [None]:
means = np.mean(combined.reshape(-1, num_channels), axis=0)

In [None]:
means

In [None]:
stds = np.std(combined.reshape(-1, num_channels), axis=0)

In [None]:
stds

In [None]:
class ConcatDataset(torch.utils.data.Dataset):
    def __init__(self, *datasets):
        self.datasets = datasets

    def __getitem__(self, i):
        return tuple(d[i] for d in self.datasets)

    def __len__(self):
        return min(len(d) for d in self.datasets)

In [None]:
datasets = []
for i in range(len(party_datasets)):
    transformed = (party_datasets[i] - means) / stds
    dataset = TensorDataset(torch.tensor(transformed), torch.tensor(party_labels[i]))
    datasets.append(dataset)
transformed = (candidate_dataset - means) / stds
dataset = TensorDataset(torch.tensor(transformed), torch.tensor(candidate_labels))
datasets.append(dataset)

In [None]:
concat_dataset = ConcatDataset(*datasets)

In [None]:
loader = torch.utils.data.DataLoader(
            concat_dataset,
            batch_size=args.batch_size,
            shuffle=True,
            pin_memory=True)

In [None]:
my_iterator = iter(loader)

In [None]:
party = 5

In [None]:
batch_images, batch_labels = next(my_iterator)[party]

In [None]:
plt.imshow(merge(batch_images, [8, 8]))

In [None]:
batch_labels