In [1]:
import logging
import torch
import torch.optim as optim
import torch.nn as nn
import torch.nn.functional as F
import argparse
import os
import random
import numpy as np
from torch.autograd import Variable
from torch.utils.data import DataLoader
import utils
import itertools
from tqdm import tqdm_notebook
import models.dcgan_unet_64 as dcgan_unet_models
import models.dcgan_64 as dcgan_models
import models.classifiers as classifiers
import models.my_model as my_model
from data.moving_mnist import MovingMNIST

In [2]:
torch.cuda.set_device(0)

Constant definition

In [3]:
# np.random.seed(1)
# random.seed(1)
# torch.manual_seed(1)
# torch.cuda.manual_seed_all(1)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [4]:
lr = 2e-3
seq_len = 12
beta1 = 0.5
content_dim = 128
pose_dim = 50
channels = 3
normalize = False
sd_nf = 100
image_width = 64
batch_size = 100
log_dir = './logs/0529_noiseLikePoseVector_advTraining/'
os.makedirs(os.path.join(log_dir, 'rec'), exist_ok=True)
os.makedirs(os.path.join(log_dir, 'analogy'), exist_ok=True)
os.makedirs(os.path.join(log_dir, 'gen'), exist_ok=True)
logging.basicConfig(filename=os.path.join(log_dir, 'record.txt'), level=logging.DEBUG)

Data Loader

In [5]:
train_data = MovingMNIST(True, '../data_uni/', seq_len=seq_len)
test_data = MovingMNIST(False, '../data_uni/', seq_len=seq_len)

train_loader = DataLoader(
    train_data,
    batch_size=batch_size,
    num_workers=16,
    shuffle=True,
    drop_last=True,
    pin_memory=True
)
test_loader = DataLoader(
    test_data,
    batch_size=batch_size,
    num_workers=0,
    shuffle=True,
    drop_last=True,
    pin_memory=True
)

Model definition

In [6]:
# # netEC = dcgan_unet_models.content_encoder(content_dim, channels).to(device)
# netEC = dcgan_models.content_encoder(content_dim, channels).to(device)
# netEP = dcgan_models.pose_encoder(pose_dim, channels).to(device)
# # netD = dcgan_unet_models.decoder(content_dim, pose_dim, channels).to(device)
# netD = dcgan_models.decoder(content_dim, pose_dim, channels).to(device)
# netC = classifiers.scene_discriminator(pose_dim, sd_nf).to(device)

netEC = my_model.content_encoder(content_dim, channels).to(device)
netEP = my_model.pose_encoder(pose_dim, channels, conditional=True).to(device)
netG = my_model.decoder(content_dim, pose_dim, channels).to(device)
# netC = my_model.scene_discriminator(pose_dim, sd_nf).to(device)
# netC = my_model.Discriminator(channels).to(device)
netD = my_model.CondDiscriminator(channels).to(device)

netEC.apply(utils.weights_init)
netEP.apply(utils.weights_init)
netG.apply(utils.weights_init)
netD.apply(utils.weights_init)

print(netEC)
print(netEP)
print(netG)
print(netD)

