In [1]:
import logging
import torch
import torch.optim as optim
import torch.nn as nn
import torch.nn.functional as F
import argparse
import os
import random
import numpy as np
from torch.autograd import Variable
from torch.utils.data import DataLoader
import utils
import itertools
from tqdm import tqdm_notebook
import models.dcgan_unet_64 as dcgan_unet_models
import models.dcgan_64 as dcgan_models
import models.classifiers as classifiers
import models.my_model as my_model
from data.moving_mnist import MovingMNIST

In [2]:
torch.cuda.set_device(0)

Constant definition

In [3]:
np.random.seed(1)
random.seed(1)
torch.manual_seed(1)
torch.cuda.manual_seed_all(1)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [4]:
lr = 2e-3
seq_len = 12
beta1 = 0.5
content_dim = 128
pose_dim = 50
channels = 3
normalize = False
sd_nf = 100
image_width = 64
batch_size = 100
log_dir = './logs/0610_newCVAEStructure_singleFeature_dualPred_identityMapping_bigTripletLen/'
os.makedirs(os.path.join(log_dir, 'rec'), exist_ok=True)
os.makedirs(os.path.join(log_dir, 'analogy'), exist_ok=True)
os.makedirs(os.path.join(log_dir, 'eval'), exist_ok=True)
logging.basicConfig(filename=os.path.join(log_dir, 'record.txt'), level=logging.DEBUG)

Data Loader

In [5]:
train_data = MovingMNIST(True, '../data_uni/', seq_len=seq_len)
test_data = MovingMNIST(False, '../data_uni/', seq_len=seq_len)

train_loader = DataLoader(
    train_data,
    batch_size=batch_size,
    num_workers=16,
    shuffle=True,
    drop_last=True,
    pin_memory=True
)
test_loader = DataLoader(
    test_data,
    batch_size=batch_size,
    num_workers=0,
    shuffle=True,
    drop_last=True,
    pin_memory=True
)

Model definition

In [6]:
netEM = my_model.motion_encoder(pose_dim, 2*channels).to(device)
netG = my_model.Generator(content_dim, pose_dim, channels).to(device)

netEM.apply(utils.weights_init)
netG.apply(utils.weights_init)

print(netEM)
print(netG)

