In [None]:
import math
from pathlib import Path
import shutil
import colorsys
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torch.utils.data as data

from IPython.display import clear_output
import matplotlib.pyplot as plt
plt.ioff()

%matplotlib inline

## Synthetic data distributions

In [None]:
def two_gaussians_func(size):
    a = torch.randn((size // 2, 2)) * 0.1 + 0.4
    b = torch.randn((size // 2, 2)) * 0.2 - 0.5
    x = torch.cat((a, b))
    return x

def ring_func(size):
    r = torch.rand((size, 1))
    x = 2 * torch.cos(r * math.pi * 2)
    y = torch.sin(r * math.pi * 2) + 0.5
    xy = torch.cat((x, y), axis=1)
    xy += torch.randn((size, 2)) * 0.1
    return xy

In [None]:
two_gaussians = two_gaussians_func(200)
plt.scatter(two_gaussians[:, 0], two_gaussians[:, 1], c='g', marker='.')

In [None]:
ring = ring_func(200)
plt.scatter(ring[:, 0], ring[:, 1], c='g', marker='.')

## Data loaders

In [None]:
class CustomDistribution(data.Dataset):
    def __init__(self, distribution_func, size):
        super().__init__()
        self.data = distribution_func(size)
    
    def __getitem__(self, index):
        return self.data[index]
    
    def __len__(self):
        return len(self.data)

In [None]:
bs_multiplier = 2 # 1 / 16, 1, 16, etc
batch_size = int(16 * bs_multiplier)
dataset_size = 256

In [None]:
# two_gaussians_func or ring_func
synthetic_dataset = CustomDistribution(ring_func, dataset_size)
synthetic_dl = data.DataLoader(synthetic_dataset, batch_size=batch_size, shuffle=True)
synthetic_dataset_all = torch.cat(list(synthetic_dl), axis=0)

## Visualization functions

In [None]:
animation_ims = Path('animation_ims')
noise2G_imgs = Path('noise2G_imgs')

def clear_folder(folder):
    shutil.rmtree(folder, ignore_errors=True)
    folder.mkdir(exist_ok=True)

In [None]:
n_rows = 2
n_cols = 4

def vis_real():
    plt.scatter(synthetic_dataset_all[:, 0], synthetic_dataset_all[:, 1], c='black', marker='.')

def vis_G():
    g_fake = G(val_noise).detach()
    plt.scatter(g_fake[:, 0], g_fake[:, 1], c='r', marker='.')

def vis_real_fake():
    vis_real()
    vis_G()
    
def vis_grad_norms():
    plt.plot(d_grad_norms, label='D')
    plt.plot(g_grad_norms, label='G')
    plt.legend()
    plt.title('Grad norms')
    
def vis_losses():
    plt.plot(d_losses_fakes, label='d_fake')
    plt.plot(d_losses_reals, label='d_real')
    plt.plot(g_losses, label='g')
    plt.legend()
    plt.title('Losses')
    
def vis_preds():
    plt.plot(real_preds, label='real')
    plt.plot(fake_preds, label='fake')
    plt.legend()
    plt.title('D preds')
    
def get_colors_from_matrix(x):
    x_max, y_max = x.max(axis=0)[0]
    x_min, y_min = x.min(axis=0)[0]
    cs = []
    
    for n in x:
        h = (n[0] - x_min) / (x_max - x_min) 
        v = (n[1] - y_min) / (y_max - y_min)
        cs.append(colorsys.hsv_to_rgb(h * 0.9 + 0.1, 1,
                                      v * 0.9 + 0.1))
    return cs
    
random_matrix_l1, random_matrix_l2 = None, None
    
def vis_G_intermediate():
    global random_matrix_l1, random_matrix_l2

    l1, l2, l3 = G.forward_intermediate(val_noise)
    l1, l2, l3 = l1.detach(), l2.detach(), l3.detach()
    
    cs = get_colors_from_matrix(val_noise)
    
    plt.subplot(n_rows, n_cols, 5)
    plt.scatter(val_noise[:, 0], val_noise[:, 1], marker='.',  c=cs)
    plt.title('Input noise')
    
    plt.subplot(n_rows, n_cols, 6)
    if random_matrix_l1 is None:
        random_matrix_l1 = torch.randn((l1.shape[1], 2))
    l1 = torch.mm(l1, random_matrix_l1)
        
    plt.scatter(l1[:, 0], l1[:, 1], marker='.',  c=cs)
    plt.title('Layer 1 activations')
    
    plt.subplot(n_rows, n_cols, 7)
    if random_matrix_l2 is None:
        random_matrix_l2 = torch.randn((l2.shape[1], 2))
    l2 = torch.mm(l2, random_matrix_l2)
        
    plt.scatter(l2[:, 0], l2[:, 1], marker='.',  c=cs)
    plt.title('Layer 2 activations')

    plt.subplot(n_rows, n_cols, 8)
    plt.scatter(l3[:, 0], l3[:, 1], marker='.',  c=cs)
    plt.title('Layer 3 activations')

def vis_D_decision_boundary():
    g_fake = G(val_noise).detach()
    fake_and_real = torch.cat((synthetic_dataset_all, g_fake))
    
    x_min, y_min = fake_and_real.min(axis=0)[0]
    x_max, y_max = fake_and_real.max(axis=0)[0]

    xx, yy = torch.meshgrid(torch.linspace(x_min, x_max), torch.linspace(y_min, y_max))
    xx_flat = xx.reshape(xx.numel())
    yy_flat = yy.reshape(yy.numel())
    
    mesh_points = torch.cat((xx_flat.view(1, -1), yy_flat.view(1, -1)), axis=0).T
    mesh_pred = D(mesh_points).detach().view(*xx.shape)
    
    plt.contourf(xx, yy, mesh_pred, cmap='coolwarm')
    plt.colorbar()

def vis_G_grad():
    fake = G(val_noise)
    fake_pred = D(fake).mean()
    d_fake = torch.autograd.grad(fake_pred, fake)[0]

    arrow_mult = 5e1
    fake_d = fake.detach()
    for i in range(len(fake)):
        plt.arrow(fake_d[i, 0], fake_d[i, 1], d_fake[i, 0] * arrow_mult,
                  d_fake[i, 1] * arrow_mult, head_width=4e-2, color='g', alpha=0.4, width=5e-4)

def vis(ep, with_grad=False):
    G.eval()
    D.eval()
    
    f = plt.figure(figsize=(16, 8))
    plt.suptitle(f'Epoch: {ep}')
    
    plt.subplot(n_rows, n_cols, 1)
    vis_D_decision_boundary()
    vis_real_fake()
    if with_grad:
        vis_G_grad()

    plt.subplot(n_rows, n_cols, 2)
    vis_grad_norms()
    
    plt.subplot(n_rows, n_cols, 3)
    vis_losses()
    
    plt.subplot(n_rows, n_cols, 4)
    vis_preds()
    
    vis_G_intermediate()

In [None]:
vis_real()

## Models

In [None]:
def get_grad_norm(model):
    total_norm = 0
    for p in model.parameters():
        if p.grad is not None:
            param_norm = p.grad.data.norm(2)
            total_norm += param_norm.item() ** 2
    total_norm = total_norm ** (1. / 2)
    return total_norm

class Generator(nn.Module):
    def __init__(self, hidden=8, act=torch.tanh):
        super().__init__()
        self.act = act
        self.l1 = nn.Linear(2, hidden, bias=False)
        self.bn1 = nn.BatchNorm1d(hidden)
        self.l2 = nn.Linear(hidden, hidden, bias=False)
        self.bn2 = nn.BatchNorm1d(hidden)
        self.l3 = nn.Linear(hidden, 2)
    
    def forward(self, x):
        x = self.act(self.l1(x))
        x = self.bn1(x)
        
        x = self.act(self.l2(x))
        x = self.bn2(x)
        
        x = self.l3(x)
        return x
    
    def forward_intermediate(self, x):
        l1 = self.act(self.l1(x))
        l1 = self.bn1(l1)
        
        l2 = self.act(self.l2(l1))
        l2 = self.bn2(l2)
        
        l3 = self.l3(l2)
        return l1, l2, l3

class Discriminator(nn.Module):
    def __init__(self, hidden=8, act=torch.tanh):
        super().__init__()
        self.act = act
        self.l1 = nn.Linear(2, hidden, bias=False)
        self.bn1 = nn.BatchNorm1d(hidden)
        self.l2 = nn.Linear(hidden, hidden, bias=False)
        self.bn2 = nn.BatchNorm1d(hidden)
        self.l3 = nn.Linear(hidden, 1)
    
    def forward(self, x):
        x = self.act(self.l1(x))
        x = self.bn1(x)
        
        x = self.act(self.l2(x))
        self.bn2(x)
        
        x = torch.sigmoid(self.l3(x))
        return x

def get_noise(bs, distribution='norm'):
    if distribution == 'uniform':
        # uniform [-1, 1]
        noise = torch.rand(bs, 2) * 2 - 1
    elif distribution == 'norm':
        # normal N(0, 1)
        noise = torch.randn(bs, 2)
    else:
        raise ValueError('Wrong distribution parameter')
    
    return noise

In [None]:
g_lr = 1e-2 * bs_multiplier
d_lr = 2e-2 * bs_multiplier
epochs = 30

G = Generator(hidden=16, act=torch.relu) # torch.tanh or torch.relu
D = Discriminator(hidden=16, act=torch.relu) # torch.tanh or torch.relu

criterion = nn.BCELoss() # nn.BCELoss() or nn.MSELoss() 

g_optim = optim.SGD(G.parameters(), g_lr, weight_decay=1e-2)
d_optim = optim.SGD(D.parameters(), d_lr, weight_decay=1e-2)

# g_optim = optim.Adam(G.parameters(), g_lr, betas=(0.5, 0.999))
# d_optim = optim.Adam(D.parameters(), d_lr, betas=(0.5, 0.999))

val_noise = get_noise(batch_size * len(synthetic_dl) // 4)

### Weights initialization

In [None]:
# def init_weights(m):
#     if type(m) == nn.Linear:
#         nn.init.xavier_uniform_(m.weight)
#         nn.init.normal_(m.bias, std=0.1)

# G.apply(init_weights)
# D.apply(init_weights)

# list(G.parameters())

## Training

In [None]:
vis_D_decision_boundary()
vis_real_fake()
vis_G_grad()

In [None]:
%%time
d_grad_norms = []
g_grad_norms = []

d_losses_fakes = []
d_losses_reals = []
g_losses = []

fake_preds = []
real_preds = []

clear_folder(animation_ims)
save_for_animation = False

for ep in range(epochs):
    for i, batch_data in enumerate(synthetic_dl):        
        # Train D
        D.train()
        d_optim.zero_grad()

        # on real
        real_pred = D(batch_data)
        d_loss_real = criterion(real_pred, torch.ones(batch_size, 1))
        d_losses_reals.append(d_loss_real.detach())
        real_preds.append(real_pred.mean().detach())

        # on fake
        fake = G(get_noise(batch_size))
        fake_pred = D(fake.detach())
        d_loss_fake = criterion(fake_pred, torch.zeros(batch_size, 1))
        d_losses_fakes.append(d_loss_fake.detach())
        fake_preds.append(fake_pred.mean().detach())

        d_loss = d_loss_real + d_loss_fake
        d_loss.backward()
        d_grad_norms.append(get_grad_norm(D))
        d_optim.step()
        
        # Train G
        D.eval()
        G.train()
        g_optim.zero_grad()

        fake = G(get_noise(batch_size))
        fake_pred = D(fake)
        g_loss = criterion(fake_pred, torch.ones(batch_size, 1))
        g_losses.append(g_loss.detach())

        g_loss.backward()
        g_grad_norms.append(get_grad_norm(G))
        g_optim.step()
        
        # Save for animation (slow), disable this for experiments
        if save_for_animation:
            G.eval()
            D.eval()
            
            if i % 3 == 0:
                f = plt.figure()
                vis_D_decision_boundary()
                vis_real_fake()
                vis_G_grad()
                
                current_step = ep * len(synthetic_dl) + i
                total_steps = epochs * len(synthetic_dl)
                plt.title(f'{int(current_step / total_steps * 100)}%')
                f.savefig(f'{animation_ims}/{current_step:05}.jpg')
                plt.close()

    # Visualization
    clear_output(True)
    vis(ep, with_grad=True)
    plt.show()

In [None]:
# !convert -delay 10 -layers optimize animation_ims/*.jpg anim.gif

## Noise after generator step by step

In [None]:
clear_folder(noise2G_imgs)

noise = get_noise(512)
cs = get_colors_from_matrix(noise)
fake = G(noise)
transition_space = np.linspace(noise, fake.detach(), num=50)

In [None]:
for i in range(len(transition_space)):
    clear_output(True)
    f = plt.figure()
    plt.scatter(transition_space[i][:, 0], transition_space[i][:, 1], c=cs, marker='.')
    plt.title(f'{int(i / len(transition_space) * 100)}%')
    
    if save_for_animation:
        f.savefig(f'{noise2G_imgs}/{i:03}.jpg')
        
    plt.show()

In [None]:
# !convert -delay 8 -layers optimize noise2G_imgs/*.jpg anim2.gif