In [None]:
!git clone https://github.com/pytorch/opacus.git
!cd opacus && pip install -e . --user

In [None]:
pip install opacus==0.10.0

In [None]:

import time as t
import os
#from utils.tensorboard_logger import Logger
from itertools import chain
#from torchvision import utils

SAVE_PER_TIMES = 100

In [None]:
import torch.nn as nn
import torch
import numpy as np
import matplotlib.pyplot as plt
import math
#from torchvision.utils import make_grid
from torch.nn import functional as F

from scipy.linalg import sqrtm
#from opacus import PrivacyEngine
%matplotlib inline

# for auto-reloading external modules
# see http://stackoverflow.com/questions/1907993/autoreload-of-modules-in-ipython
%load_ext autoreload
%autoreload 2

In [None]:
def load_train_data():
    X = []
    for i in range(5):
        X_, _ = unpickle('data/cifar-10-batches-py/data_batch_%d' % (i + 1))
        X.append(X_)
    X = np.concatenate(X)
    X = X.reshape((X.shape[0], 3, 32, 32))
    return X

def load_test_data():
    X_, _ = unpickle('data/cifar-10-batches-py/test_batch')
    X = X_.reshape((X_.shape[0], 3, 32, 32))
    return X

def set_seed(seed):
    np.random.seed(seed)

def ScaledSamples(N,size = 256):
  samples = np.random.randn(N,size)
  scales = np.random.randn(size)/10
  biases = np.random.randn(size)
  new = samples * np.expand_dims(scales,0) + np.expand_dims(biases,0)
  return new,scales,biases

def ScaledExtremeSamples(N,size = 256):
  samples = np.random.randn(N,size)
  scales = np.random.randn(size)*100
  biases = np.random.randn(size)
  new = samples * np.expand_dims(scales,0) + np.expand_dims(biases,0)
  return new

def GenerateScaledSamples(N,size = 256):
  samples = ScaledSamples(N,size)
  sig_samp = 1/(1+np.exp(-1*samples))
  return sig_samp

def GenerateScaledExtremeSamples(N,size = 256):
  samples = ScaledExtremeSamples(N,size)
  sig_samp = 1/(1+np.exp(-1*samples))
  return sig_samp

def GenerateBoxedSamples(N,size=256,boxes=2): #only works 16x16 for right now
  samples = ScaledSamples(N,size)
  tbt = [] #3x3 box
  for i in range(3):
    for j in range(3):
      tbt.append((i-1,j-1))

  for n in range(N):
    for b in range(boxes):
      place = np.random.randint(size)
      sign = np.random.randint(2)*2-1
      for dir in tbt:
        index = place +dir[0] + dir[1] * 16
        if index>=0 and index<size:
          samples[n,index] += sign * 10
  sig_samp = 1/(1+np.exp(-1*samples))
  return sig_samp

def MixtureGaussian(N,K):
  means = np.zeros((K,2))
  for k in range(K):
    theta = k / K * 2 * np.pi
    means[k] = [5*np.cos(theta),5*np.sin(theta)]
  indices = np.random.randint(K,size=(N,))
  noise = np.random.randn(N,2)
  randomcenters = means[indices]
  return randomcenters+noise




# Use this to put tensors on GPU/CPU automatically when defining tensors
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') 
#device = torch.device('cpu') 


In [None]:
#train_samples = GenerateScaledSamples(30000)
#train_samples = GenerateScaledExtremeSamples(30000)
#train_samples = GenerateBoxedSamples(30000)
#train_samples = GenerateScaledSamples(1000,size=2)
train_samples,scales,biases = ScaledSamples(1000,size=2)
true_means = biases
true_vars = np.diag(scales)
print(true_means)
print(true_vars)

In [None]:
#fig = plt.figure(figsize = (8, 8))   
#ax1 = plt.subplot(111)
#ax1.imshow(make_grid(torch.from_numpy(train_samples[0:64]).view(-1,1,16,16), padding=1).numpy().transpose((1, 2, 0)))
#plt.show()
plt.scatter(train_samples[:,0],train_samples[:,1])
plt.show()

In [None]:
train_samples2 = MixtureGaussian(2000,3)/10
train_samples3 = MixtureGaussian(2000,8)/10

plt.scatter(train_samples2[:,0],train_samples2[:,1])
plt.show()
plt.scatter(train_samples3[:,0],train_samples3[:,1])
plt.show()



In [None]:
from torch.optim.optimizer import Optimizer, required


