In [1]:
import time
import os
import os.path as osp
import sys
import io
import random
import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
from torch.autograd import Variable

from models import *

import matplotlib
import matplotlib.pyplot as plt
#get_ipython().run_line_magic('matplotlib', 'inline')

os.environ["CUDA_VISIBLE_DEVICES"]="1"
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print("Device Used: " + str(device))

# set random seed
manual_seed = 17 #42
random.seed(manual_seed)
torch.manual_seed(manual_seed)

Device Used: cpu


<torch._C.Generator at 0x7f97ccc9de50>

In [2]:
## Data and Experiment Config

# Set Result Directory
# this folder's name should describe the configuration
result_folder = './WGAN'
if not os.path.exists(result_folder):
    os.mkdir(result_folder)

# Load Data
data_folder = "./Data/Dataset_5_35"
training_data_complete = np.load(osp.join(data_folder, "spectra_complete_training.npy")).astype(np.float32) #"log_training.npy"
NLA_max = 205 
training_data = training_data_complete[:,:NLA_max] #manually delete very high wavelengths
TRAINING_DATA = training_data.shape[0]
FEATURE_SIZE = training_data.shape[1]
print("#Signal ", TRAINING_DATA)
print("#Measurements ", FEATURE_SIZE)


# Hyperparameters (could receive as arguments along with data/res directory)
LATENT_SIZE = 50

BATCH_SIZE = 100 #32,64,128
NUM_EPOCHS = 5000
D_ITERS = 1

LOG_EVERY = 10   #error logging
SAVE_EVERY = 100  #checkpoints and model saving

LR_D = 0.01
LR_G = 0.01

#Signal  9000
#Measurements  223
(9000, 223)
float32
#Signal  9000
#Measurements  205
(9000, 205)
float32


In [3]:
## Model Definition

# ~~toy code without dataloaders

# Initialize model
d = 6 #num layers
netG = Generator(nz=LATENT_SIZE, nf=FEATURE_SIZE, num_hidden_layers=d).to(device)
netD = Discriminator(nz=LATENT_SIZE, nf=FEATURE_SIZE, num_hidden_layers=d).to(device)
model_size = sum(p.numel() for p in netG.parameters())
print("Generative Model size (total): {:.3f}M".format(model_size/1e6))

# custom weights initialization called on netG and netD
def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Linear') != -1:
        nn.init.normal_(m.weight.data, 0.0, 0.02)
        nn.init.xavier_uniform_(m.weight.data)
        nn.init.constant_(m.bias.data, 0.01)
    elif classname.find('BatchNorm') != -1:
        nn.init.normal_(m.weight.data, 1.0, 0.02)
        nn.init.constant_(m.bias.data, 0)

#netG.apply(weights_init)
#netD.apply(weights_init)


## Training Config


# Establish convention for real and fake labels during training
real_label = 1
fake_label = 0
D_label_smoothing = 0.005
print("Label convention:")
print("real: "+ str(real_label) + "  fake: " + str(fake_label))

# Setup Adam optimizers for both G and D
beta1 = 0.5 # Beta1 hyperparam for Adam optimizers
optimizerD = optim.Adam(netD.parameters(), lr=LR_D, weight_decay=0.001) #SGD(netD.parameters(), lr=LR_D)
optimizerG = optim.Adam(netG.parameters(), lr=LR_G, weight_decay=0.001) #, betas=(beta1, 0.999))
clip_value = 100

# Learning rate decay
decayRate = 0.5
step = 1000 #decay every 1000 epochs
lr_scheduler_D = torch.optim.lr_scheduler.StepLR(optimizer=optimizerD, step_size=step, gamma=decayRate)
lr_scheduler_G = torch.optim.lr_scheduler.StepLR(optimizer=optimizerG, step_size=step, gamma=decayRate)