motion_encoder(
  (main): Sequential(
    (0): dcgan_conv(
      (main): Sequential(
        (0): Conv2d(6, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
        (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU(inplace)
      )
    )
    (1): dcgan_conv(
      (main): Sequential(
        (0): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
        (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU(inplace)
      )
    )
    (2): dcgan_conv(
      (main): Sequential(
        (0): Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
        (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU(inplace)
      )
    )
    (3): dcgan_conv(
      (main): Sequential(
        (0): Conv2d(256, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
  

In [7]:
optimizerEM = optim.Adam(netEM.parameters(), lr=lr, betas=(beta1, 0.999))
optimizerG = optim.Adam(netG.parameters(), lr=lr, betas=(beta1, 0.999))

Plot function

In [8]:
# --------- plotting funtions ------------------------------------
def plot_rec(x, epoch, dtype):
    x_c = x[np.random.randint(len(x))]
    x_p = x[np.random.randint(len(x))]

    h_m = netEM(torch.cat([x_c, x_p], dim=1))
    rec = netG(x_c, h_m)

    x_c, x_p, rec = x_c.data, x_p.data, rec.data
    fname = '{}-{}.png'.format(dtype, epoch)
    fname = os.path.join(log_dir, 'rec', fname)
    to_plot = []
    row_sz = 5
    nplot = 20
    for i in range(0, nplot-row_sz, row_sz):
        row = [[xc, xp, xr] for xc, xp, xr in zip(x_c[i:i+row_sz], x_p[i:i+row_sz], rec[i:i+row_sz])]
        to_plot.append(list(itertools.chain(*row)))
    utils.save_tensors_image(fname, to_plot)

def plot_analogy(x, epoch, dtype):
    x_c = x[0]
    h_c = netEC(x_c)
    
    nrow = 10
    row_sz = len(x)
    to_plot = []
    row = [xi[0].data for xi in x]
    zeros = torch.zeros(channels, image_width, image_width)
    to_plot.append([zeros] + row)
    for i in range(nrow):
        to_plot.append([x[0][i].data])

    for j in range(0, row_sz):
        # for each time step
        h_p = netEP(x[j], h_c).data
        # first 10 pose vector, equal to first pose vector
        for i in range(nrow):
            h_p[i] = h_p[0]
        rec = netD([h_c, h_p])
        for i in range(nrow):
            to_plot[i+1].append(rec[i].data.clone())

    fname = '{}-{}.png'.format(dtype, epoch)
    fname = os.path.join(log_dir, 'analogy', fname)
    utils.save_tensors_image(fname, to_plot)
    
def plot_eval(x, epoch, dtype, triplet_len=None):
    # Get triplet
    if triplet_len is None:
        triplet_len = np.random.randint(1, 4)
    idx_c = np.random.randint(triplet_len, len(x)-triplet_len)
    x_c = x[idx_c]
    x_future = x[idx_c + triplet_len]
    x_past = x[idx_c - triplet_len]

    # task 1: reconstruct x_future'
    #         use -h_future to predict x_past'
    h_future = netEM(torch.cat([x_c, x_future], dim=1))
    rec_future = netG(x_c, h_future)
    pred_past = netG(x_c, -h_future)
    
    # task 2: reconstruct x_past'
    #         use -h_past to predict x_future'
    h_past = netEM(torch.cat([x_c, x_past], dim=1))
    rec_past = netG(x_c, h_past)
    pred_future = netG(x_c, -h_past)

    x_c, x_future, x_past = x_c.data, x_future.data, x_past.data
    rec_future, pred_past = rec_future.data, pred_past.data
    rec_past, pred_future = rec_past.data, pred_future.data
    fname = '{}-{}.png'.format(dtype, epoch)
    fname = os.path.join(log_dir, 'eval', fname)
    
    to_plot = []
    row_sz = 1
    nplot = 10
    for i in range(0, 10, row_sz):
        row = [
            [x_past[i], x_c[i], x_future[i], pred_past[i], x_c[i], rec_future[i], rec_past[i], x_c[i], pred_future[i]]
        ]
        to_plot.append(list(itertools.chain(*row)))
    utils.save_tensors_image(fname, to_plot)

Training function

In [9]:
def train(x):
    optimizerEM.zero_grad()
    optimizerG.zero_grad()

    # regular
    # randomly pick motion image
    triplet_len = np.random.randint(1, len(x)//2)
    idx_c = np.random.randint(triplet_len, len(x)-triplet_len)
    
    x_c = x[idx_c]
    x_p = x[np.random.randint(len(x))]
    if np.random.rand() > 0.5:
        x_past = x[idx_c - triplet_len]
        x_future = x[idx_c + triplet_len]
    else:
        x_past = x[idx_c + triplet_len]
        x_future = x[idx_c - triplet_len]
    
    # reconstruction loss: ||D(h_c1, h_p1), x_p1|| 
    h_m = netEM(torch.cat([x_c, x_p], dim=1))
    rec = netG(x_c, h_m)
    
    # Use triplet, forward, backward loss
    # task 1: pred past
    # task 2: pred future
    # task 3: identity mapping
    h_future = netEM(torch.cat([x_c, x_future], dim=1))
    pred_past = netG(x_c, -h_future)
    
    h_past = netEM(torch.cat([x_c, x_past], dim=1))
    pred_future = netG(x_c, -h_past)
    
    h_zero = torch.zeros_like(h_past)
    rec_id = netG(x_c, h_zero)
    
    # full loss
    rec_loss = F.mse_loss(rec, x_p)
    pred_loss = F.mse_loss(pred_future, x_future) + F.mse_loss(pred_past, x_past)
    id_loss = F.mse_loss(rec_id, x_c)
    reverse_loss = F.mse_loss(h_past, -h_future.detach())
    loss = rec_loss + pred_loss + id_loss + reverse_loss
    loss.backward()

    optimizerEM.step()
    optimizerG.step()

    return rec_loss.item(), pred_loss.item(), id_loss.item(), reverse_loss.item()

In [None]:
epoch_size = len(train_loader)
test_x = next(iter(test_loader))
test_x = torch.transpose(test_x, 0, 1)
test_x = test_x.to(device)

for epoch in tqdm_notebook(range(200), desc='EPOCH'):
    netEM.train()
    netG.train()
    epoch_id_loss, epoch_rec_loss, epoch_reverse_loss, epoch_pred_loss = 0, 0, 0, 0
#     epoch_D_real, epoch_D_fake, epoch_D_G_fake = 0, 0, 0
    
    for i, x in enumerate(tqdm_notebook(train_loader, desc='BATCH')):
        # x to device
        x = torch.transpose(x, 0, 1)
        x = x.to(device)
        
        # train main model
        rec_loss, pred_loss, id_loss, reverse_loss = train(x)
        epoch_rec_loss += rec_loss
        epoch_pred_loss += pred_loss
        epoch_id_loss += id_loss
        epoch_reverse_loss += reverse_loss
    
    log_str='[%02d]rec loss: %.4f| pred loss: %.4f| identity loss: %.4f| reverse loss: %.4f' %\
    (epoch,
     epoch_rec_loss/epoch_size,
     epoch_pred_loss/epoch_size,
     epoch_id_loss/epoch_size,
     epoch_reverse_loss/epoch_size,
     )
    
    print(log_str)
    logging.info(log_str)
    
    netEM.eval()
    netG.eval()
    
    with torch.no_grad():
        plot_rec(test_x, epoch, 'test')
        plot_eval(test_x, epoch, 'test')
#         plot_analogy(test_x, epoch, 'test')

    # save the model
    torch.save({
        'netEM': netEM,
        'netG': netG,
        },
        '%s/model.pth' % log_dir)

HBox(children=(IntProgress(value=0, description='EPOCH', max=200), HTML(value='')))

HBox(children=(IntProgress(value=0, description='BATCH', max=600), HTML(value='')))

[00]rec loss: 0.0140| pred loss: 0.0279| identity loss: 0.0100| reverse loss: 0.0660


HBox(children=(IntProgress(value=0, description='BATCH', max=600), HTML(value='')))

[01]rec loss: 0.0106| pred loss: 0.0221| identity loss: 0.0042| reverse loss: 0.0083


HBox(children=(IntProgress(value=0, description='BATCH', max=600), HTML(value='')))

[02]rec loss: 0.0089| pred loss: 0.0202| identity loss: 0.0032| reverse loss: 0.0029


HBox(children=(IntProgress(value=0, description='BATCH', max=600), HTML(value='')))

[03]rec loss: 0.0078| pred loss: 0.0184| identity loss: 0.0027| reverse loss: 0.0009


HBox(children=(IntProgress(value=0, description='BATCH', max=600), HTML(value='')))

[04]rec loss: 0.0070| pred loss: 0.0180| identity loss: 0.0024| reverse loss: 0.0005


HBox(children=(IntProgress(value=0, description='BATCH', max=600), HTML(value='')))

[05]rec loss: 0.0065| pred loss: 0.0170| identity loss: 0.0022| reverse loss: 0.0004


HBox(children=(IntProgress(value=0, description='BATCH', max=600), HTML(value='')))

[06]rec loss: 0.0061| pred loss: 0.0168| identity loss: 0.0020| reverse loss: 0.0004


HBox(children=(IntProgress(value=0, description='BATCH', max=600), HTML(value='')))

[07]rec loss: 0.0058| pred loss: 0.0163| identity loss: 0.0019| reverse loss: 0.0003


HBox(children=(IntProgress(value=0, description='BATCH', max=600), HTML(value='')))

[08]rec loss: 0.0054| pred loss: 0.0162| identity loss: 0.0018| reverse loss: 0.0003


HBox(children=(IntProgress(value=0, description='BATCH', max=600), HTML(value='')))

[09]rec loss: 0.0054| pred loss: 0.0157| identity loss: 0.0017| reverse loss: 0.0003


HBox(children=(IntProgress(value=0, description='BATCH', max=600), HTML(value='')))

[10]rec loss: 0.0051| pred loss: 0.0156| identity loss: 0.0017| reverse loss: 0.0003


HBox(children=(IntProgress(value=0, description='BATCH', max=600), HTML(value='')))

[11]rec loss: 0.0050| pred loss: 0.0152| identity loss: 0.0016| reverse loss: 0.0003


HBox(children=(IntProgress(value=0, description='BATCH', max=600), HTML(value='')))

[12]rec loss: 0.0049| pred loss: 0.0149| identity loss: 0.0016| reverse loss: 0.0003


HBox(children=(IntProgress(value=0, description='BATCH', max=600), HTML(value='')))

[13]rec loss: 0.0048| pred loss: 0.0147| identity loss: 0.0015| reverse loss: 0.0003


HBox(children=(IntProgress(value=0, description='BATCH', max=600), HTML(value='')))

[14]rec loss: 0.0047| pred loss: 0.0154| identity loss: 0.0015| reverse loss: 0.0003


HBox(children=(IntProgress(value=0, description='BATCH', max=600), HTML(value='')))

[15]rec loss: 0.0046| pred loss: 0.0150| identity loss: 0.0015| reverse loss: 0.0003


HBox(children=(IntProgress(value=0, description='BATCH', max=600), HTML(value='')))

[16]rec loss: 0.0045| pred loss: 0.0149| identity loss: 0.0014| reverse loss: 0.0003


HBox(children=(IntProgress(value=0, description='BATCH', max=600), HTML(value='')))

[17]rec loss: 0.0043| pred loss: 0.0144| identity loss: 0.0014| reverse loss: 0.0003


HBox(children=(IntProgress(value=0, description='BATCH', max=600), HTML(value='')))

[18]rec loss: 0.0043| pred loss: 0.0147| identity loss: 0.0014| reverse loss: 0.0003


HBox(children=(IntProgress(value=0, description='BATCH', max=600), HTML(value='')))

[19]rec loss: 0.0042| pred loss: 0.0140| identity loss: 0.0014| reverse loss: 0.0003


HBox(children=(IntProgress(value=0, description='BATCH', max=600), HTML(value='')))

[20]rec loss: 0.0041| pred loss: 0.0143| identity loss: 0.0014| reverse loss: 0.0003


HBox(children=(IntProgress(value=0, description='BATCH', max=600), HTML(value='')))

[21]rec loss: 0.0042| pred loss: 0.0141| identity loss: 0.0013| reverse loss: 0.0003


HBox(children=(IntProgress(value=0, description='BATCH', max=600), HTML(value='')))

[22]rec loss: 0.0040| pred loss: 0.0139| identity loss: 0.0013| reverse loss: 0.0003


HBox(children=(IntProgress(value=0, description='BATCH', max=600), HTML(value='')))

[23]rec loss: 0.0040| pred loss: 0.0142| identity loss: 0.0013| reverse loss: 0.0003


HBox(children=(IntProgress(value=0, description='BATCH', max=600), HTML(value='')))

[24]rec loss: 0.0039| pred loss: 0.0137| identity loss: 0.0013| reverse loss: 0.0003


HBox(children=(IntProgress(value=0, description='BATCH', max=600), HTML(value='')))

[25]rec loss: 0.0039| pred loss: 0.0133| identity loss: 0.0013| reverse loss: 0.0003


HBox(children=(IntProgress(value=0, description='BATCH', max=600), HTML(value='')))

[26]rec loss: 0.0039| pred loss: 0.0135| identity loss: 0.0013| reverse loss: 0.0003


HBox(children=(IntProgress(value=0, description='BATCH', max=600), HTML(value='')))

[27]rec loss: 0.0038| pred loss: 0.0132| identity loss: 0.0013| reverse loss: 0.0003


HBox(children=(IntProgress(value=0, description='BATCH', max=600), HTML(value='')))

[28]rec loss: 0.0036| pred loss: 0.0139| identity loss: 0.0013| reverse loss: 0.0004


HBox(children=(IntProgress(value=0, description='BATCH', max=600), HTML(value='')))

[29]rec loss: 0.0037| pred loss: 0.0132| identity loss: 0.0012| reverse loss: 0.0003


HBox(children=(IntProgress(value=0, description='BATCH', max=600), HTML(value='')))

[30]rec loss: 0.0036| pred loss: 0.0134| identity loss: 0.0012| reverse loss: 0.0003


HBox(children=(IntProgress(value=0, description='BATCH', max=600), HTML(value='')))

[31]rec loss: 0.0035| pred loss: 0.0137| identity loss: 0.0012| reverse loss: 0.0004


HBox(children=(IntProgress(value=0, description='BATCH', max=600), HTML(value='')))

[32]rec loss: 0.0036| pred loss: 0.0135| identity loss: 0.0012| reverse loss: 0.0004


HBox(children=(IntProgress(value=0, description='BATCH', max=600), HTML(value='')))

[33]rec loss: 0.0035| pred loss: 0.0129| identity loss: 0.0012| reverse loss: 0.0003


HBox(children=(IntProgress(value=0, description='BATCH', max=600), HTML(value='')))

[34]rec loss: 0.0034| pred loss: 0.0132| identity loss: 0.0012| reverse loss: 0.0003


HBox(children=(IntProgress(value=0, description='BATCH', max=600), HTML(value='')))

[35]rec loss: 0.0035| pred loss: 0.0130| identity loss: 0.0012| reverse loss: 0.0003


HBox(children=(IntProgress(value=0, description='BATCH', max=600), HTML(value='')))

[36]rec loss: 0.0034| pred loss: 0.0131| identity loss: 0.0012| reverse loss: 0.0003


HBox(children=(IntProgress(value=0, description='BATCH', max=600), HTML(value='')))

[37]rec loss: 0.0034| pred loss: 0.0128| identity loss: 0.0012| reverse loss: 0.0003


HBox(children=(IntProgress(value=0, description='BATCH', max=600), HTML(value='')))

[38]rec loss: 0.0033| pred loss: 0.0127| identity loss: 0.0012| reverse loss: 0.0003


HBox(children=(IntProgress(value=0, description='BATCH', max=600), HTML(value='')))

[39]rec loss: 0.0033| pred loss: 0.0133| identity loss: 0.0012| reverse loss: 0.0004


HBox(children=(IntProgress(value=0, description='BATCH', max=600), HTML(value='')))

[40]rec loss: 0.0032| pred loss: 0.0129| identity loss: 0.0012| reverse loss: 0.0003


HBox(children=(IntProgress(value=0, description='BATCH', max=600), HTML(value='')))

[41]rec loss: 0.0032| pred loss: 0.0131| identity loss: 0.0011| reverse loss: 0.0003


HBox(children=(IntProgress(value=0, description='BATCH', max=600), HTML(value='')))

[42]rec loss: 0.0031| pred loss: 0.0127| identity loss: 0.0011| reverse loss: 0.0003


HBox(children=(IntProgress(value=0, description='BATCH', max=600), HTML(value='')))

[43]rec loss: 0.0031| pred loss: 0.0129| identity loss: 0.0011| reverse loss: 0.0003


HBox(children=(IntProgress(value=0, description='BATCH', max=600), HTML(value='')))

[44]rec loss: 0.0032| pred loss: 0.0124| identity loss: 0.0011| reverse loss: 0.0003


HBox(children=(IntProgress(value=0, description='BATCH', max=600), HTML(value='')))

[45]rec loss: 0.0031| pred loss: 0.0127| identity loss: 0.0011| reverse loss: 0.0003


HBox(children=(IntProgress(value=0, description='BATCH', max=600), HTML(value='')))

[46]rec loss: 0.0031| pred loss: 0.0125| identity loss: 0.0011| reverse loss: 0.0003


HBox(children=(IntProgress(value=0, description='BATCH', max=600), HTML(value='')))

[47]rec loss: 0.0031| pred loss: 0.0126| identity loss: 0.0011| reverse loss: 0.0003


HBox(children=(IntProgress(value=0, description='BATCH', max=600), HTML(value='')))

[48]rec loss: 0.0031| pred loss: 0.0127| identity loss: 0.0011| reverse loss: 0.0004


HBox(children=(IntProgress(value=0, description='BATCH', max=600), HTML(value='')))

[49]rec loss: 0.0031| pred loss: 0.0123| identity loss: 0.0011| reverse loss: 0.0003


HBox(children=(IntProgress(value=0, description='BATCH', max=600), HTML(value='')))

[50]rec loss: 0.0030| pred loss: 0.0125| identity loss: 0.0011| reverse loss: 0.0003


HBox(children=(IntProgress(value=0, description='BATCH', max=600), HTML(value='')))

[51]rec loss: 0.0030| pred loss: 0.0126| identity loss: 0.0011| reverse loss: 0.0003


HBox(children=(IntProgress(value=0, description='BATCH', max=600), HTML(value='')))

[52]rec loss: 0.0030| pred loss: 0.0123| identity loss: 0.0011| reverse loss: 0.0003


HBox(children=(IntProgress(value=0, description='BATCH', max=600), HTML(value='')))

[53]rec loss: 0.0030| pred loss: 0.0125| identity loss: 0.0011| reverse loss: 0.0003


HBox(children=(IntProgress(value=0, description='BATCH', max=600), HTML(value='')))

[54]rec loss: 0.0030| pred loss: 0.0123| identity loss: 0.0011| reverse loss: 0.0003


HBox(children=(IntProgress(value=0, description='BATCH', max=600), HTML(value='')))

[55]rec loss: 0.0029| pred loss: 0.0126| identity loss: 0.0011| reverse loss: 0.0003


HBox(children=(IntProgress(value=0, description='BATCH', max=600), HTML(value='')))

[56]rec loss: 0.0030| pred loss: 0.0122| identity loss: 0.0011| reverse loss: 0.0003


HBox(children=(IntProgress(value=0, description='BATCH', max=600), HTML(value='')))

[57]rec loss: 0.0029| pred loss: 0.0122| identity loss: 0.0011| reverse loss: 0.0003


HBox(children=(IntProgress(value=0, description='BATCH', max=600), HTML(value='')))

[58]rec loss: 0.0029| pred loss: 0.0127| identity loss: 0.0011| reverse loss: 0.0004


HBox(children=(IntProgress(value=0, description='BATCH', max=600), HTML(value='')))

Make evaluation

In [None]:
with torch.no_grad():
    plot_eval(test_x, 'final-1', 'eval', 1)
    plot_eval(test_x, 'final-2', 'eval', 2)
    plot_eval(test_x, 'final-3', 'eval', 3)