In [10]:
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
import tqdm

device = torch.device('cuda')

import torchvision.datasets as dset
import torchvision.transforms as transforms
import torchvision.utils as vutils

from torch.optim import SGD, Adam
from torch.optim.lr_scheduler import MultiStepLR

train_dataset = dset.MNIST(".data/mnist", train=True, download=True, 
    transform=transforms.Compose([transforms.Resize(64), transforms.ToTensor(), transforms.Normalize([0.1307], [0.3081])]), )

test_dataset = dset.MNIST(".data/mnist", train=False, download=True, 
    transform=transforms.Compose([transforms.Resize(64), transforms.ToTensor(), transforms.Normalize([0.1307], [0.3081])]), )

n_participants = 5
n_samples = 2000 * n_participants
split_mode = 'disjointclasses'

from utils.utils import split
train_indices_list = split(n_samples, n_participants, train_dataset=train_dataset, mode=split_mode)


from collections import Counter
class_counters = []
for i, train_loader in enumerate(train_loaders):
    class_counter = Counter()
    for data, target in train_loader:
        temp = Counter(target.tolist())
        class_counter.update(temp)
    class_counters.append(class_counter)
#     print(class_counter)

from torch.utils.data.sampler import SubsetRandomSampler
from torch.utils.data import DataLoader

batch_size = 64

train_loaders = [DataLoader(dataset=train_dataset, batch_size=batch_size, sampler=SubsetRandomSampler(train_indices)) for train_indices in train_indices_list]

import itertools
train_indices = list(itertools.chain.from_iterable(train_indices_list))
joint_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, sampler=SubsetRandomSampler(train_indices))
test_loader = DataLoader(dataset=test_dataset, batch_size=10000, shuffle=True)

from itertools import repeat

def repeater(data_loader):
    for loader in repeat(data_loader):
        for data in loader:
            yield data

repeated_train_loaders = [repeater(train_loader) for train_loader in train_loaders  ]


from math import factorial as fac
def falling_fac(n, b):
    """
    Return the product of n..n-b+1.

    >>> falling_factorial(4, 2)  # 4*3
    12
    >>> falling_factorial(5, 3)  # 5*4*3
    60
    >>> falling_factorial(56, 1)
    56
    >>> falling_factorial(56, 0)
    1
    
    r = 1  # Running product
    for i in range(n, n-b, -1):
        r *= i
    return r
    """
    return fac(n) // fac(n-b)


Using disjoint classes and partitioning the dataset to 5 participants with each having 2 classes.
participant id: 0 is getting [0, 1] classes.
participant id: 1 is getting [2, 3] classes.
participant id: 2 is getting [4, 5] classes.
participant id: 3 is getting [6, 7] classes.
participant id: 4 is getting [8, 9] classes.


In [26]:
def get_rhos(class_counters, C=10):
    '''
    sum of square of proportion of different class datapoints
    '''
    rhos = []
    for i, class_counter in enumerate(class_counter):
        total = sum(class_counter.values())
        for key, value in class_counter.items():
            rho += (1.0 * value / total)**2
        rhos.append(rho/C)
    return rhos

def get_varrhos(rhos, class_counters):
    '''
    area under the curve for given rhos
    '''
    varrhos = []
    for rho, class_counter in zip(rhos, class_counters):
        n = sum(class_counter.values())
        varrhos.append( sum([i * rho**(n-i) for i in range(n)]) )
    return varrhos


In [None]:
import gpytorch

# for MNIST 28*28
class MLP_MNIST(nn.Module):
	def __init__(self, in_dim=784, out_dim=2, device=None):
		super(MLP_MNIST, self).__init__()
		self.fc1 = nn.Linear(in_dim, 512)
		self.fc2 = nn.Linear(512, 128)
		self.fc3 = nn.Linear(128, 32)
		self.fc4 = nn.Linear(32, out_dim)

	def forward(self, x):
		x = x.view(-1,  784)
		x = F.relu(self.fc1(x))
		x = F.relu(self.fc2(x))
		x = F.relu(self.fc3(x))
		x = self.fc4(x)
		return x
		# return F.log_softmax(x, dim=1)


