In [1]:
import numpy as np
import torch
import torch.nn.functional as F
from torch import nn
from torch.utils.data import DataLoader, TensorDataset
from collections import OrderedDict

import pytorch_lightning as pl
from pytorch_lightning.callbacks import Callback
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from pytorch_lightning.loggers import TensorBoardLogger

from data.pipeline import get_data_raw
from core.kernel import get_kernel

In [2]:
class CustomView(nn.Module):  # Flattening layer for nn.Sequential
    def __init__(self, shape):
        super(CustomView, self).__init__()
        self.shape = shape

    def forward(self, x):
        return x.view(*self.shape)

In [46]:
class LitAutoEncoder(pl.LightningModule):
    """
    LitAutoEncoder(
      (encoder): ...
      (decoder): ...
    )
    """

    def __init__(self, num_channels, side_dim, hidden_dim, lr, gamma):
        super().__init__()

        self.encoder = nn.Sequential(OrderedDict([
            ('conv1', nn.Conv2d(in_channels=num_channels, out_channels=64, kernel_size=5, stride=2, padding=2)),
            ('lrelu1', nn.LeakyReLU()),
            ('bn1', nn.BatchNorm2d(64)),
            ('conv2', nn.Conv2d(in_channels=64, out_channels=128, kernel_size=5, stride=2, padding=2)),
            ('lrelu2', nn.LeakyReLU()),
            ('bn2', nn.BatchNorm2d(128)),
            ('conv3', nn.Conv2d(in_channels=128, out_channels=256, kernel_size=5, stride=2, padding=2)),
            ('lrelu3', nn.LeakyReLU()),
            ('bn3', nn.BatchNorm2d(256)),
            ('conv4', nn.Conv2d(in_channels=256, out_channels=512, kernel_size=5, stride=2, padding=2)),
            ('lrelu4', nn.LeakyReLU()),
            ('bn4', nn.BatchNorm2d(512)),
            ('view1', CustomView((-1, 2048))),
            ('linear1', nn.Linear(2048, hidden_dim))
        ]))

        self.decoder = nn.Sequential(OrderedDict([
            ('linear1', nn.Linear(hidden_dim, 2048)),
            ('relu1', nn.ReLU()),
            ('bn1', nn.BatchNorm1d(2048)),
            ('view1', CustomView((-1, 512, 2, 2))),
            ('convT1', nn.ConvTranspose2d(512, 256, kernel_size=5, stride=2, padding=2, output_padding=1)),
            ('relu2', nn.ReLU()),
            ('bn2', nn.BatchNorm2d(256)),
            ('convT2', nn.ConvTranspose2d(256, 128, kernel_size=5, stride=2, padding=2,
                               output_padding=0 if side_dim == 28 else 1)),
            ('relu3', nn.ReLU()),
            ('bn3', nn.BatchNorm2d(128)),
            ('convT3', nn.ConvTranspose2d(128, 64, kernel_size=5, stride=2, padding=2, output_padding=1)),
            ('relu4', nn.ReLU()),
            ('bn4', nn.BatchNorm2d(64)),
            ('convT4', nn.ConvTranspose2d(64, num_channels, kernel_size=5, stride=2, padding=2, output_padding=1)),
            ('tanh1', nn.Tanh())
        ]))

        self.kernel = get_kernel('rq', hidden_dim)
        self.lr = lr
        self.gamma = gamma

    def forward(self, x):
        # in lightning, forward defines the prediction/inference actions
        embedding = self.encoder(x)
        return embedding

    def autoencoder_loss(self, dataset):
        x, y = dataset
        z = self.encoder(x)
        x_hat = self.decoder(z)
        return F.mse_loss(x_hat, x)

    def mmd_loss(self, p, q):
        x_p, y_p = p
        x_q, y_q = q
        z_p = self.encoder(x_p)
        z_q = self.encoder(x_q)
        return mmd_neg_unbiased(z_p, z_q, self.kernel)

    def training_step(self, batch, batch_idx):
        ae_loss = 0
        # Autoencoder loss
        for dataset in batch[:-1]:
            ae_loss += self.autoencoder_loss(dataset)

        mmd_loss = 0
        # MMD loss, parties against reference dataset (all parties + candidates)
        ref_dataset = batch[-1]
        for dataset in batch[:-2]:
            mmd_loss += self.gamma * self.mmd_loss(dataset, ref_dataset)

        loss = ae_loss + mmd_loss
        return loss

    def validation_step(self, batch, batch_idx):
        ae_loss = 0
        for dataset in batch[:-1]:
            ae_loss += self.autoencoder_loss(dataset)
        self.log('autoencoder_loss', ae_loss)

        mmd_loss = 0
        ref_dataset = batch[-1]
        for dataset in batch[:-2]:
            mmd_loss += self.mmd_loss(dataset, ref_dataset)
        self.log('mmd_loss', mmd_loss)

        self.log('total_loss', ae_loss + mmd_loss)


    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.lr)
        return optimizer

In [47]:
dataset = 'mnist'
split = 'unequal'
num_channels = 1
side_dim = 28
hidden_dim = 16
gamma = '005'
party_data_size = 10000
candidate_data_size = 40000

In [48]:
model = LitAutoEncoder.load_from_checkpoint("models/{}-{}-gamma{}.ckpt".format(dataset, split, gamma),
                                           num_channels=num_channels,
                                           side_dim=side_dim,
                                           hidden_dim=hidden_dim,
                                           lr=0.001,
                                           gamma=gamma)

In [49]:
party_datasets, party_labels, candidate_dataset, candidate_labels = get_data_raw(dataset=dataset,
                                                                                 num_classes=10,
                                                                                 party_data_size=party_data_size,
                                                                                 candidate_data_size=candidate_data_size,
                                                                                 split=split)

In [50]:
def get_features(model, data, hidden_dim, batch_size=256):
    # data now is in format NHWC, model requires NCHW
    n = len(data)
    data = torch.tensor(np.transpose(data, [0, 3, 1, 2]))
    features = np.zeros((n, hidden_dim))
    for i in range(int(np.ceil(n / batch_size))):
        start = i * batch_size
        end = (i+1) * batch_size
        features[start:end] = model.encoder(data[start:end]).detach().numpy()
    return features

In [51]:
feats = get_features(model, candidate_dataset, hidden_dim)

In [52]:
party_features = np.array([get_features(model, data, hidden_dim) for data in party_datasets])

In [53]:
candidate_features = get_features(model, candidate_dataset, hidden_dim)

In [54]:
np.save(open("data/{}/{}-gamma{}-party_features.npy".format(dataset, split, gamma), "wb"), party_features)

In [55]:
np.save(open("data/{}/{}-gamma{}-party_labels.npy".format(dataset, split, gamma), "wb"), party_labels)

In [56]:
np.save(open("data/{}/{}-gamma{}-cand_features.npy".format(dataset, split, gamma), "wb"), candidate_features)

In [57]:
np.save(open("data/{}/{}-gamma{}-cand_labels.npy".format(dataset, split, gamma), "wb"), candidate_labels)