criterion = nn.BCEWithLogitsLoss(reduction='mean') #nn.SoftMarginLoss(reduction='mean'),nn.BCELoss(reduction='mean')
sigmoid_response = nn.Sigmoid()

# Create batch of latent vectors that we will use to visualize
#  the progression of the generator
fixed_noise = torch.randn(30, LATENT_SIZE, device=device)

print("Architecture:")
print(netG)

Generative Model size (total): 0.251M
Label convention:
real: 1  fake: 0
Architecture:
Generator(
  (layers): ModuleList(
    (0): Sequential(
      (0): Linear(in_features=50, out_features=50, bias=True)
      (1): LeakyReLU(negative_slope=0.01)
      (2): BatchNorm1d(50, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): Sequential(
      (0): Linear(in_features=50, out_features=101, bias=True)
      (1): LeakyReLU(negative_slope=0.01)
      (2): BatchNorm1d(101, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (2): Sequential(
      (0): Linear(in_features=101, out_features=153, bias=True)
      (1): LeakyReLU(negative_slope=0.01)
      (2): BatchNorm1d(153, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (3): Sequential(
      (0): Linear(in_features=153, out_features=204, bias=True)
      (1): LeakyReLU(negative_slope=0.01)
      (2): BatchNorm1d(204, eps=1e-05, momentum=0.1, affine=True, track_running_

In [4]:
## Training

# Preliminaries : 
# 1) training allows restriction of latent codes in the l2 unit ball 
def project_l2_ball(z):
    """ project the vectors in z onto the l2 unit norm ball"""
    z_l2_norm = torch.norm(z, p=2, dim=0).detach()
    if z_l2_norm.item() > 1:
        z = z.div(z_l2_norm.expand_as(z))
    return z


# Training Loop

# Lists to keep track of progress
G_losses = []
D_losses = []

print("Starting Training Loop...")
netG.train()
netD.train()
# For each epoch
for epoch in range(NUM_EPOCHS):
    # Manually build batches for current epoch. Toy code without dataloaders.
    ids = np.asarray(np.random.permutation(np.arange(TRAINING_DATA)))
    num_batches = TRAINING_DATA//BATCH_SIZE
    batches = np.split(ids, num_batches)
    # initialize loss accumulator 
    #errG_acc = 0.0
    #errD_acc = 0.0
    # For each batch
    for i, ids in enumerate(batches):   
        ############################
        # (1) Update D network: maximize log(D(x)) + log(1 - D(G(z)))
        ###########################

        ## Train with all-real batch
        netD.zero_grad()
        # Format batch
        real_batch = torch.tensor(training_data[ids], device=device)
        b_size = real_batch.size(0)
        label_real = torch.full((b_size,), real_label, device=device) #- D_label_smoothing*torch.randn((b_size,), device=device)
        # Forward pass real batch through D
        output_real = netD(real_batch).view(-1)
        # Calculate loss on all-real batch
        errD_real = criterion(output_real, label_real)
        # Calculate gradients for D in backward pass
        errD_real.backward()
        nn.utils.clip_grad_norm_(netD.parameters(), clip_value)
        # Update D
        optimizerD.step()
        D_x = sigmoid_response(output_real).mean().item()


        ## Train with all-fake batch
        netD.zero_grad()
        # Generate batch of latent vectors
        noise = torch.randn(b_size, LATENT_SIZE, device=device)
        #for idx in range(b_size):
        #    noise[idx] = project_l2_ball(noise[idx]) #constrain latent codes in the unit ball
        # Generate fake image batch with G
        fake_batch = netG(noise)
        label_fake = torch.full((b_size,), fake_label, device=device) #- D_label_smoothing*torch.randn((b_size,), device=device)
        # Classify all fake batch with D
        output_fake = netD(fake_batch.detach()).view(-1)
        # Calculate D's loss on the all-fake batch
        errD_fake = criterion(output_fake, label_fake)
        # Calculate the gradients for this batch
        errD_fake.backward()
        nn.utils.clip_grad_norm_(netD.parameters(), clip_value)
        # Update D
        optimizerD.step()
        D_G_z1 = sigmoid_response(output_fake).mean().item()
        # Get current error of D: Add the errors from the all-real and all-fake batches
        errD = errD_real.item() + errD_fake.item()
        #errD.backward()
        # Update D
        #optimizerD.step()

        if (i+1) % D_ITERS == 0:
            ############################
            # (2) Update G network: maximize log(D(G(z)))
            ###########################
            netG.zero_grad()
            # Generate batch of latent vectors
            noise = torch.randn(b_size, LATENT_SIZE, device=device)
            #for idx in range(b_size):
            #    noise[idx] = project_l2_ball(noise[idx]) #constrain latent codes in the unit ball
            # Generate fake image batch with G
            fake_batch = netG(noise)
            label_G = torch.full((b_size,), real_label, device=device) # fake labels are real for generator cost
            # Since we just updated D, perform another forward pass of all-fake batch through D
            output_G = netD(fake_batch).view(-1)
            # Calculate G's loss based on this output
            errG = criterion(output_G, label_G)
            # Calculate gradients for G
            errG.backward()
            #nn.utils.clip_grad_norm_(netG.parameters(), clip_value)
            # Update G
            optimizerG.step()
            D_G_z2 = sigmoid_response(output_G).mean().item()
            # Get current error of G
            errG = errG.item()
        
        #errG_acc += errG.item()
        #errD_acc += errD.item()
        
        # Output training stats, for last batch
        if ((epoch==0 or (epoch+1) % LOG_EVERY == 0)) and (i == num_batches-1):
            print('Epoch[%d/%d]\tLoss_D: %.4f\tLoss_G: %.4f\tD(x): %.4f\tD(G(z)): %.4f / %.4f'
            % (epoch+1, NUM_EPOCHS, errD, errG, D_x, D_G_z1, D_G_z2))

            # Save Losses for plotting later
            G_losses.append(errG)
            D_losses.append(errD)

    # Check how the generator is doing by saving G's output on fixed_noise
    # plus Checkpoint generator
    if ((epoch+1) % SAVE_EVERY == 0):
        with torch.no_grad():
            fake_batch_checkpoint = netG(fixed_noise).detach().cpu()
        # save checkpoints
        #torch.save(fake, osp.join(result_folder,"checkpoint_sample_%d.pth")%(epoch+1))
        # save learned model so far
        torch.save(netG.state_dict(), osp.join(result_folder,"learned_generator_%d.pth")%(epoch+1))
        
# save the learned model and stats at the end of training 
torch.save(netG.state_dict(), osp.join(result_folder,"learned_model.pth"))
np.save(osp.join(result_folder,"G_losses.npy"), np.asarray(G_losses))
np.save(osp.join(result_folder,"D_losses.npy"), np.asarray(D_losses))

Starting Training Loop...
Epoch[1/2000]	Loss_D: 1.0556	Loss_G: 6.9914	D(x): 0.6232	D(G(z)): 0.4122 / 0.0961
Epoch[10/2000]	Loss_D: 20.2988	Loss_G: 1.5156	D(x): 0.6564	D(G(z)): 0.9282 / 0.5360
Epoch[20/2000]	Loss_D: 1.1398	Loss_G: 0.8238	D(x): 0.6375	D(G(z)): 0.4612 / 0.4429
Epoch[30/2000]	Loss_D: 1.0949	Loss_G: 0.9283	D(x): 0.6629	D(G(z)): 0.4408 / 0.4152
Epoch[40/2000]	Loss_D: 2.6874	Loss_G: 530.1486	D(x): 0.7255	D(G(z)): 0.0000 / 0.0000
Epoch[50/2000]	Loss_D: 0.9274	Loss_G: 18.5595	D(x): 0.8147	D(G(z)): 0.5013 / 0.0000
Epoch[60/2000]	Loss_D: 0.0576	Loss_G: 4.7849	D(x): 0.9597	D(G(z)): 0.0129 / 0.0155
Epoch[70/2000]	Loss_D: 0.0443	Loss_G: 4.8281	D(x): 0.9796	D(G(z)): 0.0135 / 0.0135
Epoch[80/2000]	Loss_D: 0.0050	Loss_G: 5.8531	D(x): 0.9996	D(G(z)): 0.0045 / 0.0075
Epoch[90/2000]	Loss_D: 0.0271	Loss_G: 6.8314	D(x): 0.9787	D(G(z)): 0.0051 / 0.0030
Epoch[100/2000]	Loss_D: 0.0296	Loss_G: 6.4845	D(x): 0.9847	D(G(z)): 0.0125 / 0.0023
Epoch[110/2000]	Loss_D: 0.0054	Loss_G: 7.2110	D(x): 0.995

Epoch[980/2000]	Loss_D: 0.0616	Loss_G: 8.5835	D(x): 0.9784	D(G(z)): 0.0111 / 0.0016
Epoch[990/2000]	Loss_D: 0.2908	Loss_G: 16.1616	D(x): 0.8696	D(G(z)): 0.0106 / 0.0359
Epoch[1000/2000]	Loss_D: 0.0303	Loss_G: 10.8323	D(x): 0.9705	D(G(z)): 0.0003 / 0.0002
Epoch[1010/2000]	Loss_D: 0.0570	Loss_G: 8.1054	D(x): 0.9560	D(G(z)): 0.0067 / 0.0087
Epoch[1020/2000]	Loss_D: 0.3145	Loss_G: 3.5435	D(x): 0.9012	D(G(z)): 0.1519 / 0.0550
Epoch[1030/2000]	Loss_D: 1.4408	Loss_G: 1.2038	D(x): 0.6992	D(G(z)): 0.5746 / 0.4262
Epoch[1040/2000]	Loss_D: 1.5651	Loss_G: 0.4856	D(x): 0.5784	D(G(z)): 0.6250 / 0.6221
Epoch[1050/2000]	Loss_D: 0.8365	Loss_G: 4.2335	D(x): 0.5069	D(G(z)): 0.0937 / 0.0802
Epoch[1060/2000]	Loss_D: 0.6741	Loss_G: 5.0413	D(x): 0.6955	D(G(z)): 0.0867 / 0.0943
Epoch[1070/2000]	Loss_D: 0.0753	Loss_G: 9.4763	D(x): 0.9524	D(G(z)): 0.0232 / 0.0035
Epoch[1080/2000]	Loss_D: 1.8939	Loss_G: 1.0097	D(x): 0.6129	D(G(z)): 0.4450 / 0.4775
Epoch[1090/2000]	Loss_D: 0.6721	Loss_G: 1.1538	D(x): 0.5625	D(G(z

Epoch[1950/2000]	Loss_D: 0.2932	Loss_G: 7.1243	D(x): 0.8239	D(G(z)): 0.0035 / 0.0094
Epoch[1960/2000]	Loss_D: 174.7217	Loss_G: 0.9223	D(x): 1.0000	D(G(z)): 1.0000 / 0.4395
Epoch[1970/2000]	Loss_D: 1.3858	Loss_G: 0.6814	D(x): 0.5095	D(G(z)): 0.5074 / 0.5062
Epoch[1980/2000]	Loss_D: 0.0000	Loss_G: 314.4329	D(x): 1.0000	D(G(z)): 0.0000 / 0.0000
Epoch[1990/2000]	Loss_D: 0.0001	Loss_G: 33.0599	D(x): 0.9999	D(G(z)): 0.0000 / 0.0000
Epoch[2000/2000]	Loss_D: 0.0003	Loss_G: 11.5901	D(x): 0.9997	D(G(z)): 0.0000 / 0.0000