class GaussianProcessLayer(gpytorch.models.ApproximateGP):
    def __init__(self, num_dim, grid_bounds=(-10., 10.), grid_size=64):
        variational_distribution = gpytorch.variational.CholeskyVariationalDistribution(
            num_inducing_points=grid_size, batch_shape=torch.Size([num_dim])
        )
        
        # Our base variational strategy is a GridInterpolationVariationalStrategy,
        # which places variational inducing points on a Grid
        # We wrap it with a IndependentMultitaskVariationalStrategy so that our output is a vector-valued GP
        variational_strategy = gpytorch.variational.IndependentMultitaskVariationalStrategy(
            gpytorch.variational.GridInterpolationVariationalStrategy(
                self, grid_size=grid_size, grid_bounds=[grid_bounds],
                variational_distribution=variational_distribution,
            ), num_tasks=num_dim,
        )
        super().__init__(variational_strategy)
        
        self.covar_module = gpytorch.kernels.ScaleKernel(
            gpytorch.kernels.RBFKernel(
                lengthscale_prior=gpytorch.priors.SmoothedBoxPrior(
                    math.exp(-1), math.exp(1), sigma=0.1, transform=torch.exp
                )
            )
        )
        self.mean_module = gpytorch.means.ConstantMean()
        self.grid_bounds = grid_bounds

    def forward(self, x):
        mean = self.mean_module(x)
        covar = self.covar_module(x)
        return gpytorch.distributions.MultivariateNormal(mean, covar)

num_features = 5
num_classes = 10
feature_extractor = MLP_MNIST(out_dim=num_features,device=device)

def mmd(X, Y, k):
    """
    Calculates unbiased MMD^2. A, B and C are the pairwise-XX, pairwise-XY, pairwise-YY summation terms respectively.
    :param X: array of shape (n, d)
    :param Y: array of shape (m, d)
    :param k: GPyTorch kernel
    :return: MMD^2, A, B, C
    """
    n = X.shape[0]
    m = Y.shape[0]

    X_tens = X.clone().detach().requires_grad_(True)
    Y_tens = Y.clone().detach().requires_grad_(True)
                
    A = (1 / (n * (n - 1))) * (torch.sum(k(X_tens).evaluate()) - torch.sum(torch.diag(k(X_tens).evaluate())))
    B = -(2 / (n * m)) * torch.sum(k(X_tens, Y_tens).evaluate())
    C = (1 / (m * (m - 1))) * (torch.sum(k(Y_tens).evaluate()) - torch.sum(torch.diag(k(Y_tens).evaluate())))

    Kxy  = k(X_tens, Y_tens).evaluate()
    Kxx_ = k(X_tens, X_tens).evaluate()
    Kxx_.fill_diagonal_(0)
    
    Kyy_ = k(Y_tens, Y_tens).evaluate()
    Kyy_.fill_diagonal_(0)

    return (A + B + C), Kxx_, Kxy, Kyy_


# from utils.mmd import mmd

from torch.linalg import norm

def t_statistic(mmd_2, Kxx_, Kxy, Kyy_):
    
    """
    Kxy[ij] = k(X_i, Y_i)

    Kxx_[ij] =  0 if i == j
                k(X_i, X_j) o/w
    Kyy_[ij] is similar

    fro_norm = torch.linalg.norm(matrix, ord='fro')
    
    """
    m = Kxx_.size(0)
    ex = torch.ones(Kxx_.size(0), device = Kxx_.device)
    ey = torch.ones(Kyy_.size(0), device = Kyy_.device)

    vhat = 0

    #1st term
    constant = 4 / falling_fac(m, 4)
    a = torch.square(norm(Kxx_ @ ex)) + torch.square(norm(Kyy_ @ ey))
    vhat += constant * a
    
    #2nd term
    constant = 4*(m**2 - m - 1) / (m**3 * (m - 1)**2)
    a = torch.square(norm(Kxy @ ey)) + torch.square(norm(Kxy.T @ ex))
    vhat += constant * a

    # 3rd term
    constant = - 8/ (m**2 * (m**2 - 3 * m + 2))   
    a =  ex.T @ Kxx_ @ Kxy @ ey + ey.T @ Kyy_ @ Kxy.T @ ex
    vhat += constant * a

    # 4th term
    constant = 8 / (m**2 * falling_fac(m, 3))
    a = (ex.T @ Kxx_ @ ex + ey.T @ Kyy_ @ ey) * (ex.T @ Kxy @ ey) 
    vhat += constant * a

    #5th term
    constant = - 2*(2*m -3)/(falling_fac(m, 2) * falling_fac(m, 4))
    a = torch.square(ex.T @ Kxx_ @ ex) + torch.square(ey.T @ Kyy_ @ ey)
    vhat += constant * a

    #6th term
    constant = -4 * (2*m - 3) / (m**3 * (m - 1)**3)
    a = torch.square(ex.T @ Kxy @ ey)
    vhat += constant * a

    #7th term
    constant = - 2/ (m* ( m**3 - 6 * m**2 + 11*m - 6 ))
    a = torch.square(norm(Kxx_, ord='fro')) + torch.square(norm(Kyy_, ord='fro'))
    vhat += constant * a

    #8th term
    constant = 4 * (m-2) / (m**2 *(m-1)**3)
    a = torch.square(norm(Kxy, ord='fro'))
    vhat += constant * a
        
    if vhat < 0:
        print('vhat is negative:', vhat.item())
        print('this leads to NaN at t_stat')

    return torch.div(mmd_2, torch.sqrt(vhat))

