Generative Adversarial Networks
===========================

#### Credits to Tudor Berariu for tutorial on GAN visualization


## Imports

In [1]:
import io
import numpy as np
from scipy import stats
import time
import base64
import IPython
import matplotlib.pyplot as plt
import seaborn as sns
sns.set("notebook")

In [2]:
import torch
from torch import distributions, optim, nn
from torch.nn import functional as F

## The target distribution: A Mixture of Gaussians

In [3]:
def to_tensor(x, dtype=torch.float, device="cpu"):
    if torch.is_tensor(x):
        return x.float().to(device)
    return torch.tensor(x, dtype=dtype, device=device)


class GaussianMixture:
    """ Mixture of Gaussians. It uses Normal or MultivariateNormal depending
        on the tensor you feed for the covariance.
        
        If the covariances have shape ncomponents x nvars, then it assumes
        covariances are diagonal and the variables are independent.
        If the covariances are ncomponets x nvars x nvars, a MultivariateNormal
        is used instead.
    """
    
    def __init__(self, means, covariances, weights=None, device="cpu"):
        means = to_tensor(means, device=device)
        covariances = to_tensor(covariances, device=device)
        
        if weights is None:
            weights = torch.ones(len(means), device=device)
        weights = to_tensor(weights, device=device)
        weights.div_(weights.sum())
        
        self.__ncomponents = ncomponents = len(means)
        assert ncomponents == len(weights) == len(covariances)
        
        if means.ndimension() < 2:
            means.unsqueeze_(1)
        if covariances.ndimension() < 2:
            covariances.unsqueeze_(1)
        
        self.__nvars = nvars = means.shape[1]
        assert covariances.shape[1] == nvars == covariances.shape[-1]
        
        self.__weights = weights
        self.__full_covariance = covariances.ndimension() == 3
        if self.__full_covariance:
            self.__dist = distributions.MultivariateNormal(means, covariances)
        else:
            self.__dist = distributions.Normal(means, covariances.sqrt())
        self.__mixture_dist = distributions.Categorical(weights)
        
        self.device = device
    
    @property
    def ncomponents(self):
        return self.__ncomponents
    
    @property
    def nvars(self):
        return self.__nvars
    
    def sample(self,nsamples):
        """ Here we sample from the mixture.
        """
        idxs = self.__mixture_dist.sample((nsamples, 1, 1))
        return torch.gather(
            self.__dist.sample((nsamples,)),  # nsamples x ncomponents x nvars
            1,
            idxs.expand(nsamples, 1, self.nvars) # nsamples x 1 x nvars
        ).squeeze(1)
    
    def log_prob(self, xs):
        """ Here we compute the log-probability of examples in xs under the
            mixture of gaussians.
        """
        if xs.ndimension() == 1 and self.nvars == 1:
            xs = xs.unsqueeze(1)
        assert xs.shape[1] == self.__nvars
        
        xs = xs.unsqueeze(1).expand(-1, self.__ncomponents, -1)
        xs = xs.to(device)
        log_probs = self.__dist.log_prob(xs)
        if not self.__full_covariance:
            assert log_probs.shape == (len(xs), self.ncomponents, self.nvars)
            log_probs = log_probs.sum(dim=2)
        return torch.log(log_probs.exp() @ self.__weights)

## The Generative Adversarial Networks


In [4]:
class GAN:
    """ The standard GAN.
        Both the encoder and the decoder habe two hidden layers.
    """
    
    def __init__(self, z_sz, h_sz, x_sz=1):
        self.__x_sz = x_sz = int(x_sz)
        self.__z_sz = z_sz = int(z_sz)
        self.__device = torch.device("cpu")
        
        self.generator = nn.Sequential(
            nn.Linear(z_sz, h_sz), nn.LeakyReLU(),
            # nn.Linear(h_sz, h_sz), nn.LeakyReLU(),
            nn.Linear(h_sz, x_sz)
        )
        self.discriminator = nn.Sequential(
            nn.Linear(x_sz, h_sz), nn.LeakyReLU(),
            # nn.Linear(h_sz, h_sz), nn.LeakyReLU(),
            nn.Linear(h_sz, 1)
        )
        
    @property
    def z_sz(self):
        return self.__z_sz
        
    @property
    def device(self):
        return self.__device
    
    def to(self, device):
        self.generator.to(device)
        self.discriminator.to(device)
        self.__device = device
        

    def sample(self, nsamples):
        """ Here we produce samples using our generator.
        """
        zs = torch.randn(nsamples, self.z_sz, device=self.__device)
        return self.generator(zs)
    
    def discriminate(self, data):
        """ Here we ask the discriminator whether the data it sees is real or
            fake.
        """
        return torch.sigmoid(self.discriminator(data))

## Reporting

In [5]:
def fig2b64(figure):
    """ Returns raw encoding of a png image rendered from the given figure.
    """
    data = io.BytesIO()
    figure.savefig(data, format='png')
    data.seek(0)
    return base64.b64encode(data.read()).decode()


