#  Train a GAN On a One-Dimensional Function

**With Extentions**

In [48]:
#import libraries
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np 
import matplotlib.pyplot as plt
import torch.optim as optim

In [49]:
#define a function for sampling from real data distribution
def generate_real_data(n):
    x1 = np.random.uniform(-0.5,0.5,size=(n,1)) #from a uniform distribution
    x2 = x1**2                   
    return torch.from_numpy(np.hstack((x1, x2))).float()

**Define Extented Functions for Scaling Data** 

In [50]:
#define a function to scale real data, so that we can use a tanh activation on generator output 
def scale_real_data(real_data):
    #if we use a different range for real data we have to change scaling
    real_data[:,0] *= 2
    real_data[:,1] = real_data[:,1]*8 -1
    
    return real_data

#define a recreate function for get data for plotting
def recreate_data(array):
    array[:,0] /= 2
    array[:,1] = (array[:,1] + 1)/8
    
    return array

In [51]:
#define discriminator model
class Discriminator(nn.Module):
    def __init__(self, input_dim):
        super(Discriminator, self).__init__()
        
        self.fc1 = nn.Linear(input_dim, 16)
        self.fc2 = nn.Linear(16, 16)
        self.fc3 = nn.Linear(16, 2)
        self.output = nn.Linear(2, 1)
        
    def forward(self,x):
        #forward propagate through discriminator
#         out = F.relu(self.fc1(x))
        out = F.leaky_relu(self.fc1(x))
        out = F.leaky_relu(self.fc2(out))
        out = self.fc3(out)
        out = self.output(out)
        return out

In [52]:
#define generator model 
class Generator(nn.Module):
    def __init__(self, latent_dim, output_dim):
        super(Generator, self).__init__()
        
        self.fc1 = nn.Linear(latent_dim, 16)
        self.fc2 = nn.Linear(16, 16)
        self.out = nn.Linear(16, output_dim)
        
    def forward(self,x):
#         out = F.relu(self.fc1(x))
        out = F.leaky_relu(self.fc1(x))
        out = F.leaky_relu(self.fc2(out))
        out = torch.tanh(self.out(out))
        return out

In [53]:
#define real loss
def real_loss(d_out) :
    batch_size = d_out.size(0)
    #create labels for real data
    labels = torch.ones(batch_size, 1)
    criterion = nn.BCEWithLogitsLoss()
    loss = criterion(d_out, labels)
    return loss

#define fake loss 
def fake_loss(d_out):
    batch_size = d_out.size(0)
    #create labels for fake data
    labels = torch.zeros(batch_size, 1)
    criterion = nn.BCEWithLogitsLoss()
    loss = criterion(d_out, labels)
    return loss

In [54]:
#set hyperparametrs
latent_dim = 5
lr = 0.001


#instantiate Generator and Descriminator
D = Discriminator(2) 
G = Generator(latent_dim,2)

#define optimizers for discriminator and generator
d_optimizer = optim.Adam(D.parameters(), lr, betas=(0.5,0.999))
g_optimizer = optim.Adam(G.parameters(), lr, betas=(0.5,0.999))

D.train(),G.train()

(Discriminator(
   (fc1): Linear(in_features=2, out_features=16, bias=True)
   (fc2): Linear(in_features=16, out_features=16, bias=True)
   (fc3): Linear(in_features=16, out_features=2, bias=True)
   (output): Linear(in_features=2, out_features=1, bias=True)
 ),
 Generator(
   (fc1): Linear(in_features=5, out_features=16, bias=True)
   (fc2): Linear(in_features=16, out_features=16, bias=True)
   (out): Linear(in_features=16, out_features=2, bias=True)
 ))

In [55]:
#define a function for evaluate the GAN performance subjectively by examinig plots
def summarize_performance(samples, epoch, real_list=None, synthetic_list=None, save_fig=False):
    #sample from real data distribution for visualization
    real_data = generate_real_data(samples)    
    
    #sample from latent space for pass into generate
    latent = np.random.normal(0,1,size=(samples, latent_dim))
    latent = torch.from_numpy(latent).float()
    
    #generate synthetic samples
    synthetic_data = G(latent)
    
    #################################################
    ## reacreate SYNTHETIC DATA
    synthetic_data = recreate_data(synthetic_data)
    ################################################   
    
    #saving real and generated data in case for visulizations
    if not real_list==None and not synthetic_list==None:
        real_list.append(real_data)
        synthetic_list.append(synthetic_data)
    
    #plot these 2 samples in scatter plots
    plt.scatter(real_data[:,0].detach(), real_data[:,1].detach(), color='blue', label='Real Distribution')
    plt.scatter(synthetic_data[:,0].detach(), synthetic_data[:,1].detach(), color='red', label='Generated Distribution')
    plt.legend()
    
    #saving plot
    if save_fig:
        filename = 'generated_plot{}'.format(epoch)
        plt.savefig(filename)
    
    plt.show()
    plt.close()

In [None]:
%matplotlib inline

#set parameters related to training
epochs = 10000
batch_size = 128
show_every = 2000

gen_losses = []
dis_losses = []


r_list = []
s_list = []

#implement training loop
for i in range(1 ,epochs+1):
    
    #=========================================
    #        Train Discriminator
    #=========================================
    
    d_optimizer.zero_grad()
    #get a batch from real distribution
    real_data  = generate_real_data(batch_size)

    ##########################################
    # SCALE REAL DATA
    real_data = scale_real_data(real_data)
    ##########################################

    #calculate loss on real samples
    real_output = D(real_data)
    d_r_loss = real_loss(real_output)

    #generate latent samples from a standard normal
    latent = np.random.normal(0,1,size=(batch_size, latent_dim))
    latent = torch.from_numpy(latent).float()
    fake_data = G(latent)

    fake_output = D(fake_data)
    d_f_loss = fake_loss(fake_output)

    #accumilate losses
    d_loss = d_r_loss + d_f_loss

    dis_losses.append(d_loss)

    d_loss.backward()
    d_optimizer.step()

    #===========================================
    #         Train Generator
    #===========================================
    g_optimizer.zero_grad()

    #generate latent samples for generator
    latent = np.random.normal(0,1, size=(batch_size, latent_dim))
    latent = torch.from_numpy(latent).float()
    fake_data = G(latent)

    fake_output = D(fake_data)
    g_loss = real_loss(fake_output)

    gen_losses.append(g_loss)

    g_loss.backward()
    g_optimizer.step()

    if i%show_every == 0:
       print(f'Epoch: {i} | d_loss: {d_loss} | g_loss: {g_loss}')
       summarize_performance(100, i, r_list, s_list)
    
plt.plot(gen_losses, label='g_loss')
plt.plot(dis_losses, label='d_loss')
plt.title('Discriminator Loss and Generator Loss')
plt.legend()
plt.show()