# Toy Wasserstein GAN (Neural Net Generator)

In [1]:
import torch
import seaborn as sns
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import copy
import matplotlib.pyplot as plt
from scipy.signal import savgol_filter
import pandas as pd
from my_optimizers import GD, Adam, Cata
%config InlineBackend.figure_format = 'svg'
from torch.utils.data import Dataset, DataLoader
plt.rcParams.update({'font.size': 15})
import random
import math

### Parameters

In [2]:
device = "cpu"
nit = 50000
bs = 100
nexp = 3
real_mu = 0
real_sigma = 0.1

### Simple generator and critic

In [3]:
class Generator(torch.nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.fc1 = torch.nn.Linear(1, 5)
        self.fc2 = torch.nn.Linear(5, 1)

    def forward(self, x):
        return self.fc2(F.relu(self.fc1(x)))
    
class Critic(torch.nn.Module):
    def __init__(self):
        super(Critic, self).__init__()
        self.theta1 = nn.Parameter(torch.zeros(1).uniform_(-1,1))
        self.theta2 = nn.Parameter(torch.zeros(1).uniform_(-1,1))
        
    def forward(self, x):
        return self.theta1*x+self.theta2*x*x

### Training Function

In [4]:
def train(settings):
    
    grad_x = torch.zeros((nit,nexp), requires_grad=False)
    grad_y = torch.zeros((nit,nexp), requires_grad=False)
    loss_hist = torch.zeros((nit,nexp), requires_grad=False)

    for e in range(nexp):
        
        ### Init Model
        
        gen = Generator().to(device)
        critic = Critic().to(device) 
        
        
        ### Init Optimizers
       
        ### Init optimizers
        if settings["optim"] == 'Adam':
            tau=settings["tau"]
            beta1_s = settings["beta1"]
            beta2_s = settings["beta2"]
            opt_gen = Adam(gen.parameters(), lr=tau, betas=(beta1_s, beta2_s), eps = 1e-8)
            opt_critic = Adam(critic.parameters(), lr=tau, betas=(beta1_s, beta2_s), eps = 1e-8)
            name = r""+settings["optim"]+', $\\tau = '+str(settings["tau"])+', \\beta_1 = '+str(settings["beta1"])+', \\beta_2 = '+str(settings["beta2"])+'$'
        
        elif settings["optim"] == 'RMSprop':
            tau=settings["tau"]
            beta1_s = 0
            beta2_s = settings["beta2"]
            opt_gen = Adam(gen.parameters(), lr=tau, betas=(beta1_s, beta2_s), eps = 1e-8)
            opt_critic = Adam(critic.parameters(), lr=tau, betas=(beta1_s, beta2_s), eps = 1e-8)
            name = r""+settings["optim"]+', $\\tau = '+str(settings["tau"])+', \\beta_2 = '+str(settings["beta2"])+'$'

        elif settings["optim"] == 'SAGDA':
            tau_1=settings["tau1"]
            tau_2=settings["tau2"]
            opt_gen = GD(gen.parameters(), lr=tau_1)
            opt_critic = GD(critic.parameters(), lr=tau_2)
            name = r""+settings["optim"]+', $\\tau_1 = '+str(settings["tau1"])+', \\tau_2 = '+str(settings["tau2"])+'$'

        elif settings["optim"] == 'Smooth-SAGDA':
            tau_1 = settings["tau1"]
            tau_2 = settings["tau2"]
            beta_s = settings["beta"]
            P_s = settings["P"] 
            opt_gen = Cata(gen.parameters(), lr=tau_1, beta = beta_s, P = P_s)
            opt_critic = GD(critic.parameters(), lr=tau_2)
            name = r""+settings["optim"]+', $\\tau_1 = '+str(settings["tau1"])+', \\tau_2 = '+str(settings["tau2"])+', \\beta = '+str(settings["beta"])+', P = '+str(settings["P"])+'$'

        else:
            print('ERROR, optimizer not defined')
                   

        ### Optimization
        
        for i in range(nit):
            
            z = torch.zeros(bs).normal_(0,1).reshape(-1,1)
            real = real_mu+ real_sigma*z
            
            #critic update
            critic_real = critic(real).reshape(-1)
            critic_fake = critic(gen(z)).reshape(-1)
            loss_critic = -torch.mean(critic_real) + torch.mean(critic_fake)+0.001*(critic.theta1**2+critic.theta2**2)

            def closure_critic():
                gen.zero_grad()
                critic.zero_grad()
                loss_critic.backward(retain_graph=True)
            opt_critic.step(closure_critic)
            
            #generator update
            gen_fake = critic(gen(z)).reshape(-1)
            loss_gen = -torch.mean(gen_fake)#-0.001*(gen.linear.bias**2+gen.linear.weight**2)
            def closure_gen():
                gen.zero_grad()
                critic.zero_grad()
                loss_gen.backward()
            opt_gen.step(closure_gen)      
            
            #saving gradients
            with torch.no_grad():
                gx = [p.grad.data.detach().numpy() for p in critic.parameters()]
                gx = np.concatenate(gx, axis=None)
                grad_x[i,e] = torch.tensor(np.linalg.norm(gx))
                gy = [p.grad.data.detach().numpy() for p in gen.parameters()]
                gy = np.concatenate(gy, axis=None)
                grad_y[i,e] = torch.tensor(np.linalg.norm(gy))
                est_mu = torch.mean(gen(z))
                est_sigma = torch.std(gen(z))
                loss_hist[i,e] = torch.abs(est_mu-real_mu)**2+torch.abs(torch.abs(est_sigma)-real_sigma)**2

    return [name,loss_hist,grad_x,grad_y]

### Training Function

In [None]:
to_run=[]

to_run.append({"optim":"SAGDA", "tau1": 1e-1, "tau2": 5e-1})
to_run.append({"optim":"RMSprop", "tau": 1e-3, "beta2": 0.9 })
to_run.append({"optim":"Adam", "tau": 1e-3, "beta1": 0.5, "beta2": 0.9 })
to_run.append({"optim":"Smooth-SAGDA", "tau1": 1e-1, "tau2": 1e-1, "beta":0.5, "P":10})

gx = []
gy = []
names = []
loss = []

for i in range(len(to_run)):
    names_c,loss_c, gx_c, gy_c = train(to_run[i])
    names.append(names_c)
    loss.append(names_c)
    gx.append(gx_c)
    gy.append(gy_c)

	addcmul_(Number value, Tensor tensor1, Tensor tensor2)
Consider using one of the following signatures instead:
	addcmul_(Tensor tensor1, Tensor tensor2, *, Number value) (Triggered internally at  ../torch/csrc/utils/python_arg_parser.cpp:882.)
  exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)


