# 1-D Visualization

In [2]:
import argparse
import os
import numpy as np
import math
import pandas as pd
from tqdm import tqdm_notebook as tqdm

import matplotlib.pyplot as plt
import matplotlib.image as mpimg
import torchvision.transforms as transforms
from torchvision.utils import save_image

from torch.utils.data import DataLoader
from torchvision import datasets
from torch.autograd import Variable

import torch.nn as nn
import torch.nn.functional as F
import torch

In [3]:
import seaborn as sns
import matplotlib.pyplot as plt

In [4]:
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.hidden_layer = nn.Sequential(
            nn.Linear(1, 512), 
            nn.ReLU(inplace=True),
            nn.Linear(512, 1), 
        )

    def forward(self, x):
        out = self.hidden_layer(x)
        return out

In [5]:
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.hidden_layer = nn.Sequential(
            nn.Linear(1, 512), 
            nn.ReLU(inplace=True),
            nn.Linear(512, 1), 
            nn.Sigmoid()
        )
    def forward(self,x):
        out = self.hidden_layer(x)
        return out

In [6]:
# Get data from real distribution, let assume it is gaussian distribution where mean is 5, var is 1
def draw_from_real(size):
    return np.random.normal(loc=5.0, scale=0.5, size=size)

# noise is uniform distibution in [-10, 10]
def draw_from_noise(size):
    noise = np.random.uniform(low=-10.0, high=10.0, size=size)
    return noise

# Get data from generator, where input generator is noise
def draw_from_fake(generator, size):
    noises = draw_from_noise(size)
    with torch.no_grad():
        noises = draw_from_noise(size)
        noises = torch.from_numpy(noises) # Conver numpy to tensor
        noises = Variable(noises.type(Tensor)).view(size, -1)
        fake_datas = generator(noises)
    return fake_datas

real_data = draw_from_real(100000)
def plot_density(generator, discriminator, save_path=None, title=None):
    
    if title == 'generator':
        global fake_data
        fake_data = draw_from_fake(generator, 10000)
    fake_data = fake_data.to('cpu')
    sns.distplot(real_data, hist=False, color='red', label='real')
    sns.distplot(fake_data, hist=False, color='blue', label='fake')
    plt.xlim(0, 10)
    plt.ylim(0, 1)
    
    sample_points = np.linspace(0, 10.0, num=1000)
    with torch.no_grad():
        datas = torch.from_numpy(sample_points) # Conver numpy to tensor
        datas = Variable(datas.type(Tensor)).view(-1, 1)
        output = discriminator(datas)
    plt.plot(sample_points, output.data.numpy(), label='Discriminator', ls='--')
    
    plt.legend(loc='upper right')
    if title:
        plt.title(title)
    if save_path:
        plt.savefig(save_path)
        plt.clf()
        return
    plt.show()
    
# plot_density(generator, discriminator)

In [7]:
Tensor = torch.FloatTensor

In [8]:
# Loss function
adversarial_loss = torch.nn.BCELoss()

# Network 
generator = Generator()
discriminator = Discriminator()

fake_data = draw_from_fake(generator, 10000)

# Optimizers
lr = 0.0002
optimizer_G = torch.optim.Adam(generator.parameters(), lr=lr, betas=(0.5, 0.999))
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=lr, betas=(0.5, 0.999))

batch_size = 32

In [None]:
# Training 
real_label = 1
fake_label = 0

d_loss_total = 0
g_loss_total = 0

d_step = 10
g_step = 10

num_iters = 1000

cnt = 0
for iter_ in range(num_iters):
    real_label = Variable(Tensor(batch_size, 1).fill_(1.0), requires_grad=False)
    fake_label = Variable(Tensor(batch_size, 1).fill_(0.0), requires_grad=False)
    
    # --------------------
    #  Train Discriminator
    # ---------------------
    for i in range(d_step):
        optimizer_D.zero_grad()

        # compute loss for real data (discrimator should output 1)
        real_datas = draw_from_real(batch_size)
        real_datas = torch.from_numpy(real_datas) # Conver numpy to tensor
        real_datas = Variable(real_datas.type(Tensor)).view(batch_size, -1)

        real_output = discriminator(real_datas)
        real_loss = adversarial_loss(real_output, real_label)
        real_loss.backward()

        # computer loss for fake data (discrimator should ouput 0)
        noises = draw_from_noise(batch_size)
        noises = torch.from_numpy(noises) # Conver numpy to tensor
        noises = Variable(noises.type(Tensor)).view(batch_size, -1)
        fake_datas = generator(noises)

        fake_output = discriminator(fake_datas.detach()) # don't propagate to generator net
        fake_loss = adversarial_loss(fake_output, fake_label)
        fake_loss.backward()

        # compute total loss and update weight 
        discriminator_loss = real_loss + fake_loss
        optimizer_D.step()
        
        if i % (d_step//10+1) == 0:
            plot_density(generator, discriminator, save_path='./Result/1D/%d.jpg'%(cnt), title='discriminator')
            cnt += 1
    
    d_loss_total += discriminator_loss.item()
        
    # -----------------
    #  Train Generator
    # -----------------
    for i in range(g_step):
        optimizer_G.zero_grad()
        
        # computer loss for fake data (discrimator should ouput 1)
        noises = draw_from_noise(batch_size)
        noises = torch.from_numpy(noises) # Conver numpy to tensor
        noises = Variable(noises.type(Tensor)).view(batch_size, -1)
        fake_datas = generator(noises)

        fake_output = discriminator(fake_datas)
        generator_loss = adversarial_loss(fake_output, real_label)
        generator_loss.backward()
        optimizer_G.step()
        
        if i % (g_step//10+1) == 0:
            plot_density(generator, discriminator, save_path='./Result/1D/%d.jpg'%(cnt), title='generator')
            cnt += 1
            
    if iter_ % (num_iters//10) == 0:
        g_loss_total += generator_loss.item()
        print("[D Loss:%.4f] [G Loss:%.4f]"%(d_loss_total/(iter_+1), g_loss_total/(iter_+1)))

[D Loss:1.1535] [G Loss:0.6151]
[D Loss:1.3710] [G Loss:0.0126]


In [1]:
import glob
import os
def clear_data():
    for img_path in glob.glob('./Result/1D/*.jpg'):
        os.remove(img_path)
clear_data()