class SGDPriv(Optimizer):


    def __init__(self, params, lr=required, momentum=0, dampening=0,
                 weight_decay=0, nesterov=False,
                 noise_mult = 1.0, max_grad_norm = 1.0, batch_size = 1):
        if lr is not required and lr < 0.0:
            raise ValueError("Invalid learning rate: {}".format(lr))
        if momentum < 0.0:
            raise ValueError("Invalid momentum value: {}".format(momentum))
        if weight_decay < 0.0:
            raise ValueError("Invalid weight_decay value: {}".format(weight_decay))

        defaults = dict(lr=lr, momentum=momentum, dampening=dampening,
                        weight_decay=weight_decay, nesterov=nesterov,
                        noise_mult=noise_mult, max_grad_norm=max_grad_norm, batch_size=batch_size)
        if nesterov and (momentum <= 0 or dampening != 0):
            raise ValueError("Nesterov momentum requires a momentum and zero dampening")
        super(SGDPriv, self).__init__(params, defaults)

    def __setstate__(self, state):
        super(SGDPriv, self).__setstate__(state)
        for group in self.param_groups:
            group.setdefault('nesterov', False)

    @torch.no_grad()
    def step(self, closure=None):
        """Performs a single optimization step.
        Arguments:
            closure (callable, optional): A closure that reevaluates the model
                and returns the loss.
        """
        loss = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        for group in self.param_groups:
            weight_decay = group['weight_decay']
            momentum = group['momentum']
            dampening = group['dampening']
            nesterov = group['nesterov']
            
            nm = group['noise_mult']
            mg = group['max_grad_norm']
            bs = group['batch_size']

            for p in group['params']:
                if p.grad is None:
                    continue
                d_p = p.grad
                #print(d_p.shape)
                #print(d_p)
                if nm*mg>0:
                    noise = torch.normal(0.0, nm*mg, d_p.shape, device=device)
                else:
                    noise = torch.zeros(d_p.shape, device=device)
                d_p += noise
                torch.clamp(d_p,-mg,mg)
                #print(d_p)
                #print()
                
                if weight_decay != 0:
                    d_p = d_p.add(p, alpha=weight_decay)
                if momentum != 0:
                    param_state = self.state[p]
                    if 'momentum_buffer' not in param_state:
                        buf = param_state['momentum_buffer'] = torch.clone(d_p).detach()
                    else:
                        buf = param_state['momentum_buffer']
                        buf.mul_(momentum).add_(d_p, alpha=1 - dampening)
                    if nesterov:
                        d_p = d_p.add(buf, alpha=momentum)
                    else:
                        d_p = buf

                p.add_(d_p, alpha=-group['lr'])

        return loss

In [None]:
import torch
from torch.optim.optimizer import Optimizer

from torch import Tensor
from typing import List