class Reporter:
    
    fig_id = 0
    
    def __init__(self, figsize=(16, 10), pdfs_xlim=(-1, 1), model="GAN"):
        Reporter.fig_id = self.__fig_id = fig_id = Reporter.fig_id + 1
        self.__figure = figure = plt.figure(num=self.__fig_id, figsize=figsize)
        
        self.__pdfs_ax = pdfs_ax = figure.add_subplot(3, 1, 1)
        pdfs_ax.set_xlim(pdfs_xlim)
        pdfs_ax.set_ylim((-0.1, 1.6))
        
        self.__accs_ax = accs_ax = figure.add_subplot(3, 1, 2)
        self.__loss_ax = loss_ax = figure.add_subplot(3, 1, 3, sharex=accs_ax)
        
        # TODO: nu se afiseaza bine
        # figure.suptitle(f"{model:s} learning a distribution")

        pdfs_ax.set_title("Probability Density Functions")
        pdfs_ax.set_xlim(pdfs_xlim)
        accs_ax.set_title("Discriminator accuracy")
        loss_ax.set_title("Loss functions")
        
        figure.tight_layout()
        
        raw_img = fig2b64(figure)

        IPython.display.display_html(
            f'<table>'
            f'<tr><td>D. accuracy</td><td>:</td><td class="acc"></td></tr>'
            f'<tr><td>D. loss</td><td>:</td><td class="dloss"></td></tr>'
            f'<tr><td>G. loss</td><td>:</td><td class="gloss"></td></tr>'
            f'</table>'
            f'<img class="plots" src="data:image/png;base64,{raw_img}"></img>',
            raw=True
        )
        
        self.__lines = {}
    
    def __first_update(self, trace, pdfs):
        steps = trace["steps"]
        
        # Here we plot the accuracies.
        
        self.__lines["real_acc"], = self.__accs_ax.plot(
            steps, trace["real_acc"], label='real data',
        )
        self.__lines["fake_acc"], = self.__accs_ax.plot(
            steps, trace["fake_acc"], label='fake data',
        )
        self.__lines["overall_acc"], = self.__accs_ax.plot(
            steps, trace["overall_acc"], label='overall',
        )
        self.__accs_ax.legend()
        
        # Below we plot the loss functions.
        
        self.__lines["d_loss"], = self.__loss_ax.plot(
            steps, trace["d_loss"], label='D loss',
        )
        self.__lines["g_loss"], = self.__loss_ax.plot(
            steps, trace["g_loss"], label='G loss',
        )
        self.__loss_ax.legend()
        
        # Here we plot the pdfs, and the discriminator
        
        self.__lines["target_pdf"], = self.__pdfs_ax.plot(
            pdfs["support"], pdfs["target"], label="target pdf",
        )
        self.__lines["generator_pdf"], = self.__pdfs_ax.plot(
            pdfs["support"], pdfs["generator"], label="generator pdf",
        )
        self.__lines["decision"], = self.__pdfs_ax.plot(
            pdfs["support"], pdfs["decision"], label="discriminator",
            linestyle='--',
        )
        self.__pdfs_ax.axhline(y=0.5, alpha=0.75, linestyle=':')
        self.__pdfs_ax.axhline(y=1, alpha=0.25, linestyle=':')
        self.__pdfs_ax.legend()
        
    
    
    def __update_plots(self, trace, pdfs):
        if not self.__lines:
            self.__first_update(trace, pdfs)
            return
        
        steps = trace["steps"]
        
        self.__lines["real_acc"].set_data(steps, trace["real_acc"])
        self.__lines["fake_acc"].set_data(steps, trace["fake_acc"])
        self.__lines["overall_acc"].set_data(steps, trace["overall_acc"])
        self.__accs_ax.relim()
        self.__accs_ax.autoscale_view()
        
        self.__lines["d_loss"].set_data(steps, trace["d_loss"])
        self.__lines["g_loss"].set_data(steps, trace["g_loss"])
        self.__loss_ax.relim()
        self.__loss_ax.autoscale_view()
        
        self.__lines["generator_pdf"].set_data(pdfs["support"], pdfs["generator"])
        self.__lines["decision"].set_data(pdfs["support"], pdfs["decision"])

        
    def report(self, trace, pdfs):
        self.__update_plots(trace, pdfs)
        raw_img = fig2b64(self.__figure)
        IPython.display.display_javascript(
            f"document.querySelector('.plots').src = 'data:image/png;base64,"
            f"{raw_img}';"
            f"document.querySelector('.acc').innerHTML = "
            f"'{trace['overall_acc'][-1]:.3f}%';"
            f"document.querySelector('.dloss').innerHTML = "
            f"'{trace['d_loss'][-1]:.3f}';"
            f"document.querySelector('.gloss').innerHTML = "
            f"'{trace['g_loss'][-1]:.3f}';",
            raw=True
        )
        
    
    def close(self):
        plt.close(self.__fig_id)