class DKLModel(gpytorch.Module):
    def __init__(self, feature_extractor, num_dim, grid_bounds=(-10., 10.)):
        super(DKLModel, self).__init__()
        self.feature_extractor = feature_extractor
        self.gp_layer = GaussianProcessLayer(num_dim=num_dim, grid_bounds=grid_bounds)
        self.grid_bounds = grid_bounds
        self.num_dim = num_dim
    
    def forward(self, x1, x2):        
        features1 = self.get_features(x1)
        features2 = self.get_features(x2)        
        mmd_2, Kxx_, Kxy, Kyy_ = mmd(features1.reshape(len(x1), -1), features2.reshape(len(x2), -1), k=self.gp_layer.covar_module)
        t_stat = t_statistic(mmd_2, Kxx_, Kxy, Kyy_)        
        return t_stat, self.gp_layer(features1), self.gp_layer(features2)
    
    def get_features(self, x):
        features = self.feature_extractor(x)
        features = gpytorch.utils.grid.scale_to_bounds(features, self.grid_bounds[0], self.grid_bounds[1])
        # This next line makes it so that we learn a GP for each feature
        features = features.transpose(-1, -2).unsqueeze(-1)
        return features


In [None]:
# need a kernel collectively defined by gpytorch and a DNN
def train_kernel(epoch):
    model.train()
    likelihood.train()

    joint_minibatch_iter = tqdm.notebook.tqdm(joint_loader, desc=f"(Epoch {epoch}) Minibatch")
    loaders = [joint_minibatch_iter] + repeated_train_loaders
    with gpytorch.settings.num_likelihood_samples(8):
        
        for data in zip(*loaders):
            # data is of length 6 [(data, target), (data1, target1)... (data5, target5)]
            data = list(data)
            data_j, target_j = data.pop(0)
            
            if torch.cuda.is_available():
                data_j, target_j = data_j.cuda(), target_j.cuda()
                for i in range(n_participants):
                    data[i][0], data[i][1] = data[i][0].cuda(), data[i][1].cuda()    
            optimizer.zero_grad()
            
            mmd_loss, likelihood_loss = 0, 0
            for i in range(n_participants):
                t_stat, res_j, res_i = model(data_j, data[i][0])
                mmd_loss += -t_stat
                if torch.isnan(t_stat):
                    print("got nan value, t_stat is nan")
                    return
            
                if i == 0:
                    # increment the likelihood loss for joint data only once
                    likelihood_loss += -mll(res_j, target_j)
    
                likelihood_loss += -mll(res_i, data[i][1])
            loss = mmd_loss + likelihood_loss
            
            '''
            for mmd_loss, need to consider over the entire data_loader and not the batches
            to do so, use the mmd_update method
            
            as a result, if we ignore likelihood_loss altogether,
            we can incrementally update the loss over the data_loader and only do step() 
            at the last batch
            
            however, if we wish to include the likelihood_loss, we need to more carefully design the loss function
            
            '''
            
            loss.backward()
            optimizer.step()
            joint_minibatch_iter.set_postfix(loss=loss.item())



In [None]:
for epoch in range(50):        
    train_kernel(epoch)

In [None]:
model = DKLModel(feature_extractor, num_dim=num_features)
likelihood = gpytorch.likelihoods.SoftmaxLikelihood(num_features=model.num_dim, num_classes=num_classes)

# If you run this example without CUDA, I hope you like waiting!
if torch.cuda.is_available():
    model = model.cuda()
    likelihood = likelihood.cuda()

from torch.optim import SGD, Adam
from torch.optim.lr_scheduler import MultiStepLR