def adam(params: List[Tensor],
         grads: List[Tensor],
         exp_avgs: List[Tensor],
         exp_avg_sqs: List[Tensor],
         max_exp_avg_sqs: List[Tensor],
         state_steps: List[int],
         amsgrad: bool,
         beta1: float,
         beta2: float,
         lr: float,
         weight_decay: float,
         eps: float):
    r"""Functional API that performs Adam algorithm computation.
    See :class:`~torch.optim.Adam` for details.
    """

    for i, param in enumerate(params):

        grad = grads[i]
        exp_avg = exp_avgs[i]
        exp_avg_sq = exp_avg_sqs[i]
        step = state_steps[i]
        if amsgrad:
            max_exp_avg_sq = max_exp_avg_sqs[i]

        bias_correction1 = 1 - beta1 ** step
        bias_correction2 = 1 - beta2 ** step

        if weight_decay != 0:
            grad = grad.add(param, alpha=weight_decay)

        # Decay the first and second moment running average coefficient
        exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
        exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
        if amsgrad:
            # Maintains the maximum of all 2nd moment running avg. till now
            torch.maximum(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq)
            # Use the max. for normalizing running avg. of gradient
            denom = (max_exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(eps)
        else:
            denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(eps)

        step_size = lr / bias_correction1

        param.addcdiv_(exp_avg, denom, value=-step_size)

class AdamPriv(Optimizer):
    r"""Implements Adam algorithm.
    It has been proposed in `Adam: A Method for Stochastic Optimization`_.
    The implementation of the L2 penalty follows changes proposed in
    `Decoupled Weight Decay Regularization`_.
    Arguments:
        params (iterable): iterable of parameters to optimize or dicts defining
            parameter groups
        lr (float, optional): learning rate (default: 1e-3)
        betas (Tuple[float, float], optional): coefficients used for computing
            running averages of gradient and its square (default: (0.9, 0.999))
        eps (float, optional): term added to the denominator to improve
            numerical stability (default: 1e-8)
        weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
        amsgrad (boolean, optional): whether to use the AMSGrad variant of this
            algorithm from the paper `On the Convergence of Adam and Beyond`_
            (default: False)
    .. _Adam\: A Method for Stochastic Optimization:
        https://arxiv.org/abs/1412.6980
    .. _Decoupled Weight Decay Regularization:
        https://arxiv.org/abs/1711.05101
    .. _On the Convergence of Adam and Beyond:
        https://openreview.net/forum?id=ryQu7f-RZ
    """

    def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8,
                 weight_decay=0, amsgrad=False,
                 noise_mult = 1.0, max_grad_norm = 1.0, batch_size = 1):
        if not 0.0 <= lr:
            raise ValueError("Invalid learning rate: {}".format(lr))
        if not 0.0 <= eps:
            raise ValueError("Invalid epsilon value: {}".format(eps))
        if not 0.0 <= betas[0] < 1.0:
            raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
        if not 0.0 <= betas[1] < 1.0:
            raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
        if not 0.0 <= weight_decay:
            raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
        defaults = dict(lr=lr, betas=betas, eps=eps,
                        weight_decay=weight_decay, amsgrad=amsgrad,
                        noise_mult=noise_mult, max_grad_norm=max_grad_norm, batch_size=batch_size)
        super(AdamPriv, self).__init__(params, defaults)

    def __setstate__(self, state):
        super(AdamPriv, self).__setstate__(state)
        for group in self.param_groups:
            group.setdefault('amsgrad', False)

    @torch.no_grad()
    def step(self, closure=None):
        """Performs a single optimization step.
        Arguments:
            closure (callable, optional): A closure that reevaluates the model
                and returns the loss.
        """
        loss = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        for group in self.param_groups:
            params_with_grad = []
            grads = []
            exp_avgs = []
            exp_avg_sqs = []
            state_sums = []
            max_exp_avg_sqs = []
            state_steps = []
            
            nm = group['noise_mult']
            mg = group['max_grad_norm']
            bs = group['batch_size']

            for p in group['params']:
                if p.grad is not None:
                    params_with_grad.append(p)
                    if p.grad.is_sparse:
                        raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead')
                    d_p = p.grad
                    #print(d_p)
                    #print(d_p.shape)
                    #print(device)
                    #print('ayaya')
                    if nm*mg>0:
                        #noise = torch.normal(mean=torch.zeros(1), std=nm*mg, size=d_p.shape).to(device)
                        noise = torch.randn(d_p.shape, device=device)*nm*mg
                    else:
                        noise = torch.zeros(d_p.shape, device=device)
                    d_p += noise
                    torch.clamp(d_p,-mg,mg)
                    grads.append(d_p)

                    state = self.state[p]
                    # Lazy state initialization
                    if len(state) == 0:
                        state['step'] = 0
                        # Exponential moving average of gradient values
                        state['exp_avg'] = torch.zeros_like(p)
                        # Exponential moving average of squared gradient values
                        state['exp_avg_sq'] = torch.zeros_like(p)
                        if group['amsgrad']:
                            # Maintains max of all exp. moving avg. of sq. grad. values
                            state['max_exp_avg_sq'] = torch.zeros_like(p)

                    exp_avgs.append(state['exp_avg'])
                    exp_avg_sqs.append(state['exp_avg_sq'])

                    if group['amsgrad']:
                        max_exp_avg_sqs.append(state['max_exp_avg_sq'])

                    # update the steps for each param group update
                    state['step'] += 1
                    # record the step after step update
                    state_steps.append(state['step'])

            beta1, beta2 = group['betas']
            adam(params_with_grad,
                   grads,
                   exp_avgs,
                   exp_avg_sqs,
                   max_exp_avg_sqs,
                   state_steps,
                   group['amsgrad'],
                   beta1,
                   beta2,
                   group['lr'],
                   group['weight_decay'],
                   group['eps']
                   )
        return loss

In [None]:

class SimpGen(nn.Module):
    def __init__(self, layers):
        super(SimpGen, self).__init__()
        
        self.length = len(layers)-1
        self.sizes = layers
        self.hiddens = nn.ModuleList()
        #self.norms = nn.ModuleList()
        for k in range(self.length):
          self.hiddens.append(  nn.Linear(layers[k],layers[k+1])  )
          #self.norms.append( nn.BatchNorm1d(layers[k+1] ))
        self.activation = nn.ReLU()

    def forward(self, x):
        h=x
        for k in range(self.length):
          h = self.hiddens[k](h)
          #h = self.norms[k](h)
          if k!=self.length-1:
            h = self.activation(h)
          else: #torch sigmoid for 2D example
            pass
            #h = self.activation(h) #JK that sucked (maybe)
            
            #h = torch.sigmoid(h)
            #h = 1 / (1+torch.exp(-1*h))
        return h

    def getGrad(self):
      grads = []
      for param in self.parameters():
        if param.grad is not None:
          grads.append(param.grad.view(-1).detach())
      grads = torch.cat(grads)
      #print(grads.shape)
      return grads