### Plotting

In [None]:
scale = 0.4
plt.rcParams["figure.figsize"] = (10,5)

markers = ["v","^","<",">","o","s","p","P","*"]
colors = sns.color_palette('colorblind')

for i in range(len(to_run)):
    mean_x_log = np.mean(np.log10(gx[i].detach().numpy()),1)
    std_x_log = np.std(np.log10(gx[i].detach().numpy()), 1)

    mean_y_log = np.mean(np.log10(gy[i].detach().numpy()),1)
    std_y_log = np.std(np.log10(gy[i].detach().numpy()), 1)
    
    mean_x_log_s = savgol_filter(mean_x_log, 601, 3, mode='nearest')
    mean_y_log_s = savgol_filter(mean_y_log, 601, 3, mode='nearest')

    std_x_log_s = savgol_filter(std_x_log, 601, 3, mode='nearest')
    std_y_log_s = savgol_filter(std_y_log, 601, 3, mode='nearest')  

    if i==(len(to_run)-1):
        cc='#695025'
        ls='dotted'
    else:
        cc = colors[i%10]
        ls='-'

    ## Plotting x
    ax = plt.subplot(121)
    plt.plot(range(nit),np.power(10,mean_x_log_s), linestyle=ls, marker = markers[i%7], label=names[i],linewidth=3, color = cc, markevery=10000, markersize = 12)
    plt.fill_between(range(nit),np.power(10,mean_x_log_s-scale*std_x_log_s) , np.power(10,mean_x_log_s+scale*std_x_log_s), alpha=0.5, fc=cc)
    plt.xlabel("Iterations")
    plt.title(r'$\Vert\nabla_x F(x,y)\Vert$')
    plt.yscale("log")
    plt.grid()


    ## Plotting y
    ax = plt.subplot(122)
    plt.plot(range(nit),np.power(10,mean_y_log_s), linestyle=ls, marker = markers[i%7], label=names[i],linewidth=3,color =  cc, markevery=10000, markersize = 12)
    plt.fill_between(range(nit),np.power(10,mean_y_log_s-scale*std_y_log_s) , np.power(10,mean_y_log_s+scale*std_y_log_s), alpha=0.5, fc= cc)
    plt.xlabel("Iterations")
    plt.title(r'$\Vert\nabla_y F(x,y)\Vert$')
    plt.yscale("log")
    plt.legend(loc='center left', bbox_to_anchor=(1.1, 0.5))
    plt.grid()


