In [1]:
import jax
import jax.numpy as jnp
from jax.experimental.optimizers import adam
from jax.config import config
import numpy as np
import matplotlib.pyplot as plt
from jax.nn import sigmoid
from jax.nn.initializers import normal
from jax.nn import leaky_relu, sigmoid
from jax.experimental import stax
from jax.example_libraries.stax import (BatchNorm, Conv, ConvTranspose, Dense,
                                   Tanh, Relu, Flatten, Sigmoid)
from jax.experimental.stax import (BatchNorm, Conv, ConvTranspose, Dense,
                                   Tanh, Relu, Flatten, Sigmoid)
from jax.experimental.optimizers import pack_optimizer_state, unpack_optimizer_state
from jax.example_libraries.optimizers import pack_optimizer_state, unpack_optimizer_state
import jax.random as random

from jax.lax import sort

from jax import value_and_grad, jit
from functools import partial
import pickle
import os

finfo = jnp.finfo(jnp.float32)
EPS = finfo.eps
EPSNEG = finfo.epsneg

import argparse
import time


def mlp_discriminator():
    model = stax.serial(
        Dense(out_dim=256), Relu,
        # BatchNorm(axis=(1,)),
        Dense(out_dim=256), Relu,
        # BatchNorm(axis=(1,)),
        Dense(out_dim=256), Relu,
        # BatchNorm(axis=(1,)),
        Dense(out_dim=256), Relu,
        # BatchNorm(axis=(1,)),
        Dense(1),
        # Sigmoid
    )
    return model


def mlp_generator_2d():
    model = stax.serial(
        Dense(out_dim=256), Relu,
        # BatchNorm(axis=(1,)),
        Dense(out_dim=256), Relu,
        # BatchNorm(axis=(1,)),
        Dense(out_dim=256), Relu,
        # BatchNorm(axis=(1,)),
        Dense(out_dim=256), Relu,
        # BatchNorm(axis=(1,)),
        Dense(2)
    )
    return model
class GAN:
  

   

    def __init__(self, d_creator, g_creator, d_opt_creator, g_opt_creator, loss_function):

        d_init, d_apply = d_creator()
        g_init, g_apply = g_creator()
        (d_opt_init, d_opt_update, d_opt_get_params) = d_opt_creator()
        (g_opt_init, g_opt_update, g_opt_get_params) = g_opt_creator()

        # self.creators = {'d_creator': d_creator,
        #                  'g_creator': g_creator,
        #                  'd_opt_creator': d_opt_creator,
        #                  'g_opt_creator': g_opt_creator
        #                  }
        self.d_creator = d_creator
        self.g_creator = g_creator
        self.d_opt_creator = d_opt_creator
        self.g_opt_creator = g_opt_creator
        self.d = {'init': d_init, 'apply': d_apply}
        self.g = {'init': g_init, 'apply': g_apply}
        self.d_opt = {'init': d_opt_init, 'update': d_opt_update, 'get_params': d_opt_get_params}
        self.g_opt = {'init': g_opt_init, 'update': g_opt_update, 'get_params': g_opt_get_params}
        self.loss_function = loss_function
        self.d_output_shape = None
        self.g_output_shape = None
        self.d_input_shape = None
        self.g_input_shape = None
        self.batch_size = None

    def init(self, prng_d, prng_g, d_input_shape, g_input_shape, batch_size):
  
        self.g_input_shape = g_input_shape
        self.d_input_shape = d_input_shape
        self.d_output_shape, d_params = self.d['init'](prng_d, (batch_size, *d_input_shape))
        self.g_output_shape, g_params = self.g['init'](prng_g, (batch_size, *g_input_shape))
        self.batch_size = batch_size
        d_state = self.d_opt['init'](d_params)
        g_state = self.g_opt['init'](g_params)
        return d_state, g_state

    @partial(jit, static_argnums=(0,))
    def _d_loss(self, d_params, g_params, z, real_samples):
        fake_imgs = self.g['apply'](g_params, z)

        fake_predictions = self.d['apply'](d_params, fake_imgs)
        real_predictions = self.d['apply'](d_params, real_samples)
        fake_loss = self.loss_function(fake_predictions, jnp.zeros_like(fake_predictions))
        real_loss = self.loss_function(real_predictions, jnp.ones_like(real_predictions))

        return fake_loss + real_loss

    @partial(jit, static_argnums=(0, 4))
    def _g_loss(self, g_params, d_params, z, k):
        """
        Warning if k is negative, batch_size - k bottom samples are used to calculate error
        Note: You can adjust this method to perform random or bottom updates
        """
        fake_imgs = self.g['apply'](g_params, z)

        fake_predictions = self.d['apply'](d_params, fake_imgs)
        fake_predictions = sort(fake_predictions, 0)
        if k > 0:
            fake_predictions = jnp.flip(fake_predictions, 0)
        # fake_predictions = jnp.flip(fake_predictions, 0)
        fake_predictions = fake_predictions[:k]

        loss = self.loss_function(fake_predictions, jnp.ones_like(fake_predictions))

        return loss

    @partial(jit, static_argnums=(0, 6))
    def train_step(self, i, prng_key, d_state, g_state, real_samples, k):

        k = k or self.batch_size
        prng1, prng2 = random.split(prng_key, 2)
        d_params = self.d_opt['get_params'](d_state)
        g_params = self.g_opt['get_params'](g_state)

        z = random.normal(prng1, (self.batch_size, *self.g_input_shape))
        d_loss_value, d_grads = value_and_grad(self._d_loss)(d_params, g_params, z, real_samples)
        d_state = self.d_opt['update'](i, d_grads, d_state)

        z = random.normal(prng2, (self.batch_size, *self.g_input_shape))
        g_loss_value, g_grads = value_and_grad(self._g_loss)(g_params, d_params, z, k)
        g_state = self.g_opt['update'](i, g_grads, g_state)

        return d_state, g_state, d_loss_value, g_loss_value

    @partial(jit, static_argnums=(0,))
    def generate_samples(self, z, g_state):
        fakes = self.g['apply'](self.g_opt['get_params'](g_state), z)
        return fakes

    @partial(jit, static_argnums=(0,))
    def rate_samples(self, samples, d_state):
        crit = self.d['apply'](self.d_opt['get_params'](d_state), samples)
        return crit