class SimpDisc(nn.Module):
    def __init__(self, layers):
        super(SimpDisc, self).__init__()
        
        self.length = len(layers)-1
        self.sizes = layers
        assert layers[self.length]==1
        self.hiddens = nn.ModuleList()
        #self.norms = nn.ModuleList()
        for k in range(self.length):
          self.hiddens.append(  nn.Linear(layers[k],layers[k+1])  )
          #self.norms.append( nn.BatchNorm1d(layers[k+1] ))
        self.activation = nn.ReLU()

    def forward(self, x):
        h=x
        for k in range(self.length):
          h = self.hiddens[k](h)
          if k>0:
            #h = self.norms[k](h)
            pass
          if k!=self.length-1:
            h = self.activation(h)
        return h
    
    def getGrad(self):
      grads = []
      for param in self.parameters():
        if param.grad is not None:
          grads.append(param.grad.view(-1).detach())
      grads = torch.cat(grads)
      #print(grads.shape)
      return grads




import torch.autograd as autograd
torch.autograd.set_detect_anomaly(True)

class SimpWGANGP(nn.Module):

    def __init__(self,dlayers,glayers,dloss,gloss,epochs,sample_size=None,compute_exact_w=False,privacy=False,nm=None,mg=None):
        super(SimpWGANGP, self).__init__()
        #self.num_epoch = 25
        self.num_epoch = epochs
        self.batch_size = 128
        self.log_step = 100 
        #self.visualize_step = 2 
        self.visualize_step = 20
        self.code_size = glayers[len(glayers)-1]
        self.g_learning_rate = gloss
        self.d_learning_rate = dloss
        self.vis_learning_rate = 1e-2
        
        self.compute_exact_w = compute_exact_w
        self.privacy = privacy
        
        # IID N(0, 1) Sample
        self.tracked_noise = torch.randn([64, self.code_size], device=device)        
        self._actmax_label = torch.ones([64, 1], device=device)


        
        #dlayers = [256, 128, 64, 1]
        #glayers = [256, 256, 256, 256]
        self._discriminator = SimpDisc(dlayers).to(device)
        self._generator = SimpGen(glayers).to(device)

        self._l2_loss = nn.MSELoss()
        self._classification_loss = nn.BCEWithLogitsLoss()

        betas = (0.5, 0.9)
        self._generator_optimizer = torch.optim.Adam(self._generator.parameters(),lr=self.g_learning_rate,betas=betas)
        self._discriminator_optimizer = torch.optim.Adam(self._discriminator.parameters(),lr=self.d_learning_rate,betas=betas)
        self._discriminator_optimizer = AdamPriv(self._discriminator.parameters(),
                                                             lr=self.d_learning_rate,betas=betas,
                                                noise_mult=nm,max_grad_norm=mg,batch_size=self.batch_size)

        #self._generator_optimizer = torch.optim.SGD(self._generator.parameters(),lr=self.g_learning_rate)
        #self._discriminator_optimizer = torch.optim.SGD(self._discriminator.parameters(),lr=self.d_learning_rate)
        #self._discriminator_optimizer = SGDPriv(self._discriminator.parameters(),lr=self.d_learning_rate,
        #                                        noise_mult=nm,max_grad_norm=mg,batch_size=self.batch_size)
        
        '''
        if self.privacy:
            self.privacy_engine = PrivacyEngine(
                self._discriminator,
                2*self.batch_size,
                sample_size,
                alphas=[10, 100],
                #noise_multiplier=1.3,
                noise_multiplier=nm,
                #max_grad_norm=1.0,
                max_grad_norm=mg,
            )
            self.privacy_engine.attach(self._discriminator_optimizer)
            
            
            self.gping = False
            for layer in self.modules():
                layer.gping = False
        '''

    def _loss(self, real_log, fake_log):
      D_fake = fake_log
      D_real = real_log
      loss = D_fake - D_real
      return loss
      
    def _reconstruction_loss(self, generated, target):
        return self._l2_loss(generated, target)

    def LikelihoodMOG(self, x, y):
        likeli = 0
        K=3
        for k in range(K):
            theta = k / K * 2 * np.pi
            mean = [5*np.cos(theta)/10,5*np.sin(theta)/10]
            like = np.exp( -1 * ((x-mean[0])**2 + (y-mean[1])**2) )
            #likeli += like
            if like>likeli:
                likeli = like
        likeli /= K
        return likeli
    
    def FullDistLikeliMOG(self):
        likeli_sum = 0
        TEST_N = 1000
        test_noise = torch.randn((TEST_N,2)).to(device)
        test_samples = self._generator(test_noise).cpu().detach().numpy()
        for n in range(TEST_N):
            samp = test_samples[n]
            likeli = self.LikelihoodMOG(samp[0],samp[1])
            likeli_sum += likeli
        return likeli_sum/TEST_N
    
    def calc_gradient_penalty(self, real_data, fake_data): #from wgan github
        LAMBDA = 10

        alpha = torch.rand(self.batch_size, 1)
        alpha = alpha.expand(real_data.size())
        alpha = alpha.to(device)

        interpolates = alpha * real_data + ((1 - alpha) * fake_data)
        interpolates = interpolates.to(device)

        interpolates = autograd.Variable(interpolates, requires_grad=True)

        #print('interpolating')
        disc_interpolates = self._discriminator(interpolates)
        
        #print('gradienting')
        self.gping = True
        for layer in self.modules():
            layer.gping = True
        gradients = autograd.grad(outputs=disc_interpolates, inputs=interpolates,
                                  grad_outputs=torch.ones(disc_interpolates.size()).to(device),
                                  create_graph=True, retain_graph=True, only_inputs=True,allow_unused=True)[0]
        #print('grad shape',gradients.shape)
        #print(gradients)
        #from torchviz import make_dot
        #make_dot(gradients).view()
        
        
        self.gping = False
        for layer in self.modules():
            layer.gping = False
        
        gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean() * LAMBDA
        return gradient_penalty

    # Training function
    def train(self, train_samples):
        num_train = train_samples.shape[0]
        step = 0
        
        # smooth the loss curve so that it does not fluctuate too much
        smooth_factor = 0.95
        smooth_factor = 0
        plot_dis_s = 0
        plot_gen_s = 0
        plot_ws = 0
        
        dis_losses = []
        gen_losses = []
        dis_grads = []
        gen_grads = []
        w_dists = []
        max_steps = int(self.num_epoch * (num_train // self.batch_size))
        fake_label = torch.zeros([self.batch_size, 1], device=device)
        real_label = torch.ones([self.batch_size, 1], device=device)
        self._generator.train()
        self._discriminator.train()
        print('Start training ...')
        for epoch in range(self.num_epoch):
            #print('EP',epoch)
            np.random.shuffle(train_samples)
            for i in range(num_train // self.batch_size):
                #print('batch #',i)
                step += 1

                batch_samples = train_samples[i * self.batch_size : (i + 1) * self.batch_size]
                batch_samples = torch.Tensor(batch_samples).to(device)

                ################################################################################
                # Prob 2-1: Train the discriminator on all-real images first                   #
                ################################################################################
                DISC_STEPS = 5
                for _ in range(DISC_STEPS):
                  self._discriminator_optimizer.zero_grad()

                  real_dis_out = self._discriminator(batch_samples)

                  noise =  torch.randn((self.batch_size,self.code_size),device=device)# IID Normal(0, 1)^d on the torch device
                  fake_samples = self._generator(noise)
                  fake_dis_out = self._discriminator(fake_samples.detach())

                  #print('DISC LOSS 1')
                  disc_loss1 = fake_dis_out - real_dis_out
                  disc_loss1 = disc_loss1.mean()
                  disc_loss1.backward(retain_graph=False)
                  #disc_loss1.backward(retain_graph=True)
                  
                  if self.privacy:
                      #print('GRAD PENALTY')
                      gp = self.calc_gradient_penalty(batch_samples,fake_samples.detach())
                      #print('take it back now')
                      gp.backward(retain_graph=True)

                      dis_grad = torch.norm(self._discriminator.getGrad())
                      self._discriminator_optimizer.step()
                      #dis_loss = disc_loss1.cpu().detach() + gp.cpu().detach()
                      dis_loss = disc_loss1.cpu().detach()
                      dis_losses.append(dis_loss)
                  else:
                      gp = self.calc_gradient_penalty(batch_samples,fake_samples.detach())
                      gp.backward(retain_graph=True)

                      dis_grad = torch.norm(self._discriminator.getGrad())
                      self._discriminator_optimizer.step()
                      dis_loss = disc_loss1.cpu().detach() + gp.cpu().detach()
                      dis_losses.append(dis_loss)
                    

                ################################################################################
                # Prob 2-1: Train the generator                                                #
                ################################################################################             
                self._generator_optimizer.zero_grad()

                fake_samples2 = self._generator(noise)
                fake_dis_out2 = self._discriminator(fake_samples2)

                gen_loss1 = -fake_dis_out2
                gen_loss1 = gen_loss1.mean()
                gen_loss1.backward(retain_graph=True)

                gen_grad = torch.norm(self._generator.getGrad())
                self._generator_optimizer.step()
                gen_loss = gen_loss1.cpu().detach()

                if self.compute_exact_w:
                  fake_vars =  (self._generator.hiddens[0].weight.cpu().detach().numpy())
                  fake_means = (self._generator.hiddens[0].bias.cpu().detach().numpy())
                  fake_sqrt = sqrtm(fake_vars)
                  wd1 = np.sum(np.square(true_means-fake_means))
                  wd2 = np.trace(  true_vars + fake_vars - 2 *sqrtm(np.matmul(fake_sqrt,np.matmul(true_vars,fake_sqrt))) ) 
#try again
                  wd1 = np.sum(np.square(true_means-fake_means))
                  wd2 = np.trace(  np.matmul(true_vars,true_vars) + np.matmul(fake_vars,fake_vars) - 2 *sqrtm(np.matmul(fake_vars,np.matmul(true_vars,np.matmul(true_vars,fake_sqrt)))) ) 
                  #print(wd1)
                  #print(wd2)
                  #print()
                  wd=wd1+wd2
                  w_dists.append(wd)
                
                ################################################################################
                #                               END OF YOUR CODE                               #
                ################################################################################
                
                plot_dis_s = plot_dis_s * smooth_factor + dis_loss * (1 - smooth_factor)
                plot_gen_s = plot_gen_s * smooth_factor + gen_loss * (1 - smooth_factor)
                plot_ws = plot_ws * smooth_factor + (1 - smooth_factor)
                #dis_losses.append(plot_dis_s / plot_ws)
                gen_losses.append(plot_gen_s / plot_ws)

                dis_grads.append(dis_grad)
                gen_grads.append(gen_grad)

                if step % self.log_step == 0:
                    print('Iteration {0}/{1}: dis loss = {2:.4f}, gen loss = {3:.4f}'.format(step, max_steps, dis_loss, gen_loss))
            if epoch % self.visualize_step == 0:
                if self.compute_exact_w:
                  plt.plot(w_dists)
                  plt.title('exact wasserstein distance')
                  plt.show()

                plt.plot(dis_losses)
                plt.title('discriminator loss')
                plt.xlabel('iterations')
                plt.ylabel('loss')
                plt.show()
    
                plt.plot(gen_losses)
                plt.title('generator loss')
                plt.xlabel('iterations')
                plt.ylabel('loss')
                plt.show()

                plt.plot(dis_grads)
                plt.title('discriminator grads')
                plt.xlabel('iterations')
                plt.ylabel('grad norm')
                plt.show()
    
                plt.plot(gen_grads)
                plt.title('generator grads')
                plt.xlabel('iterations')
                plt.ylabel('grad norm')
                plt.show()

                '''
                fig = plt.figure(figsize = (8, 8))   
                ax1 = plt.subplot(111)
                ax1.imshow(make_grid( \
                                     self._generator(self.tracked_noise.detach()).cpu().detach().view(-1,1,16,16), padding=1, normalize=True).numpy().transpose((1, 2, 0) \
                                                                                                                                                                 ))
                plt.show()
                '''

                A=-1;B=1;D=21;

                self._generator.eval()
                self._discriminator.eval()
                
                noise_np = self.tracked_noise.cpu().detach().numpy()
                plt.scatter(noise_np[:,0],noise_np[:,1],c='r')
                plt.show()
                #plt.scatter(train_samples[:64,0],train_samples[:64,1])
                torch_gen_samp = self._generator(self.tracked_noise.detach())
                generated_samples = torch_gen_samp.cpu().detach().numpy()
                #plt.scatter(generated_samples[:,0],generated_samples[:,1],c='r')
                true_beliefs = torch.sigmoid(self._discriminator(torch.Tensor(train_samples[:64]).to(device))).cpu().detach().numpy()
                fake_beliefs = torch.sigmoid(self._discriminator(torch_gen_samp)).cpu().detach().numpy()
                true_colors = np.zeros((64,3))
                fake_colors = np.zeros((64,3))
                true_colors[:,2] = true_beliefs[:,0]
                fake_colors[:,0] = fake_beliefs[:,0]

                #plt.subplot(141)
                plt.scatter(train_samples[:64,0],train_samples[:64,1],s=50.,c=true_colors,alpha=0.3)
                plt.scatter(generated_samples[:64,0],generated_samples[:64,1],s=50.,c=fake_colors,alpha=0.3)
                plt.xlim(A,B)
                plt.ylim(A,B)
                plt.gca().set_aspect('equal','box')
                plt.show()

                x=np.linspace(A,B,D);y=np.linspace(A,B,D);
                xx,yy=np.meshgrid(x,y)
                test_grid=np.array((xx.ravel(), yy.ravel())).T     
                test_beliefs = torch.sigmoid(self._discriminator(torch.Tensor(test_grid).to(device))).cpu().detach().numpy()
                test_colors = np.zeros((test_grid.shape[0],3))
                test_colors2 = np.zeros((test_grid.shape[0],3))
                test_colors3 = np.zeros((test_grid.shape[0],3))
                test_colors[:,1] = test_beliefs[:,0]
                tmax=np.max(test_beliefs);tmin=np.min(test_beliefs);
                test_colors2[:,1] = (test_beliefs[:,0]-tmin)/(tmax-tmin)
                test_colors3[:,1] = (test_beliefs[:,0]>.5).astype(float)

                

                #plt.subplot(142)
                plt.scatter(test_grid[:,0],test_grid[:,1],s=88,c=test_colors,marker='s',alpha=0.3)
                plt.gca().set_aspect('equal','box')
                plt.show()
                #plt.subplot(143)
                plt.scatter(test_grid[:,0],test_grid[:,1],s=88,c=test_colors2,marker='s',alpha=0.3)
                plt.gca().set_aspect('equal','box')
                plt.show()
                #plt.subplot(144)
                plt.scatter(test_grid[:,0],test_grid[:,1],s=88,c=test_colors3,marker='s',alpha=0.3)
                plt.scatter(train_samples[:64,0],train_samples[:64,1],s=50.,c='b',alpha=0.3)
                plt.scatter(generated_samples[:64,0],generated_samples[:64,1],s=50.,c='r',alpha=0.3)
                plt.gca().set_aspect('equal','box')
                plt.show()
                
                self._generator.train()
                self._discriminator.train()

        print('... Done!')

    # Find the reconstruction of a batch of samples
    def reconstruct(self, samples):
        recon_code = torch.zeros([samples.shape[0], self.code_size], device=device, requires_grad=True)
        samples = torch.tensor(samples, device=device, dtype=torch.float32)

        # Set the generator to evaluation mode, to make batchnorm stats stay fixed
        self._generator.eval()

        ################################################################################
        # Prob 2-4: complete the definition of the optimizer .                         #
        # skip this part when working on problem 2-1 and come back for problem 2-4     #
        ################################################################################
        
        # Use the vis learning rate
        recon_optimizer = torch.optim.Adam([recon_code], lr=self.vis_learning_rate) 
        
        for i in range(500):
            ################################################################################
            # Prob 2-4: Fill in the training loop for reconstruciton                       #
            # skip this part when working on problem 2-1 and come back for problem 2-4     #
            ################################################################################
            recon_optimizer.zero_grad()
            recon_samples = self._generator(recon_code)
            recon_loss = self._reconstruction_loss(recon_samples,samples)
            recon_loss.backward()
            recon_optimizer.step()
            ################################################################################
            #                               END OF YOUR CODE                               #
            ################################################################################
            
        return recon_loss, recon_samples.detach().cpu()
        

    # Perform activation maximization on a batch of different initial codes
    def actmax(self, actmax_code):
        self._generator.eval()
        self._discriminator.eval() 
        ################################################################################
        # Prob 2-4: check this function                                                #
        # skip this part when working on problem 2-1 and come back for problem 2-4     #
        ################################################################################
        actmax_code = torch.tensor(actmax_code, device=device, dtype=torch.float32, requires_grad=True)
        actmax_optimizer = torch.optim.Adam([actmax_code], lr=self.vis_learning_rate) 
        for i in range(500):
            actmax_optimizer.zero_grad()
            actmax_sample = self._generator(actmax_code)
            actmax_dis = self._discriminator(actmax_sample)
            actmax_loss = self._loss(actmax_dis, self._actmax_label)
            actmax_loss.backward()
            actmax_optimizer.step()
        return actmax_sample.detach().cpu()

In [None]:
set_seed(42)

EP = 2000

dlayers = [2,8,6,4,1]
dloss = 1e-3
glayers = [2,2]
gloss = 3e-2

dlayers = [2,32,16,16,1]
dlayers = [2,16,16,1]
dloss = 3e-4

nm = 0.0
mg = 1.0
#mg  = 1000.
#from opacus.autograd_grad_sample import disable_hooks
#disable_hooks()

#simpgan = SimpGAN(dlayers, glayers,dloss,gloss,EP)
#simpgan = SimpWGANGP(dlayers, glayers,dloss,gloss,EP,
#                     sample_size=train_samples.shape[0],compute_exact_w=True,privacy=True,nm=nm,mg=mg)
for nm in [0.0, 0.0001, 0.001, 0.01, 0.1, 1.0]:
    simpgan = SimpWGANGP(dlayers, glayers,dloss,gloss,EP,
                         sample_size=train_samples.shape[0],compute_exact_w=True,privacy=False,nm=nm,mg=mg)
    simpgan.train(train_samples)
    torch.save(simpgan.state_dict(), "simpgan_"+str(nm)+".pth")
    print('END')
    print()
    print()

In [None]:
dlayers = [2,32,16,16,1]
dloss = 1e-4
glayers = [2,32,16,4,2]
gloss = 1e-4
EP = 2000
dlayers = [2,256,256,256,1]
dloss = 1e-4
glayers = [2,256,256,256,2]
gloss = 1e-4
EP = 10000

dlayers = [2,64,64,64,1]
dloss = 1e-4
glayers = [2,16,16,2]
gloss = 1e-4
EP = 2000

mg = 1.0
#simpgan = SimpGAN(dlayers, glayers,dloss,gloss,EP)
#simpgan = SimpWGANGP(dlayers, glayers,dloss,gloss,EP,nm=nm,mg=mg)
#simpgan.train(train_samples2)
#torch.save(simpgan.state_dict(), "simpwgangp_mog.pth")
for nm in [0.0, 0.0001, 0.001, 0.01, 0.1, 1.0]:
    simpgan = SimpWGANGP(dlayers, glayers,dloss,gloss,EP,
                         sample_size=train_samples.shape[0],compute_exact_w=False,privacy=False,nm=nm,mg=mg)
    simpgan.train(train_samples2)
    torch.save(simpgan.state_dict(), "simpgan_mog_"+str(nm)+".pth")
    print('END MIXED')
    print()
    print()
    
    









In [None]:
dlayers = [2,64,64,64,1]
dloss = 1e-4
glayers = [2,16,16,2]
gloss = 1e-4
EP = 1000

mg = 1.0

for nm in [ 10, 0.0001, 0.001, 0.01, 0.1, 1.0]:
    simpgan = SimpWGANGP(dlayers, glayers,dloss,gloss,EP,
                         sample_size=train_samples2.shape[0],compute_exact_w=False,privacy=False,nm=nm,mg=mg)
    simpgan.train(train_samples2)
    torch.save(simpgan.state_dict(), "simpgan_mog_"+str(nm)+".pth")
    print('END MIXED')
    print()
    print()
    

In [None]:
dlayers = [2,64,64,64,1]
dloss = 1e-4
glayers = [2,16,16,2]
gloss = 1e-4
EP = 1000

mg = 1.0

for nm in [  1.0, 0.1, 0.01]:
    simpgan = SimpWGANGP(dlayers, glayers,dloss,gloss,EP,
                         sample_size=train_samples2.shape[0],compute_exact_w=False,privacy=False,nm=nm,mg=mg)
    simpgan.train(train_samples2)
    torch.save(simpgan.state_dict(), "simpgan_mog_"+str(nm)+".pth")
    print('END MIXED')
    print()
    print()
    

In [None]:
dlayers = [2,64,64,64,1]
dloss = 1e-4
glayers = [2,16,16,2]
gloss = 1e-4
EP = 1000

mg = 1.0

for nm in [ 3.0]:
    simpgan = SimpWGANGP(dlayers, glayers,dloss,gloss,EP,
                         sample_size=train_samples2.shape[0],compute_exact_w=False,privacy=False,nm=nm,mg=mg)
    simpgan.train(train_samples2)
    torch.save(simpgan.state_dict(), "simpgan_mog_"+str(nm)+".pth")
    print('END MIXED')
    print()
    print()
    

In [None]:
dlayers = [2,64,64,64,1]
dloss = 1e-4
glayers = [2,16,16,2]
gloss = 1e-4
EP = 1000

mg = 1.0

for nm in [ 30]:
    simpgan = SimpWGANGP(dlayers, glayers,dloss,gloss,EP,
                         sample_size=train_samples2.shape[0],compute_exact_w=False,privacy=False,nm=nm,mg=mg)
    simpgan.train(train_samples2)
    torch.save(simpgan.state_dict(), "simpgan_mog_"+str(nm)+".pth")
    print('END MIXED')
    print()
    print()
    

In [None]:
nms = [0.0001, 0.001, 0.01, 0.1, 1.0,10.,30, 3, .3]
nms = [0.0001, 0.001, 1.0, 3.0,10,30]
#nms = [0.0001, 0.001, 1.0, 3.0]
#nms = [ 0.001,1.0, 3.0, 10]
likelis = []
for nm in nms:
    path = "simpgan_mog_"+str(nm)+".pth"
    simpgan = SimpWGANGP(dlayers, glayers,dloss,gloss,EP,
                         sample_size=train_samples.shape[0],compute_exact_w=True,privacy=False,nm=nm,mg=mg)
    simpgan.load_state_dict(torch.load(path))
    #simpgan.eval()
    like = simpgan.FullDistLikeliMOG()
    likelis.append(like)
    print(nm, like)
    plt.scatter(np.log(nm),np.log(like))
plt.show()
goodness = np.log(likelis)
#goodness = likelis
plt.plot(np.log(nms),goodness)
plt.scatter(np.log(nms),goodness)
plt.title('privacy-quality tradeoff (log-log scale)')
plt.xlabel('privacy (noise multiplier)')
plt.ylabel('quality (likelihood)')
plt.show()