# Convergence and Optimality Analysis of Low-Dimensional Generative Adversarial Networks using Error Function Integrals

## Requirements

Uncomment to install requirements.

In [1]:
# !pip install matplotlib numpy scipy tqdm ipywidgets

## Imports

In [2]:
%matplotlib inline
import matplotlib.pyplot as plt
import numpy as np
import math
import os

from scipy.io import savemat, loadmat
from tqdm.notebook import tqdm
from matplotlib import cm

from ipywidgets import interact, interactive, fixed, interact_manual
import ipywidgets as widgets

## Utility functions

In [3]:
def Phi(x):
  return math.erf(x)
def Phi_plus(x):
  return 1 + Phi(x)
def Phi_minus(x):
  return 1 - Phi(x)

In [4]:
def create_3d_plot():
  fig = plt.figure(figsize=(12, 10))
  ax = fig.gca(projection='3d')
  ax.view_init(azim=45, elev=20)
  ax.invert_xaxis()
  return fig, ax

## Common notes
Objective function (from Goodfellow's paper):
$$
V[G,D] = \mathbb{E}_x \log D(x) + \mathbb{E}_z \log(1 - D(G(z))
$$

The goal is to optimize
$$
\min_G \max_D V[G, D]
$$

The assumption of Goodfellow's paper is that we can find an algorithm which finds globally optimal $D^*, G^*$. Let's see if we can optimize for some simple case, introducing some new levels of complexity if we need them. If the approach of Gf is formulated for a very common case, it should work for a basic one.

A least squares GAN cost function is:
$$
V[G, D] = E_x D^2(x) + E_z [1 - D(G(z))]^2
$$

# Analysis of some case of `x` and `z` distribution


Here we assume
$$
x \sim \text{Exp}(c) \\
z \sim \text{Rayleigh}(1) \\
D(x) = {1 + \Phi(ax + b) \over 2} \\
G(z) = g z^2 + h
$$

Where
$$
\text{Exp}(c) = c e^{-c x} \\
\text{Rayleigh}(1) = 2x \exp(-x^2)
$$

Augmented cost function
$$
V[D,G] = \mathbb{E}_x D^2(x) + \mathbb{E}_z [1 - D(G(z))]^2 = J_1[D] + J_2[D,G]
$$

Therefore
$$
J_1[D] = \frac{1}{4}\int_0^{\infty} (1+\Phi(ax + b))^2 \text{Exp}(x|c) dx \\
J_2[D, G] = \\
\frac{1}{4g}\int_0^{\infty} [1 - \Phi(a\zeta + \eta)]^2 \text{Exp}(\zeta|1/g) dz
$$
where $\eta=ah+b$ and $\zeta=gz^2$


## Cost function computed using sampling
Easiest way (the one, which Goodfellow also follows):
$$
V[G,D] = {1 \over m} \sum_{i = 1}^m
D^2(x_i) + [1 - D(G(z_i))]^2
$$

In [5]:
def Discriminator(x, a, b):
  return np.vectorize(lambda x:0.5*(1 + Phi(a*x + b)))(x)

In [6]:
def Generator(z, g, h):
  return np.vectorize(lambda z:g*(z**2) + h)(z)

In [7]:
def Jbh(x, z, a, b, g, h):
  return np.average(np.square(Discriminator(x, a, b))) + \
         np.average(np.square(1 - Discriminator(Generator(z, g, h), a, b)))

In [8]:
def sample_Jbh(config, n_samples=1000):
  a, b, c, g, h = [v[1] for v in config.items()]
  x = np.random.exponential(1/c, n_samples)
  z = np.random.rayleigh(1/np.sqrt(2), n_samples)
  bg, hg = np.meshgrid(b, h)
  jbh = np.array([[Jbh(x, z, a, b_, g, h_) for b_ in b] for h_ in h])

  return jbh, bg, hg, x, z

## SGD implementation
Here we implement vanilla SGD for the example above.

Define ${\partial J \over \partial b}$, ${\partial J \over \partial h}$:

In [9]:
def J_diff_b_numerical(x, z, a, b, g, h, delta=1e-6):
  return (Jbh(x, z, a, b + delta, g, h) - Jbh(x, z, a, b - delta, g, h))/(2*delta)
def J_diff_h_numerical(x, z, a, b, g, h, delta=1e-6):
  return (Jbh(x, z, a, b, g, h + delta) - Jbh(x, z, a, b , g, h - delta))/(2*delta)

In [10]:
def J_diff_b(x, z, a, b, g, h):
  eta = a*h + b
  f1 = np.vectorize(lambda x: Phi_plus(a*x + b)*np.exp(-(a*x + b)**2))
  f2 = np.vectorize(lambda z: Phi_minus(a*g*z**2 + eta)* \
                              np.exp(-(a*g*z**2 + eta)**2))
  return 1/math.sqrt(math.pi)*np.average(f1(x) - f2(z))
def J_diff_h(x, z, a, b, g, h):
  eta = a*h + b
  f = np.vectorize(lambda z: a*Phi_minus(a*g*z**2 + eta)* \
                             np.exp(-(a*g*z**2 + eta)**2))
  return -1/math.sqrt(math.pi)*np.average(f(z))

In [11]:
def _get_parameters(config, b2, h2):
  a, b, c, g, h = [v[1] for v in config.items()]
  b = b2 if b2 is not None else b
  h = h2 if h2 is not None else h
  return a, b, c, g, h

def calc_J_diff(config, x, z, method=J_diff_b, b2=None, h2=None):
  a, b, c, g, h = _get_parameters(config, b2, h2)
  return np.array([[method(x, z, a, b_, g, h_) for b_ in b] for h_ in h])

Single SGD step:
$$
b \leftarrow b + \epsilon_b {\partial J \over \partial b} \\
h \leftarrow h - \epsilon_h {\partial J \over \partial h}
$$

In [12]:
def SGD_step_bh(x, z, a, b0, g, h0, eps_b, eps_h):
  b = b0 + eps_b*J_diff_b(x, z, a, b0, g, h0)
  h = h0 - eps_h*J_diff_h(x, z, a, b0, g, h0)
  return b, h

SGD procedure:

In [13]:
def SGD_bh(x, z, a, b0, g, h0, eps_b, eps_h, stop):
  b, h = b0, h0
  for i in range(stop):
    b, h = SGD_step_bh(x, z, a, b, g, h, eps_b, eps_h)
    yield b, h

In [14]:
def run_SGD(config_sgd, config_jbh, jbh_surface, bg, hg, x, z, plot_result=True):
  a, b, c, g, h = [v[1] for v in config_jbh.items()]
  b_initial, h_initial, eps_b, eps_h, n_steps = [v[1] 
                                                 for v in config_sgd.items()]
  gd = SGD_bh(x, z, a, b_initial, g, h_initial, eps_b, eps_h, n_steps)
  path = [val for val in gd]
  jbh_path = [Jbh(x, z, a, val[0], g, val[1]) for val in path]
  b_path = [val[0] for val in path]
  h_path = [val[1] for val in path]
  
  if plot_result:
    fig, ax = create_3d_plot()
    ax.plot_wireframe(bg, hg, jbh_surface)
    ax.plot(b_path, h_path, jbh_path, 'r--')
    ax.text(b_path[0], h_path[0], jbh_path[0],
            r'$\leftarrow b_0, h_0 = (%1.1f, %1.1f)$'
            % (b_path[0], h_path[0]), fontsize=16)
    n_path = len(path) - 1
    ax.text(b_path[n_path], h_path[n_path], jbh_path[n_path],
            r'$\leftarrow \hat b, \hat h = (%1.1f, %1.1f)$' % 
            (b_path[n_path], h_path[n_path]), fontsize=16)
    plt.xlabel('b')
    plt.ylabel('h')
    plt.title(
        r'SGD w.r.t. $(b,h)$ $a = %1.1f, c = %1.1f, g = %1.1f, '
        '\epsilon_b = %1.1f, \epsilon_h = %1.1f$, %d steps' 
        % (a, c, g, eps_b, eps_h, n_steps))
    plt.show()
  return b_path, h_path

In [15]:
def run_SGD_batch(jbh_config, sgd_config, n_runs=100, n_samples=1000):
  b_paths = []
  h_paths = []
  for _ in tqdm(range(n_runs)):
    jbh, bg, hg, x, z = sample_Jbh(jbh_config, n_samples=n_samples)
    b_hat, h_hat = run_SGD(sgd_config, 
                             jbh_config, 
                           jbh, bg, hg, x, z,
                           plot_result=False)
    b_paths += [b_hat]
    h_paths += [h_hat]
  return b_paths, h_paths

In [16]:
def show_paths(b_paths, h_paths):
  b_paths, h_paths = np.array(b_paths), np.array(h_paths)

  fig = plt.figure(figsize=(10, 10))

  mean_b = np.mean(b_paths, axis=0)
  std_b = np.std(b_paths, axis=0)
  mean_h = np.mean(h_paths, axis=0)
  std_h = np.std(h_paths, axis=0)
  plt.plot(mean_b, label=r'$\bar b \pm 3\sigma$')
  plt.fill_between(range(len(mean_b)), mean_b - 3*std_b, mean_b + 3*std_b, alpha=0.7)
  plt.plot(mean_h, label=r'$\bar h \pm 3\sigma$')
  plt.fill_between(range(len(mean_h)), mean_h - 3*std_h, mean_h + 3*std_h, alpha=0.7)
  plt.legend()
  plt.show()
  print('solution b = %1.4f +/- %1.4f, h = %1.4f +/- %1.4f' %
        (mean_b[-1], 3*std_b[-1], mean_h[-1], 3*std_h[-1]))

In [17]:
def _set_subplot(title):
  plt.title(title)
  plt.xlabel('B')
  plt.ylabel('H')

def show_comparative(b_paths, h_paths, thetas):
  fig = plt.subplots(4, 2, figsize=(15, 15))

  ax_sgd_traj = plt.subplot(2,2,1)
  for i in range(len(b_paths)):
    plt.plot(b_paths[i], h_paths[i], color='tab:blue')
  _set_subplot('SGD Trajectories')

  ax_sgd_endpoints = plt.subplot(2,2,2)
  for i in range(len(b_paths)):
    plt.scatter(b_paths[i][-1], h_paths[i][-1], color='tab:blue')
  _set_subplot('SGD End Points')

  if thetas is None:
    return
  
  ax_analytical_traj = plt.subplot(2,2,3)
  for i in range(thetas.shape[0]):
    plt.plot(thetas[i][1], thetas[i][3], color='tab:orange')
  _set_subplot('Analytical Trajectories')

  ax_analytical_endpoints = plt.subplot(2,2,4)
  for i in range(thetas.shape[0]):
    plt.scatter(thetas[i][1][-1], thetas[i][3][-1], color='tab:orange')
  _set_subplot('Analytical End Points')
  ax_analytical_endpoints.set_ylim(ax_sgd_endpoints.get_ylim())
  ax_analytical_endpoints.set_xlim(ax_sgd_endpoints.get_xlim())

  plt.show()

### Utility functions to compare with analytical implementation

In [18]:
def load_analytical(fname, param):
  if not os.path.isfile(fname):
    return None
  return loadmat(fname)[param].transpose()

### Case A implementation
Configuration:

In [19]:
case_a_jbh_config = {
    "a": 2.2,
    "b": np.linspace(-4.0, 2.0, 61),
    "c": 0.5,
    "g": 1.5,
    "h": np.linspace(-2.1, 3.1, 41),
}
case_a_sgd_config = {
    "b_initial": -3,
    "h_initial": -2,
    "eps_b": 0.4,
    "eps_h": 0.4,
    "n_steps": 250
}

MC SGD experiment:

In [20]:
@interact_manual(n_runs = widgets.IntSlider(3, min=1, max=100))
def case_a(n_runs):
  b_paths, h_paths = run_SGD_batch(case_a_jbh_config, case_a_sgd_config,
                                   n_runs=n_runs)
  show_comparative(b_paths, h_paths,
                   thetas=load_analytical('content/Thetas_A.mat', 'Thetas_A'))
  savemat('content/case_a.mat', {'b_paths': b_paths, 'h_paths': h_paths})

interactive(children=(IntSlider(value=3, description='n_runs', min=1), Button(description='Run Interact', styl…

### Case B implementation

In [21]:
case_b_jbh_config = {
    "a": 1.21,
    "b": np.linspace(-2.0, 3.0, int(np.floor(5/0.2))),
    "c": 4.04,
    "g": 0.35,
    "h": np.linspace(-5.0, 1.0, int(np.floor(6/0.2))),
}
case_b_sgd_config = {
    "b_initial": -1,
    "h_initial": -4,
    "eps_b": 0.4,
    "eps_h": 0.4,
    "n_steps": 250
}

In [22]:
@interact_manual(n_runs = widgets.IntSlider(3, min=1, max=100))
def case_b(n_runs):
  b_paths, h_paths = run_SGD_batch(case_b_jbh_config, case_b_sgd_config,
                                   n_runs=n_runs)
  show_comparative(b_paths, h_paths,
                   thetas=load_analytical('content/Thetas_B.mat', 'Thetas_B'))
  savemat('content/case_b.mat', {'b_paths': b_paths, 'h_paths': h_paths})

interactive(children=(IntSlider(value=3, description='n_runs', min=1), Button(description='Run Interact', styl…