content_encoder(
  (main): Sequential(
    (0): dcgan_conv(
      (main): Sequential(
        (0): Conv2d(3, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
        (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU(inplace)
      )
    )
    (1): dcgan_conv(
      (main): Sequential(
        (0): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
        (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU(inplace)
      )
    )
    (2): dcgan_conv(
      (main): Sequential(
        (0): Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
        (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU(inplace)
      )
    )
    (3): dcgan_conv(
      (main): Sequential(
        (0): Conv2d(256, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
 

In [7]:
optimizerEC = optim.Adam(netEC.parameters(), lr=lr, betas=(beta1, 0.999))
optimizerEP = optim.Adam(netEP.parameters(), lr=lr, betas=(beta1, 0.999))
optimizerG = optim.Adam(netG.parameters(), lr=lr, betas=(beta1, 0.999))
optimizerD = optim.Adam(netD.parameters(), lr=lr, betas=(beta1, 0.999))

Plot function

In [8]:
# --------- plotting funtions ------------------------------------
def plot_rec(x, epoch, dtype):
    x_c = x[0]
    x_p = x[np.random.randint(1, len(x))]

    h_c = netEC(x_c)
    h_p = netEP(x_p, h_c)
    rec = netG([h_c, h_p])

    x_c, x_p, rec = x_c.data, x_p.data, rec.data
    fname = '{}-{}.png'.format(dtype, epoch)
    fname = os.path.join(log_dir, 'rec', fname)
    to_plot = []
    row_sz = 5
    nplot = 20
    for i in range(0, nplot-row_sz, row_sz):
        row = [[xc, xp, xr] for xc, xp, xr in zip(x_c[i:i+row_sz], x_p[i:i+row_sz], rec[i:i+row_sz])]
        to_plot.append(list(itertools.chain(*row)))
    utils.save_tensors_image(fname, to_plot)

def plot_analogy(x, epoch, dtype):
    x_c = x[0]
    h_c = netEC(x_c)
    
    nrow = 10
    row_sz = len(x)
    to_plot = []
    row = [xi[0].data for xi in x]
    zeros = torch.zeros(channels, image_width, image_width)
    to_plot.append([zeros] + row)
    for i in range(nrow):
        to_plot.append([x[0][i].data])

    for j in range(0, row_sz):
        # for each time step
        h_p = netEP(x[j], h_c).data
        # first 10 pose vector, equal to first pose vector
        for i in range(nrow):
            h_p[i] = h_p[0]
        rec = netG([h_c, h_p])
        for i in range(nrow):
            to_plot[i+1].append(rec[i].data.clone())

    fname = '{}-{}.png'.format(dtype, epoch)
    fname = os.path.join(log_dir, 'analogy', fname)
    utils.save_tensors_image(fname, to_plot)
    
def plot_gen(x, epoch, dtype):
    """
    Plot generation function
    """
    x_c = x[0]
    noise = torch.randn((batch_size, pose_dim, 1, 1), device=device)
    
    h_c = netEC(x_c)
    gen = netG([h_c, noise])
    
    x_c, gen = x_c.data, gen.data
    fname = '{}-{}.png'.format(dtype, epoch)
    fname = os.path.join(log_dir, 'gen', fname)
    to_plot = []
    row_sz = 5
    nplot = 15
    
    for i in range(0, nplot, row_sz):
        row = [[xc, xg] for xc, xg in zip(x_c[i:i+row_sz], gen[i:i+row_sz])]
        to_plot.append(list(itertools.chain(*row)))
        
    utils.save_tensors_image(fname, to_plot)

Training function

In [9]:
def train(x):
    x_c1 = x[np.random.randint(len(x))]
    x_c2 = x[np.random.randint(len(x))]
    x_p = x[np.random.randint(len(x))]
    
    """
    Train Discriminator
    """
    optimizerD.zero_grad()
    
    h_c1 = netEC(x_c1)
    
    # Train with real sample
    real_x = x_p
    real_lbl = torch.full((batch_size,), 1, device=device)
    out_real = netD(real_x, h_c1.detach())
    errD_real = F.binary_cross_entropy(out_real, real_lbl)
    errD_real.backward()
    D_x = out_real.mean().item()
    
    # Train with fake sample
    fake_lbl = torch.zeros((batch_size,), device=device)
    noise = torch.randn((batch_size, pose_dim, 1, 1), device=device)
    fake_x = netG([h_c1, noise])
    out_fake = netD(fake_x.detach(), h_c1.detach())
    errD_fake = F.binary_cross_entropy(out_fake, fake_lbl)
    errD_fake.backward()
    D_G_z1 = out_fake.mean().item()
    
    errD = errD_fake + errD_real
    optimizerD.step()
    
    
    """
    Train EC&EP&G
    """
    optimizerEC.zero_grad()
#     optimizerEP.zero_grad()
    optimizerG.zero_grad()

    # Adversarial loss
    out_gen = netD(fake_x, h_c1)
    errG = F.binary_cross_entropy(out_gen, real_lbl)
    D_G_z2 = out_gen.mean().item()
    
    # Noise reconstruction loss
#     noise_rec = netEP(fake_x, h_c1)
#     errEP = F.mse_loss(noise_rec, noise)

    # similarity loss: ||h_c1 - h_c2||
    h_c2 = netEC(x_c2).detach()
    errSim = F.mse_loss(h_c1, h_c2)

    
    # full loss
#     errTotal = errG + errEP + errSim\
    errTotal = errG + errSim
    errTotal.backward()

    optimizerEC.step()
#     optimizerEP.step()
    optimizerG.step()

#     return errD.item(), D_x, D_G_z1, errG.item(), errEP.item(), errSim.item(), D_G_z2
    return errD.item(), D_x, D_G_z1, errG.item(), errSim.item(), D_G_z2

In [None]:
epoch_size = len(train_loader)
test_x = next(iter(test_loader))
test_x = torch.transpose(test_x, 0, 1)
test_x = test_x.to(device)

for epoch in tqdm_notebook(range(200), desc='EPOCH'):
    netEP.train()
    netEC.train()
    netD.train()
    netG.train()
    epoch_errSim, epoch_errG, epoch_errEP, epoch_errD = 0, 0, 0, 0
    epoch_D_x, epoch_D_G_z1, epoch_D_G_z2 = 0, 0, 0
    
    for i, x in enumerate(tqdm_notebook(train_loader, desc='BATCH')):
        # x to device
        x = torch.transpose(x, 0, 1)
        x = x.to(device)
        
        # train all model
#         errD, D_x, D_G_z1, errG, errEP, errSim, D_G_z2 = train(x)
        errD, D_x, D_G_z1, errG, errSim, D_G_z2 = train(x)
        epoch_errD += errD
        epoch_errG += errG
#         epoch_errEP += errEP
        epoch_errSim += errSim
        epoch_D_x += D_x
        epoch_D_G_z1 += D_G_z1
        epoch_D_G_z2 += D_G_z2
    
#     log_str='[%02d]errD: %.4f| D(x): %.4f| D(G(z1)): %.4f| errG: %.4f| errEP: %.4f| errSim: %.4f| D(G(z2)): %.4f' %\
    log_str='[%02d]errD: %.4f| D(x): %.4f| D(G(z1)): %.4f| errG: %.4f| errSim: %.4f| D(G(z2)): %.4f' %\
    (epoch,
     epoch_errD/epoch_size,
     epoch_D_x/epoch_size,
     epoch_D_G_z1/epoch_size,
     epoch_errG/epoch_size,
#      epoch_errEP/epoch_size,
     epoch_errSim/epoch_size,
     epoch_D_G_z2/epoch_size
     )
    
    print(log_str)
    logging.info(log_str)
    
    netEP.eval()
    netEC.eval()
    netG.eval()
    netD.eval()
    
    with torch.no_grad():
#         plot_rec(test_x, epoch, 'test')
#         plot_analogy(test_x, epoch, 'test')
        plot_gen(test_x, epoch, 'test')

    # save the model
    torch.save({
        'netG': netG,
        'netEP': netEP,
        'netEC': netEC,
        'netD': netD},
        '%s/model.pth' % log_dir)

HBox(children=(IntProgress(value=0, description='EPOCH', max=200), HTML(value='')))

HBox(children=(IntProgress(value=0, description='BATCH', max=600), HTML(value='')))

[00]errD: 27.5884| D(x): 0.9987| D(G(z1)): 0.9987| errG: 0.0000| errSim: 0.0331| D(G(z2)): 1.0000


HBox(children=(IntProgress(value=0, description='BATCH', max=600), HTML(value='')))

[01]errD: 3.0280| D(x): 0.7330| D(G(z1)): 0.3196| errG: 3.0660| errSim: 0.0775| D(G(z2)): 0.2195


HBox(children=(IntProgress(value=0, description='BATCH', max=600), HTML(value='')))

[02]errD: 1.5185| D(x): 0.6973| D(G(z1)): 0.3091| errG: 2.4128| errSim: 0.0787| D(G(z2)): 0.2131


HBox(children=(IntProgress(value=0, description='BATCH', max=600), HTML(value='')))

[03]errD: 0.8136| D(x): 0.7798| D(G(z1)): 0.2262| errG: 2.9123| errSim: 0.0559| D(G(z2)): 0.1516


HBox(children=(IntProgress(value=0, description='BATCH', max=600), HTML(value='')))

[04]errD: 0.4716| D(x): 0.8618| D(G(z1)): 0.1399| errG: 3.8142| errSim: 0.0630| D(G(z2)): 0.0979


HBox(children=(IntProgress(value=0, description='BATCH', max=600), HTML(value='')))

[05]errD: 0.3708| D(x): 0.8970| D(G(z1)): 0.1041| errG: 4.5672| errSim: 0.0756| D(G(z2)): 0.0671


HBox(children=(IntProgress(value=0, description='BATCH', max=600), HTML(value='')))

[06]errD: 0.3107| D(x): 0.9117| D(G(z1)): 0.0890| errG: 4.5765| errSim: 0.0880| D(G(z2)): 0.0593


HBox(children=(IntProgress(value=0, description='BATCH', max=600), HTML(value='')))

[07]errD: 0.2719| D(x): 0.9270| D(G(z1)): 0.0735| errG: 5.0836| errSim: 0.0987| D(G(z2)): 0.0480


HBox(children=(IntProgress(value=0, description='BATCH', max=600), HTML(value='')))

[08]errD: 0.2020| D(x): 0.9465| D(G(z1)): 0.0541| errG: 5.7003| errSim: 0.1175| D(G(z2)): 0.0303


HBox(children=(IntProgress(value=0, description='BATCH', max=600), HTML(value='')))

[09]errD: 0.1517| D(x): 0.9632| D(G(z1)): 0.0368| errG: 6.0414| errSim: 0.1345| D(G(z2)): 0.0197


HBox(children=(IntProgress(value=0, description='BATCH', max=600), HTML(value='')))

[10]errD: 0.2529| D(x): 0.9453| D(G(z1)): 0.0550| errG: 5.6012| errSim: 0.1449| D(G(z2)): 0.0325


HBox(children=(IntProgress(value=0, description='BATCH', max=600), HTML(value='')))

[11]errD: 0.1493| D(x): 0.9630| D(G(z1)): 0.0373| errG: 5.9786| errSim: 0.1650| D(G(z2)): 0.0197


HBox(children=(IntProgress(value=0, description='BATCH', max=600), HTML(value='')))

[12]errD: 0.1379| D(x): 0.9683| D(G(z1)): 0.0313| errG: 6.5786| errSim: 0.1827| D(G(z2)): 0.0158


HBox(children=(IntProgress(value=0, description='BATCH', max=600), HTML(value='')))

[13]errD: 0.2000| D(x): 0.9623| D(G(z1)): 0.0377| errG: 6.2313| errSim: 0.1880| D(G(z2)): 0.0197


HBox(children=(IntProgress(value=0, description='BATCH', max=600), HTML(value='')))

[14]errD: 0.1493| D(x): 0.9664| D(G(z1)): 0.0340| errG: 6.1228| errSim: 0.1926| D(G(z2)): 0.0183


HBox(children=(IntProgress(value=0, description='BATCH', max=600), HTML(value='')))

[15]errD: 0.1286| D(x): 0.9698| D(G(z1)): 0.0302| errG: 6.0716| errSim: 0.2105| D(G(z2)): 0.0153


HBox(children=(IntProgress(value=0, description='BATCH', max=600), HTML(value='')))

[16]errD: 0.1183| D(x): 0.9739| D(G(z1)): 0.0262| errG: 6.7545| errSim: 0.2101| D(G(z2)): 0.0121


HBox(children=(IntProgress(value=0, description='BATCH', max=600), HTML(value='')))

In [None]:
len(train_loader)

In [None]:
for i, x in enumerate(train_loader):
    if i == 0:
        with torch.no_grad():
            x = torch.transpose(x, 0, 1)
            x = x.to(device)
            plot_rec(x, 200)
            plot_analogy(x, 200)