n_epochs = 20
lr = 0.1
optimizer = SGD([
    {'params': model.feature_extractor.parameters(), 'weight_decay': 1e-4},
    {'params': model.gp_layer.hyperparameters(), 'lr': lr * 0.01},
    {'params': model.gp_layer.variational_parameters()},
    {'params': likelihood.parameters()},
], lr=lr, momentum=0.9, nesterov=True, weight_decay=0)
scheduler = MultiStepLR(optimizer, milestones=[0.5 * n_epochs, 0.75 * n_epochs], gamma=0.1)
mll = gpytorch.mlls.VariationalELBO(likelihood, model.gp_layer, num_data=len(joint_loader.dataset))

In [None]:

ngpu = 1
nz = 100
ngf = 64
ndf = 64
nc = 1

class Generator(nn.Module):
    def __init__(self, ngpu):
        super(Generator, self).__init__()
        self.ngpu = ngpu
        self.main = nn.Sequential(
            # input is Z, going into a convolution
            nn.ConvTranspose2d(     nz, ngf * 8, 4, 1, 0, bias=False),
            nn.BatchNorm2d(ngf * 8),
            nn.ReLU(True),
            # state size. (ngf*8) x 4 x 4
            nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 4),
            nn.ReLU(True),
            # state size. (ngf*4) x 8 x 8
            nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 2),
            nn.ReLU(True),
            # state size. (ngf*2) x 16 x 16
            nn.ConvTranspose2d(ngf * 2,     ngf, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf),
            nn.ReLU(True),
            # state size. (ngf) x 32 x 32
            nn.ConvTranspose2d(    ngf,      nc, 4, 2, 1, bias=False),
            nn.Tanh()
            # state size. (nc) x 64 x 64
        )

    def forward(self, input):
        if input.is_cuda and self.ngpu > 1:
            output = nn.parallel.data_parallel(self.main, input, range(self.ngpu))
        else:
            output = self.main(input)
        return output

class Discriminator(nn.Module):
    def __init__(self, ngpu):
        super(Discriminator, self).__init__()
        self.ngpu = ngpu
        self.main = nn.Sequential(
            # input is (nc) x 64 x 64
            nn.Conv2d(nc, ndf, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (ndf) x 32 x 32
            nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 2),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (ndf*2) x 16 x 16
            nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 4),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (ndf*4) x 8 x 8
            nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 8),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (ndf*8) x 4 x 4
            nn.Conv2d(ndf * 8, 1, 4, 1, 0, bias=False),
            nn.Sigmoid()
        )

    def forward(self, input):
        if input.is_cuda and self.ngpu > 1:
            output = nn.parallel.data_parallel(self.main, input, range(self.ngpu))
        else:
            output = self.main(input)

        return output.view(-1, 1).squeeze(1)