def BCE_from_logits(logits, targets):
    p = sigmoid(logits)
    loss_array = -jnp.log(jnp.where(p == 0, EPS, p)) * targets\
                 - jnp.log(1 - jnp.where(p == 1, 1-EPSNEG, p)) * (1 - targets)
    return jnp.mean(loss_array)



In [2]:
SEED_DEFAULT = 100


class DataLoader:
    def get_next_batch(self):
        pass


class GaussianMixture(DataLoader):
    @staticmethod
    def create_2d_mean_matrix(num_components):
        a = int(np.sqrt(num_components))
        while a < num_components:
            if num_components % a == 0:
                break
            a += 1
        b = num_components // a
        return np.array([[i, j] for i in range(-a // 2 + 1, a // 2 + 1, 1) for j in range(-b // 2 + 1, b // 2 + 1, 1)])

    @staticmethod
    def create_2d_covariance_matrix(variance, num_components):
        return np.array([np.identity(2) * variance for _ in range(num_components)])

    def __init__(self, prng, batch_size, num_modes=None, variance=None, means=None, covariances=None):
        self.prng = prng
        self.batch_size = batch_size
        if means is not None:
            self.means = means
            self.num_modes = len(means)
        else:
            self.means = self.create_2d_mean_matrix(num_modes)
            self.num_modes = num_modes
        if covariances is not None:
            self.covariances = covariances
        else:
            self.covariances = self.create_2d_covariance_matrix(variance, self.num_modes)
        assert self.means.shape[0] == self.covariances.shape[0], "means and covariances must have equal length"
        assert self.means.shape[1] == self.covariances.shape[1], "means and covariances must have corresponding " \
                                                                 "dimensionality "

    def get_next_batch(self):
        self.prng, counts_key, shuffle_key, *keys = random.split(self.prng, self.num_modes + 3)
        numbs, counts = np.unique(random.randint(counts_key, (self.batch_size,), 0, self.num_modes), return_counts=True)
        
        batch = []
        for i, comp_ind in enumerate(numbs):
            samples = random.multivariate_normal(keys[i], self.means[comp_ind], self.covariances[comp_ind],
                                                 (counts[i],), jnp.float32)
            batch.extend(samples)
        batch = np.array(batch)
        batch = batch[random.permutation(shuffle_key, len(batch)),]
        return batch

    def get_iteration_samples(self, num_iter):
        self.prng, counts_key, shuffle_key, *keys = random.split(self.prng, self.num_modes + 3)
        numbs, counts = np.unique(random.randint(counts_key, (self.batch_size * num_iter,), 0, self.num_modes),
                                  return_counts=True)
        batches = []
        print(counts)
        for i, comp_ind in enumerate(numbs):
            print(f"creating {i}th components: {self.means[comp_ind]}->{counts[i]}")
            samples = random.multivariate_normal(keys[i], self.means[comp_ind], self.covariances[comp_ind],
                                                 (counts[i],), jnp.float32)
            batches.extend(samples)
        batches = np.array(batches)
        print(f"shuffling")
        batches = batches[random.permutation(shuffle_key, len(batches)),]
        batches = batches.reshape((num_iter, self.batch_size, self.means[1].size))
        return batches


In [3]:
def get_gaussian_mixture(batch_size, num_iters, components, variance, seed=SEED_DEFAULT,
                         save=True, from_file=True):
    prng = random.PRNGKey(seed)

    path = f"./temp/GaussianMixture-{components}-{variance}-{batch_size}-{num_iters}.npy"
    if os.path.exists(path) and from_file:
        print("file exists")
        data = np.load(path)
    else:
        print("file didn't exist, creating")
        dl = GaussianMixture(prng, batch_size, components, variance)
        data = dl.get_iteration_samples(num_iters)
        # if save:
        #     np.save(path, data)
    return data

In [4]:
def plot_samples_scatter(samples, samples2=None, samples_ratings=None, save_adr=None, show=True, cmap=None):
    X = samples[:, 0],
    Y = samples[:, 1]
    if samples_ratings is not None:
        cmap = cmap or 'cividis'

    plt.scatter(X, Y, c=samples_ratings, alpha=0.2, cmap=cmap)
    if samples_ratings is not None:
        plt.colorbar()

    if samples2 is not None:
        X2 = samples2[:, 0]
        Y2 = samples2[:, 1]

        plt.scatter(X2, Y2, color='red', alpha=0.2)

    if save_adr is not None:
        plt.savefig(save_adr)
    if show:
        plt.show()
    else:
        plt.clf()

In [10]:
from jax._src.lax.lax import top_k


# this is to raise exception when nans are created
JAX_DEBUG_NANS = True
config.update("jax_debug_nans", True)


# ~~~~~~~~~~~ Stax GAN ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
dataset_default = 'gaussian_mixture'
seed_default = 20
num_components_default = 25
gaussian_variance_default = 0.0025
prior_dim_default = 2

d_lr_default = 0.0001
d_momentum_default = 0.9
d_momentum2_default = 0.99
g_lr_default = 0.0001
g_momentum_default = 0.9
g_momentum2_default = 0.99
loss_function_default = BCE_from_logits
batch_size_default = 256
batch_size_min_default = 192
decay_rate_default = 0.99

num_iter_default = 10000

datasets = {'gaussian_mixture':get_gaussian_mixture}


def create_and_initialize_gan(prng, d_lr, d_momentum, d_momentum2, g_lr, g_momentum, g_momentum2, loss_function,
                              d_input_shape, g_input_shape, batch_size):
    d_creator = mlp_discriminator
    g_creator = mlp_generator_2d
    d_opt_creator = partial(adam, d_lr, d_momentum, d_momentum2)
    g_opt_creator = partial(adam, g_lr, g_momentum, g_momentum2)

    gan = GAN(d_creator, g_creator, d_opt_creator, g_opt_creator, loss_function)

    prng1, prng2 = jax.random.split(prng, 2)
    d_state, g_state = gan.init(prng1, prng2, d_input_shape, g_input_shape, batch_size)
    return gan, d_state, g_state


def train(num_components, variance=gaussian_variance_default,
          batch_size=batch_size_default,
          num_iter=num_iter_default,
          batch_size_min=batch_size_min_default,
          dataset=dataset_default, loss_function=loss_function_default,
          prior_dim=prior_dim_default, d_lr=d_lr_default, d_momentum=d_momentum_default,
          d_momentum2=d_momentum2_default, g_lr=g_lr_default, g_momentum=g_momentum_default,
          g_momentum2=g_momentum2_default, top_k=1,
          show_plots=True,
          save_adr_plots_folder=None,
          seed=seed_default
          ):
    prng = jax.random.PRNGKey(seed)
    im_shape = (2,)
    prng_to_use, prng = jax.random.split(prng, 2)
    gan, d_state, g_state = create_and_initialize_gan(prng_to_use,
                                                      d_lr, d_momentum, d_momentum2,
                                                      g_lr, g_momentum, g_momentum2,
                                                      loss_function, im_shape, (prior_dim,), batch_size)
    data = datasets[dataset](batch_size, num_iter, num_components, variance)

    d_losses = []
    g_losses = []

    prng_images, prng = jax.random.split(prng, 2)
    r = jax.random.normal(prng_images, (10000, prior_dim_default))

    start_time = time.time()
    prev_time = time.time()
    k = batch_size
    for i, real_ims in enumerate(data):
        # print info and show results
        if i % 1000 == 0:
            print(f"{i}/{num_iter} took {time.time() - prev_time}")
            prev_time = time.time()
            fake_imgs = gan.generate_samples(r, g_state)
            save_adr_plot = None
            
            plot_samples_scatter(fake_imgs, real_ims,
                                 save_adr=save_adr_plot,
                                 samples_ratings=gan.rate_samples(fake_imgs, d_state),
                                 show=show_plots)
           
        

        # ------------- actual training starts -----------------------------
        # decay k
        if top_k == 1 and i % 2000 == 1999:
            k = int(k * decay_rate_default)
            k = max(batch_size_min, k)
            print(f"iter:{i}/{num_iter}, updated k: {k}")

        # train one step
        prng, prng_to_use = jax.random.split(prng, 2)
        d_state, g_state, d_loss_value, g_loss_value = gan.train_step(i, prng_to_use, d_state, g_state, real_ims, k)
        # ------------- actual training ends --------------------------------

        # keep the iteration loss
        d_losses.append(d_loss_value)
        g_losses.append(g_loss_value)

    print(f'finished, took{time.time() - start_time}')
    return d_losses, g_losses, d_state, g_state, gan

In [None]:
d_losses, g_losses, d_state, g_state, gan = train(num_components=25,
                                                      batch_size=256, top_k = -1)

In [12]:
def _get_random(num, seed):
    prng = jax.random.PRNGKey(seed)
    r = jax.random.normal(prng, (num, 2))
    return r


def eval_gauss(gan, d_state, g_state, num_modes, var, seed=0):
    r = _get_random(10000, seed)
    fake_samples = np.array(gan.generate_samples(r, g_state))
    modes = GaussianMixture.create_2d_mean_matrix(num_modes)
    sd = np.sqrt(var)

    mode_inds, dists = _get_nearest_modes(fake_samples, modes)
    recovered_modes = np.unique(mode_inds[dists < 4 * sd])
    high_quality_samples = mode_inds[dists < 4 * sd]
    print(f"num of recovered modes:{len(recovered_modes)}")
    print(f"high guality samples:{len(high_quality_samples) / len(dists) * 100} %")
    print(f"samples within sd: "
          f"[0,1):{len(dists[dists < sd]) / len(dists) * 100}, "
          f"[1,2):{len(dists[(dists >= sd) & (dists < sd * 2)]) / len(dists) * 100}, "
          f"[2,3):{len(dists[(dists >= sd * 2) & (dists < sd * 3)]) / len(dists) * 100}, "
          f"[3,4):{len(dists[(dists >= sd * 3) & (dists < sd * 4)]) / len(dists) * 100}, "
          f"[4,inf):{len(dists[(dists >= sd * 4)]) / len(dists) * 100}")

    mode, counts = np.unique(mode_inds[dists <= 0.2], return_counts=True)
    plt.bar(mode, counts)
    #plt.savefig("./distribution-of-high-quality-samples-per-mode")
    plt.show()

    return mode_inds, dists


def _get_nearest_modes(samples, modes):
    mode_inds = [-1 for _ in range(len(samples))]
    dists = [-1 for _ in range(len(samples))]
    for i, sample in enumerate(samples):
        mode_inds[i], dists[i] = _get_nearest_mode(sample, modes)
    return np.array(mode_inds), np.array(dists)


def _get_nearest_mode(sample, modes):
    dist = np.sqrt((sample[0] - modes[0][0]) ** 2 + (sample[1] - modes[0][1]) ** 2)
    mode_ind = 0
    for i, mode in enumerate(modes):
        d = np.sqrt((sample[0] - mode[0]) ** 2 + (sample[1] - mode[1]) ** 2)
        if d < dist:
            dist = d
            mode_ind = i
    return mode_ind, dist


In [None]:
mode_index, dist = eval_gauss(gan, d_state, g_state, 25, 0.0025, seed=0)

In [None]:
import matplotlib.pyplot as plt
g_losses = np.reshape(np.array(g_losses), (-1,1))
i = np.arange(10000)
print(i)
plt.plot(i,g_losses)