class GANReporter(Reporter):
    def __init__(self, target_dist, gan, *args, **kwargs):
        self.__gan = gan
        self.__target_dist = target_dist
        self.__reporting_freq = int(reporting_freq)
        
        self.__names = (
            ["real_acc", "fake_acc", "overall_acc"] +
            ["d_loss", "g_loss", "d_real", "d_fake"]
        )
        
        self.__trace = {n: [] for n in self.__names}
        self.__buffers = {n: [] for n in self.__names}
        self.__pdfs = pdfs = dict({})
            
        target_samples = target_dist.sample(5000)
        min_val = target_samples.min().cpu().item()
        max_val = target_samples.max().cpu().item()
        x_lim = 1.1 * min_val - 0.1 * max_val, 1.1 * max_val - 0.1 * min_val
        
        self.__support = support = torch.linspace(*x_lim, 1000)
        pdfs["support"] = support.numpy()
        pdfs["target"] = target_dist.log_prob(support).exp().cpu().numpy()
            
        self.__step = 0
        self.__trace["steps"] = []
        
        super(GANReporter, self).__init__(*args, pdfs_xlim=x_lim, **kwargs)
            
    def tick(self, d_loss, g_loss, real_acc, fake_acc):
        self.__buffers["g_loss"].append(g_loss)
        self.__buffers["d_loss"].append(d_loss)
        self.__buffers["real_acc"].append(real_acc)
        self.__buffers["fake_acc"].append(fake_acc)
        self.__buffers["overall_acc"].append((real_acc + fake_acc) / 2.0)
        
        if self.__step % self.__reporting_freq == 0:
            for name in self.__names:
                self.__trace[name].append(np.mean(self.__buffers[name]))
                self.__buffers[name].clear()
            self.__trace["steps"].append(self.__step)
            
            with torch.no_grad():
                support = self.__support
                self.__pdfs["generator"] = self.__approximate_generator_pdf()(
                    support.numpy()
                )
                xs = support.to(gan.device).unsqueeze(1)
                self.__pdfs["decision"] = gan.discriminate(xs).cpu().numpy()
            self.report(self.__trace, self.__pdfs)
        self.__step += 1
    
    def __approximate_generator_pdf(self):
        samples = self.__gan.sample(20000).view(-1).cpu().numpy()
        kernel = stats.gaussian_kde(samples, bw_method=0.08)
        return kernel
        
        

        

## Training the GANs

In [None]:
#@title Configure training

nsteps = 10000 #@param {type:"integer"}
batch_size = 64 #@param {type: "integer"}
latent_size = 3 #@param {type: "integer"}
hidden_size = 20 #@param {type: "integer"}
gm_locs = [1.0, 4.0, 7.0] #@param
gm_covars = [0.2, 0.05, 0.3] #@param
gm_weights = [4, 1, 3] #@param
learning_rate = 0.001 #@param {type: "number"}
reporting_freq = 10 #@param {type: "integer"}
discriminator_speedup = 5 #@param {type: "integer"}
model = "NS-GAN" #@param ["NS-GAN", "M-GAN", "W-GAN-GP", "W-GAN-PC"]
device = "cpu" #@param ["cpu", "cuda"]

# -- Aici începe nebunia

device = torch.device(device)
gan = GAN(latent_size, hidden_size)
gan.to(device)
mixture = GaussianMixture(gm_locs, gm_covars, weights=gm_weights, device=device)

d_optimizer = optim.Adam(gan.discriminator.parameters(), lr=learning_rate)
g_optimizer = optim.Adam(gan.generator.parameters(), lr=learning_rate)

ones = torch.ones(batch_size, 1, device=device)
zeros = torch.zeros_like(ones)

reporter = GANReporter(mixture, gan)

zeros = torch.zeros_like(ones)

for step in range(nsteps):
    real_xs = mixture.sample(batch_size)
    fake_xs = gan.sample(batch_size)
    
    real_logits = gan.discriminator(real_xs)
    fake_logits = gan.discriminator(fake_xs)

    # Here we optimize the discriminator
    bce_real = F.binary_cross_entropy_with_logits(real_logits, ones)
    bce_fake = F.binary_cross_entropy_with_logits(fake_logits, zeros)
    d_loss = (bce_real + bce_fake) / 2.0

    d_optimizer.zero_grad()
    d_loss.backward(retain_graph=True)
    d_optimizer.step()

    # Here we optimize the generator

    if step % discriminator_speedup == 0:

        if model == "NS-GAN":
            g_loss = F.binary_cross_entropy_with_logits(fake_logits, ones)
        elif model == "M-GAN":
            g_loss = -F.binary_cross_entropy_with_logits(fake_logits, zeros)

        g_optimizer.zero_grad()
        g_loss.backward()
        g_optimizer.step()


    # Here we report progress
    real_acc = (torch.sigmoid(real_logits) >= .5).float().mean().mul(100).item()
    fake_acc = (torch.sigmoid(fake_logits) < .5).float().mean().mul(100).item()
    reporter.tick(d_loss.item(), g_loss.item(), real_acc, fake_acc)


reporter.close()