In [None]:
# custom weights initialization called on netG and netD
def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        m.weight.data.normal_(0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        m.weight.data.normal_(1.0, 0.02)
        m.bias.data.fill_(0)

# load the models
# from models.Conditional_DCGAN_MNIST import Discriminator, Generator

def train_individual_gan(data_loader, ngpu=1, nz=100, nclass=10, nepochs=50, lr=0.0002, beta1=0.5, device=None):
    netG = Generator(ngpu=1).to(device)
    netD = Discriminator(ngpu=1).to(device)
    
    netG.apply(weights_init)
    netD.apply(weights_init)

    criterion = nn.BCELoss()

    real_label, fake_label = 1, 0

    # setup optimizer
    optimizerD = Adam(netD.parameters(), lr=lr, betas=(beta1, 0.999))
    optimizerG = Adam(netG.parameters(), lr=lr, betas=(beta1, 0.999))

    epochs = range(nepochs)
    epochs = tqdm.notebook.tqdm(epochs, desc=f" Epoch")
    for epoch in epochs:
        for i, (data, target) in enumerate(data_loader, 0):

            ############################
            # (1) Update D network: maximize log(D(x)) + log(1 - D(G(z)))
            ###########################
            # train with real
            netD.zero_grad()
            real_data = data.to(device)
            batch_size = real_data.size(0)

#             digit_target = torch.nn.functional.one_hot(target, nclass)
#             digit_target = digit_target.unsqueeze(2).unsqueeze(3)
#             digit_target = digit_target.expand([batch_size, digit_target.size(1), 28, 28]).to(device)
            
            label = torch.full((batch_size,), real_label, dtype=real_data.dtype, device=device)
#             output = netD(real_data, digit_target)
            output = netD(real_data)
            errD_real = criterion(output, label)
            errD_real.backward()
            D_x = output.mean().item()


            # train with fake
            noise = torch.randn(batch_size, nz, 1, 1, device=device)
            G_target = torch.nn.functional.one_hot(target, nclass)
            G_target = G_target.unsqueeze(2).unsqueeze(3).to(device)
#             fake = netG(noise, G_target)
            fake = netG(noise)
            label.fill_(fake_label)
#             output = netD(fake.detach(), digit_target)
            output = netD(fake.detach())
            errD_fake = criterion(output, label)
            errD_fake.backward()
            D_G_z1 = output.mean().item()
            errD = errD_real + errD_fake
            optimizerD.step()

            ############################
            # (2) Update G network: maximize log(D(G(z)))
            ###########################
            netG.zero_grad()
            label.fill_(real_label)  # fake labels are real for generator cost
#             output = netD(fake, digit_target)
            output = netD(fake)
            errG = criterion(output, label)
            errG.backward()
            D_G_z2 = output.mean().item()
            optimizerG.step()

#         print('[%d/%d][%d/%d] Loss_D: %.4f Loss_G: %.4f D(x): %.4f D(G(z)): %.4f / %.4f' 
#             % (epoch+1, nepochs, i+1, len(data_loader),
#                 errD.item(), errG.item(), D_x, D_G_z1, D_G_z2))

    return netG, netD

In [None]:
Gs = []
for i, train_loader in enumerate(train_loaders):
    print("{}-participant".format(i+1), end='  ')
    G, D =  train_individual_gan(train_loader, device=device, lr=0.001, nclass=10, nepochs=50)
    Gs.append(G)

In [None]:

import matplotlib.pyplot as plt

class_number = 10
latent_size = 100

batch_size = 10
fixed_noise = torch.randn(batch_size, latent_size, 1, 1)

for i, G in enumerate(Gs):
    labels = torch.tensor( [i, i+1] * (batch_size//2 ),   dtype=int).reshape(batch_size)
#     G_target = torch.nn.functional.one_hot(labels, class_number)
#     G_target = G_target.unsqueeze(2).unsqueeze(3)

    fixed_noise = torch.randn(batch_size, latent_size, 1, 1)
    if torch.cuda.is_available():
        G = G.cuda()
        fixed_noise = fixed_noise.cuda()
#         G_target = G_target.cuda()
    fake_images = G(fixed_noise)
#     fake_images = G(fixed_noise, G_target)

    fake_images_np = fake_images.cpu().detach().numpy()
    fake_images_np = fake_images_np.reshape(fake_images_np.shape[0], 28, 28)
    R, C = 5, 2
    for i in range(batch_size):
        plt.subplot(R, C, i + 1)
        plt.imshow(fake_images_np[i], cmap='gray')
    plt.show()
    

In [None]:
def train(epoch):
    model.train()
    likelihood.train()

    minibatch_iter = tqdm.notebook.tqdm(joint_loader, desc=f"(Epoch {epoch}) Minibatch")
    with gpytorch.settings.num_likelihood_samples(8):
        for data, target in minibatch_iter:
            if torch.cuda.is_available():
                data, target = data.cuda(), target.cuda()
            optimizer.zero_grad()
            _,output, _ = model(data, data)
            print('output:', output)
            print('target:',target)
            loss = -mll(output, target)
            loss.backward()
            optimizer.step()
            minibatch_iter.set_postfix(loss=loss.item())
        
def test():
    model.eval()
    likelihood.eval()

    correct = 0
    with torch.no_grad(), gpytorch.settings.num_likelihood_samples(16):
        for data, target in test_loader:
            if torch.cuda.is_available():
                data, target = data.cuda(), target.cuda()
            output = likelihood(model(data))  # This gives us 16 samples from the predictive distribution
            pred = output.probs.mean(0).argmax(-1)  # Taking the mean over all of the sample we've drawn
            correct += pred.eq(target.view_as(pred)).cpu().sum()
    print('Test set: Accuracy: {}/{} ({}%)'.format(
        correct, len(test_loader.dataset), 100. * correct / float(len(test_loader.dataset))
    ))



for epoch in range(1, n_epochs + 1):
    with gpytorch.settings.use_toeplitz(False):
        train(epoch)
        test()
    scheduler.step()
    # state_dict = model.state_dict()
    # likelihood_state_dict = likelihood.state_dict()
