In [1]:
import logging
import torch
import torch.optim as optim
import torch.nn as nn
import torch.nn.functional as F
import argparse
import os
import random
import numpy as np
from torch.autograd import Variable
from torch.utils.data import DataLoader
import utils
import itertools
from tqdm import tqdm_notebook
import models.dcgan_unet_64 as dcgan_unet_models
import models.dcgan_64 as dcgan_models
import models.classifiers as classifiers
import models.my_model as my_model
from data.moving_mnist import MovingMNIST

In [2]:
torch.cuda.set_device(0)

Constant definition

In [3]:
np.random.seed(1)
random.seed(1)
torch.manual_seed(1)
torch.cuda.manual_seed_all(1)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [4]:
lr = 2e-3
seq_len = 12
beta1 = 0.5
content_dim = 128
pose_dim = 10
channels = 3
normalize = False
sd_nf = 100
image_width = 64
batch_size = 100
log_dir = './logs/0522_my_model_CVAE_ourDisc_newPair/'
os.makedirs(os.path.join(log_dir, 'rec'), exist_ok=True)
os.makedirs(os.path.join(log_dir, 'analogy'), exist_ok=True)
logging.basicConfig(filename=os.path.join(log_dir, 'record.txt'), level=logging.DEBUG)

Data Loader

In [5]:
train_data = MovingMNIST(True, '../data_uni/', seq_len=seq_len)
test_data = MovingMNIST(False, '../data_uni/', seq_len=seq_len)

train_loader = DataLoader(
    train_data,
    batch_size=batch_size,
    num_workers=16,
    shuffle=True,
    drop_last=True,
    pin_memory=True
)
test_loader = DataLoader(
    test_data,
    batch_size=batch_size,
    num_workers=0,
    shuffle=True,
    drop_last=True,
    pin_memory=True
)

Model definition

In [6]:
# # netEC = dcgan_unet_models.content_encoder(content_dim, channels).to(device)
# netEC = dcgan_models.content_encoder(content_dim, channels).to(device)
# netEP = dcgan_models.pose_encoder(pose_dim, channels).to(device)
# # netD = dcgan_unet_models.decoder(content_dim, pose_dim, channels).to(device)
# netD = dcgan_models.decoder(content_dim, pose_dim, channels).to(device)
# netC = classifiers.scene_discriminator(pose_dim, sd_nf).to(device)

netEC = my_model.content_encoder(content_dim, channels).to(device)
netEP = my_model.pose_encoder(pose_dim, channels, conditional=True).to(device)
netD = my_model.decoder(content_dim, pose_dim, channels).to(device)
# netC = my_model.scene_discriminator(pose_dim, sd_nf).to(device)
netC = my_model.Discriminator(channels).to(device)

netEC.apply(utils.weights_init)
netEP.apply(utils.weights_init)
netD.apply(utils.weights_init)
netC.apply(utils.weights_init)

print(netEC)
print(netEP)
print(netD)
print(netC)

content_encoder(
  (main): Sequential(
    (0): dcgan_conv(
      (main): Sequential(
        (0): Conv2d(3, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
        (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU(inplace)
      )
    )
    (1): dcgan_conv(
      (main): Sequential(
        (0): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
        (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU(inplace)
      )
    )
    (2): dcgan_conv(
      (main): Sequential(
        (0): Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
        (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU(inplace)
      )
    )
    (3): dcgan_conv(
      (main): Sequential(
        (0): Conv2d(256, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
 

In [7]:
optimizerEC = optim.Adam(netEC.parameters(), lr=lr, betas=(beta1, 0.999))
optimizerEP = optim.Adam(netEP.parameters(), lr=lr, betas=(beta1, 0.999))
optimizerD = optim.Adam(netD.parameters(), lr=lr, betas=(beta1, 0.999))
optimizerC = optim.Adam(netC.parameters(), lr=lr, betas=(beta1, 0.999))

Plot function

In [8]:
# --------- plotting funtions ------------------------------------
def plot_rec(x, epoch, dtype):
    x_c = x[0]
    x_p = x[np.random.randint(1, len(x))]

    h_c = netEC(x_c)
    h_p = netEP(x_p, h_c)
    rec = netD([h_c, h_p])

    x_c, x_p, rec = x_c.data, x_p.data, rec.data
    fname = '{}-{}.png'.format(dtype, epoch)
    fname = os.path.join(log_dir, 'rec', fname)
    to_plot = []
    row_sz = 5
    nplot = 20
    for i in range(0, nplot-row_sz, row_sz):
        row = [[xc, xp, xr] for xc, xp, xr in zip(x_c[i:i+row_sz], x_p[i:i+row_sz], rec[i:i+row_sz])]
        to_plot.append(list(itertools.chain(*row)))
    utils.save_tensors_image(fname, to_plot)

def plot_analogy(x, epoch, dtype):
    x_c = x[0]
    h_c = netEC(x_c)
    
    nrow = 10
    row_sz = len(x)
    to_plot = []
    row = [xi[0].data for xi in x]
    zeros = torch.zeros(channels, image_width, image_width)
    to_plot.append([zeros] + row)
    for i in range(nrow):
        to_plot.append([x[0][i].data])

    for j in range(0, row_sz):
        # for each time step
        h_p = netEP(x[j], h_c).data
        # first 10 pose vector, equal to first pose vector
        for i in range(nrow):
            h_p[i] = h_p[0]
        rec = netD([h_c, h_p])
        for i in range(nrow):
            to_plot[i+1].append(rec[i].data.clone())

    fname = '{}-{}.png'.format(dtype, epoch)
    fname = os.path.join(log_dir, 'analogy', fname)
    utils.save_tensors_image(fname, to_plot)

Training function

In [9]:
def train(x):
    optimizerEC.zero_grad()
    optimizerEP.zero_grad()
    optimizerD.zero_grad()

#     x_c1 = x[0]
#     x_c2 = x[1]
#     x_p1 = x[2]
#     x_p2 = x[3]
    x_c1 = x[np.random.randint(len(x))]
    x_c2 = x[np.random.randint(len(x))]
    x_p1 = x[np.random.randint(len(x))]
    x_p2 = x[np.random.randint(len(x))]

    h_c1 = netEC(x_c1)
#     h_c2 = netEC(x_c2)[0].detach()
    h_c2 = netEC(x_c2).detach()
    h_p1 = netEP(x_p1, h_c1.detach()) # used for scene discriminator
    h_p2 = netEP(x_p2, h_c1.detach())


    # similarity loss: ||h_c1 - h_c2||
#     sim_loss = F.mse_loss(h_c1[0], h_c2)
    sim_loss = F.mse_loss(h_c1, h_c2)


    # reconstruction loss: ||D(h_c1, h_p1), x_p1|| 
    rec = netD([h_c1, h_p1])
    rec_loss = F.mse_loss(rec, x_p1)

    # scene discriminator loss: maximize entropy of output
    # target = torch.FloatTensor(batch_size, 1).fill_(0.5).to(device)
    # out = netC([h_p1, h_p2])
    # sd_loss = F.binary_cross_entropy(out, target)
    
    # Swap pose vector to train the discriminator
    target = torch.FloatTensor(batch_size, 1).fill_(1).to(device)
    idx = torch.randperm(batch_size)
    h_p2 = h_p2[idx]
    rec_swap = netD([h_c1, h_p2])
    out = netC([x_c1.detach(), rec_swap]).view(-1, 1)
    D_G_fake = out.mean().item()
    adv_loss = F.binary_cross_entropy(out, target)
    
    # full loss
    loss = sim_loss + rec_loss + 0.1 * adv_loss
    loss.backward()

    optimizerEC.step()
    optimizerEP.step()
    optimizerD.step()

    return sim_loss.item(), rec_loss.item(), adv_loss.item(), D_G_fake

In [10]:
def train_scene_discriminator(x):
    optimizerC.zero_grad()

    target = torch.FloatTensor(batch_size, 1).to(device)
    
    # condition
    h_c = netEC(x[np.random.randint(len(x))]).detach()
    
    x1 = x[0]
    x2 = x[1]
    h_p1 = netEP(x1, h_c).detach()
    h_p2 = netEP(x2, h_c).detach()

    half = batch_size // 2
    rp = torch.randperm(half).cuda()
    h_p2[:half] = h_p2[rp]
    target[:half] = 0
    target[half:] = 1

    out = netC([h_p1, h_p2])
    bce = F.binary_cross_entropy(out, target)

    bce.backward()
    optimizerC.step()

    acc =out[:half].le(0.5).sum() + out[half:].gt(0.5).sum()
    return bce.data.cpu().numpy(), acc.data.cpu().numpy() / batch_size

In [11]:
def train_discriminator(x):
    optimizerC.zero_grad()

    real_lbl = torch.FloatTensor(batch_size, 1).fill_(1).to(device)
    fake_lbl = torch.FloatTensor(batch_size, 1).fill_(0).to(device)
    
    x1 = x[np.random.randint(len(x))]
    x2 = x[np.random.randint(len(x))]
    x3 = x[np.random.randint(len(x))]

    # real pair
    # 1. x1
    # 2. reconstructed frames by pose(x2) and content(x1)
    h_c = netEC(x1).detach()
    h_p = netEP(x3, h_c)
    x_rec = netD([h_c, h_p]).detach()
    out_real = netC([x1, x_rec]).view(-1, 1)
    loss_real = F.binary_cross_entropy(out_real, real_lbl)
    D_real = loss_real.mean().item()
    
    # fake pair
    # 1. x1
    # 2. swapped reconstructed frames
    #    by swapped pose(x3) and content(x1)
    idx = torch.randperm(batch_size)
    h_p = netEP(x3, h_c)
    h_p = h_p[idx]
    x_swap = netD([h_c, h_p]).detach()
    out_fake = netC([x1, x_swap]).view(-1, 1)
    loss_fake = F.binary_cross_entropy(out_fake, fake_lbl)
    D_fake = loss_fake.mean().item()
    
    
    bce = 0.5*loss_real + 0.5*loss_fake
    bce.backward()
    optimizerC.step()

    
    return bce.item(), D_real, D_fake

In [None]:
epoch_size = len(train_loader)
test_x = next(iter(test_loader))
test_x = torch.transpose(test_x, 0, 1)
test_x = test_x.to(device)

for epoch in tqdm_notebook(range(200), desc='EPOCH'):
    netEP.train()
    netEC.train()
    netD.train()
    netC.train()
    epoch_sim_loss, epoch_rec_loss, epoch_adv_loss, epoch_sd_loss = 0, 0, 0, 0
    epoch_D_real, epoch_D_fake, epoch_D_G_fake = 0, 0, 0
    
    for i, x in enumerate(tqdm_notebook(train_loader, desc='BATCH')):
        # x to device
        x = torch.transpose(x, 0, 1)
        x = x.to(device)
        
        # train scene discriminator
        # sd_loss, sd_acc = train_scene_discriminator(x)
        sd_loss, D_real, D_fake = train_discriminator(x)
        epoch_sd_loss += sd_loss
        epoch_D_real += D_real
        epoch_D_fake += D_fake
        
        # train main model
        sim_loss, rec_loss, adv_loss, D_G_fake = train(x)
        epoch_sim_loss += sim_loss
        epoch_rec_loss += rec_loss
        epoch_adv_loss += adv_loss
        epoch_D_G_fake += D_G_fake
    
    log_str='[%02d]rec loss: %.4f |sim loss: %.4f|adv loss: %.4f |sd loss: %.4f \
|D(real): %.2f |D(fake): %.2f |D(G(fake)): %.2f' %\
    (epoch,
     epoch_rec_loss/epoch_size,
     epoch_sim_loss/epoch_size,
     epoch_adv_loss/epoch_size,
     epoch_sd_loss/epoch_size,
     epoch_D_real/epoch_size,
     epoch_D_fake/epoch_size,
     epoch_D_G_fake/epoch_size)
    
    print(log_str)
    logging.info(log_str)
    
    netEP.eval()
    netEC.eval()
    netD.eval()
    
    with torch.no_grad():
        plot_rec(test_x, epoch, 'test')
        plot_analogy(test_x, epoch, 'test')

    # save the model
    torch.save({
        'netD': netD,
        'netEP': netEP,
        'netEC': netEC},
        '%s/model.pth' % log_dir)

HBox(children=(IntProgress(value=0, description='EPOCH', max=200), HTML(value='')))

HBox(children=(IntProgress(value=0, description='BATCH', max=600), HTML(value='')))

[00]rec loss: 0.0172 |sim loss: 0.0325|adv loss: 0.7971 |sd loss: 0.8049     |D(real): 0.78 |D(fake): 0.83 |D(G(fake)): 0.50


HBox(children=(IntProgress(value=0, description='BATCH', max=600), HTML(value='')))

[01]rec loss: 0.0156 |sim loss: 0.0042|adv loss: 0.7159 |sd loss: 0.7048     |D(real): 0.70 |D(fake): 0.70 |D(G(fake)): 0.50


HBox(children=(IntProgress(value=0, description='BATCH', max=600), HTML(value='')))

[02]rec loss: 0.0166 |sim loss: 0.0022|adv loss: 0.7083 |sd loss: 0.6993     |D(real): 0.70 |D(fake): 0.70 |D(G(fake)): 0.50


HBox(children=(IntProgress(value=0, description='BATCH', max=600), HTML(value='')))

[03]rec loss: 0.0162 |sim loss: 0.0008|adv loss: 0.7488 |sd loss: 0.7212     |D(real): 0.72 |D(fake): 0.72 |D(G(fake)): 0.50


HBox(children=(IntProgress(value=0, description='BATCH', max=600), HTML(value='')))

[04]rec loss: 0.0203 |sim loss: 0.0008|adv loss: 0.7076 |sd loss: 0.6990     |D(real): 0.70 |D(fake): 0.70 |D(G(fake)): 0.50


HBox(children=(IntProgress(value=0, description='BATCH', max=600), HTML(value='')))

[05]rec loss: 0.0157 |sim loss: 0.0011|adv loss: 0.6952 |sd loss: 0.6937     |D(real): 0.69 |D(fake): 0.69 |D(G(fake)): 0.50


HBox(children=(IntProgress(value=0, description='BATCH', max=600), HTML(value='')))

[06]rec loss: 0.0135 |sim loss: 0.0005|adv loss: 0.6987 |sd loss: 0.6945     |D(real): 0.69 |D(fake): 0.69 |D(G(fake)): 0.50


HBox(children=(IntProgress(value=0, description='BATCH', max=600), HTML(value='')))

[07]rec loss: 0.0126 |sim loss: 0.0006|adv loss: 0.7058 |sd loss: 0.6960     |D(real): 0.70 |D(fake): 0.70 |D(G(fake)): 0.50


HBox(children=(IntProgress(value=0, description='BATCH', max=600), HTML(value='')))

[08]rec loss: 0.0140 |sim loss: 0.0011|adv loss: 0.7060 |sd loss: 0.6952     |D(real): 0.70 |D(fake): 0.70 |D(G(fake)): 0.50


HBox(children=(IntProgress(value=0, description='BATCH', max=600), HTML(value='')))

[09]rec loss: 0.0133 |sim loss: 0.0010|adv loss: 0.7034 |sd loss: 0.6950     |D(real): 0.69 |D(fake): 0.70 |D(G(fake)): 0.50


HBox(children=(IntProgress(value=0, description='BATCH', max=600), HTML(value='')))

[10]rec loss: 0.0120 |sim loss: 0.0008|adv loss: 0.7064 |sd loss: 0.6956     |D(real): 0.70 |D(fake): 0.70 |D(G(fake)): 0.50


HBox(children=(IntProgress(value=0, description='BATCH', max=600), HTML(value='')))

[11]rec loss: 0.0126 |sim loss: 0.0006|adv loss: 0.7079 |sd loss: 0.6951     |D(real): 0.70 |D(fake): 0.69 |D(G(fake)): 0.49


HBox(children=(IntProgress(value=0, description='BATCH', max=600), HTML(value='')))

[12]rec loss: 0.0117 |sim loss: 0.0008|adv loss: 0.7084 |sd loss: 0.6950     |D(real): 0.70 |D(fake): 0.69 |D(G(fake)): 0.49


HBox(children=(IntProgress(value=0, description='BATCH', max=600), HTML(value='')))

[13]rec loss: 0.0112 |sim loss: 0.0010|adv loss: 0.7099 |sd loss: 0.6941     |D(real): 0.69 |D(fake): 0.69 |D(G(fake)): 0.49


HBox(children=(IntProgress(value=0, description='BATCH', max=600), HTML(value='')))

[14]rec loss: 0.0117 |sim loss: 0.0007|adv loss: 0.7095 |sd loss: 0.6937     |D(real): 0.69 |D(fake): 0.69 |D(G(fake)): 0.49


HBox(children=(IntProgress(value=0, description='BATCH', max=600), HTML(value='')))

[15]rec loss: 0.0114 |sim loss: 0.0010|adv loss: 0.7152 |sd loss: 0.6933     |D(real): 0.69 |D(fake): 0.69 |D(G(fake)): 0.49


HBox(children=(IntProgress(value=0, description='BATCH', max=600), HTML(value='')))

[16]rec loss: 0.0117 |sim loss: 0.0011|adv loss: 0.7180 |sd loss: 0.6926     |D(real): 0.69 |D(fake): 0.69 |D(G(fake)): 0.49


HBox(children=(IntProgress(value=0, description='BATCH', max=600), HTML(value='')))

[17]rec loss: 0.0114 |sim loss: 0.0011|adv loss: 0.7257 |sd loss: 0.6917     |D(real): 0.69 |D(fake): 0.69 |D(G(fake)): 0.49


HBox(children=(IntProgress(value=0, description='BATCH', max=600), HTML(value='')))

[18]rec loss: 0.0119 |sim loss: 0.0011|adv loss: 0.7275 |sd loss: 0.6910     |D(real): 0.69 |D(fake): 0.69 |D(G(fake)): 0.49


HBox(children=(IntProgress(value=0, description='BATCH', max=600), HTML(value='')))

[19]rec loss: 0.0138 |sim loss: 0.0010|adv loss: 0.7094 |sd loss: 0.6924     |D(real): 0.69 |D(fake): 0.69 |D(G(fake)): 0.49


HBox(children=(IntProgress(value=0, description='BATCH', max=600), HTML(value='')))

[20]rec loss: 0.0119 |sim loss: 0.0011|adv loss: 0.7299 |sd loss: 0.6914     |D(real): 0.69 |D(fake): 0.69 |D(G(fake)): 0.48


HBox(children=(IntProgress(value=0, description='BATCH', max=600), HTML(value='')))

[21]rec loss: 0.0118 |sim loss: 0.0010|adv loss: 0.7249 |sd loss: 0.6909     |D(real): 0.69 |D(fake): 0.69 |D(G(fake)): 0.49


HBox(children=(IntProgress(value=0, description='BATCH', max=600), HTML(value='')))

[22]rec loss: 0.0119 |sim loss: 0.0012|adv loss: 0.7284 |sd loss: 0.6916     |D(real): 0.69 |D(fake): 0.69 |D(G(fake)): 0.49


HBox(children=(IntProgress(value=0, description='BATCH', max=600), HTML(value='')))

[23]rec loss: 0.0119 |sim loss: 0.0010|adv loss: 0.7249 |sd loss: 0.6913     |D(real): 0.69 |D(fake): 0.69 |D(G(fake)): 0.49


HBox(children=(IntProgress(value=0, description='BATCH', max=600), HTML(value='')))

[24]rec loss: 0.0131 |sim loss: 0.0012|adv loss: 0.7219 |sd loss: 0.6915     |D(real): 0.69 |D(fake): 0.69 |D(G(fake)): 0.49


HBox(children=(IntProgress(value=0, description='BATCH', max=600), HTML(value='')))

[25]rec loss: 0.0130 |sim loss: 0.0006|adv loss: 0.7058 |sd loss: 0.6936     |D(real): 0.69 |D(fake): 0.69 |D(G(fake)): 0.49


HBox(children=(IntProgress(value=0, description='BATCH', max=600), HTML(value='')))

[26]rec loss: 0.0116 |sim loss: 0.0009|adv loss: 0.7254 |sd loss: 0.6910     |D(real): 0.69 |D(fake): 0.69 |D(G(fake)): 0.49


HBox(children=(IntProgress(value=0, description='BATCH', max=600), HTML(value='')))

[27]rec loss: 0.0115 |sim loss: 0.0009|adv loss: 0.7216 |sd loss: 0.6928     |D(real): 0.69 |D(fake): 0.69 |D(G(fake)): 0.49


HBox(children=(IntProgress(value=0, description='BATCH', max=600), HTML(value='')))

[28]rec loss: 0.0114 |sim loss: 0.0011|adv loss: 0.7347 |sd loss: 0.6899     |D(real): 0.69 |D(fake): 0.69 |D(G(fake)): 0.48


HBox(children=(IntProgress(value=0, description='BATCH', max=600), HTML(value='')))

[29]rec loss: 0.0111 |sim loss: 0.0010|adv loss: 0.7287 |sd loss: 0.6909     |D(real): 0.69 |D(fake): 0.69 |D(G(fake)): 0.49


HBox(children=(IntProgress(value=0, description='BATCH', max=600), HTML(value='')))

[30]rec loss: 0.0112 |sim loss: 0.0011|adv loss: 0.7346 |sd loss: 0.6899     |D(real): 0.69 |D(fake): 0.69 |D(G(fake)): 0.48


HBox(children=(IntProgress(value=0, description='BATCH', max=600), HTML(value='')))

[31]rec loss: 0.0114 |sim loss: 0.0011|adv loss: 0.7371 |sd loss: 0.6891     |D(real): 0.69 |D(fake): 0.69 |D(G(fake)): 0.48


HBox(children=(IntProgress(value=0, description='BATCH', max=600), HTML(value='')))

[32]rec loss: 0.0114 |sim loss: 0.0012|adv loss: 0.7402 |sd loss: 0.6872     |D(real): 0.69 |D(fake): 0.69 |D(G(fake)): 0.48


HBox(children=(IntProgress(value=0, description='BATCH', max=600), HTML(value='')))

[33]rec loss: 0.0119 |sim loss: 0.0013|adv loss: 0.7436 |sd loss: 0.6889     |D(real): 0.69 |D(fake): 0.69 |D(G(fake)): 0.48


HBox(children=(IntProgress(value=0, description='BATCH', max=600), HTML(value='')))

[34]rec loss: 0.0117 |sim loss: 0.0013|adv loss: 0.7426 |sd loss: 0.6881     |D(real): 0.69 |D(fake): 0.69 |D(G(fake)): 0.48


HBox(children=(IntProgress(value=0, description='BATCH', max=600), HTML(value='')))

[35]rec loss: 0.0116 |sim loss: 0.0013|adv loss: 0.7548 |sd loss: 0.6857     |D(real): 0.69 |D(fake): 0.68 |D(G(fake)): 0.48


HBox(children=(IntProgress(value=0, description='BATCH', max=600), HTML(value='')))

[36]rec loss: 0.0117 |sim loss: 0.0014|adv loss: 0.7382 |sd loss: 0.6887     |D(real): 0.69 |D(fake): 0.68 |D(G(fake)): 0.48


HBox(children=(IntProgress(value=0, description='BATCH', max=600), HTML(value='')))

[37]rec loss: 0.0114 |sim loss: 0.0012|adv loss: 0.7500 |sd loss: 0.6851     |D(real): 0.69 |D(fake): 0.68 |D(G(fake)): 0.48


HBox(children=(IntProgress(value=0, description='BATCH', max=600), HTML(value='')))

[38]rec loss: 0.0121 |sim loss: 0.0014|adv loss: 0.7422 |sd loss: 0.6883     |D(real): 0.69 |D(fake): 0.69 |D(G(fake)): 0.48


HBox(children=(IntProgress(value=0, description='BATCH', max=600), HTML(value='')))

[39]rec loss: 0.0114 |sim loss: 0.0012|adv loss: 0.7321 |sd loss: 0.6910     |D(real): 0.69 |D(fake): 0.69 |D(G(fake)): 0.48


HBox(children=(IntProgress(value=0, description='BATCH', max=600), HTML(value='')))

[40]rec loss: 0.0119 |sim loss: 0.0013|adv loss: 0.7437 |sd loss: 0.6890     |D(real): 0.69 |D(fake): 0.69 |D(G(fake)): 0.48


HBox(children=(IntProgress(value=0, description='BATCH', max=600), HTML(value='')))

[41]rec loss: 0.0117 |sim loss: 0.0012|adv loss: 0.7482 |sd loss: 0.6873     |D(real): 0.69 |D(fake): 0.68 |D(G(fake)): 0.48


HBox(children=(IntProgress(value=0, description='BATCH', max=600), HTML(value='')))

[42]rec loss: 0.0116 |sim loss: 0.0013|adv loss: 0.7470 |sd loss: 0.6896     |D(real): 0.69 |D(fake): 0.69 |D(G(fake)): 0.48


HBox(children=(IntProgress(value=0, description='BATCH', max=600), HTML(value='')))

[43]rec loss: 0.0117 |sim loss: 0.0013|adv loss: 0.7535 |sd loss: 0.6846     |D(real): 0.69 |D(fake): 0.68 |D(G(fake)): 0.48


HBox(children=(IntProgress(value=0, description='BATCH', max=600), HTML(value='')))

[44]rec loss: 0.0117 |sim loss: 0.0015|adv loss: 0.7586 |sd loss: 0.6838     |D(real): 0.69 |D(fake): 0.68 |D(G(fake)): 0.47


HBox(children=(IntProgress(value=0, description='BATCH', max=600), HTML(value='')))

[45]rec loss: 0.0116 |sim loss: 0.0012|adv loss: 0.7484 |sd loss: 0.6876     |D(real): 0.69 |D(fake): 0.68 |D(G(fake)): 0.48


HBox(children=(IntProgress(value=0, description='BATCH', max=600), HTML(value='')))

[46]rec loss: 0.0113 |sim loss: 0.0013|adv loss: 0.7497 |sd loss: 0.6865     |D(real): 0.69 |D(fake): 0.68 |D(G(fake)): 0.48


HBox(children=(IntProgress(value=0, description='BATCH', max=600), HTML(value='')))

[47]rec loss: 0.0114 |sim loss: 0.0013|adv loss: 0.7531 |sd loss: 0.6855     |D(real): 0.69 |D(fake): 0.68 |D(G(fake)): 0.48


HBox(children=(IntProgress(value=0, description='BATCH', max=600), HTML(value='')))

[48]rec loss: 0.0113 |sim loss: 0.0013|adv loss: 0.7538 |sd loss: 0.6844     |D(real): 0.69 |D(fake): 0.68 |D(G(fake)): 0.48


HBox(children=(IntProgress(value=0, description='BATCH', max=600), HTML(value='')))

[49]rec loss: 0.0115 |sim loss: 0.0015|adv loss: 0.7592 |sd loss: 0.6861     |D(real): 0.69 |D(fake): 0.68 |D(G(fake)): 0.47


HBox(children=(IntProgress(value=0, description='BATCH', max=600), HTML(value='')))

[50]rec loss: 0.0116 |sim loss: 0.0014|adv loss: 0.7818 |sd loss: 0.6761     |D(real): 0.68 |D(fake): 0.67 |D(G(fake)): 0.47


HBox(children=(IntProgress(value=0, description='BATCH', max=600), HTML(value='')))

[51]rec loss: 0.0119 |sim loss: 0.0015|adv loss: 0.7665 |sd loss: 0.6850     |D(real): 0.69 |D(fake): 0.68 |D(G(fake)): 0.47


HBox(children=(IntProgress(value=0, description='BATCH', max=600), HTML(value='')))

[52]rec loss: 0.0120 |sim loss: 0.0013|adv loss: 0.7544 |sd loss: 0.6823     |D(real): 0.69 |D(fake): 0.68 |D(G(fake)): 0.48


HBox(children=(IntProgress(value=0, description='BATCH', max=600), HTML(value='')))

[53]rec loss: 0.0116 |sim loss: 0.0013|adv loss: 0.7576 |sd loss: 0.6869     |D(real): 0.69 |D(fake): 0.68 |D(G(fake)): 0.48


HBox(children=(IntProgress(value=0, description='BATCH', max=600), HTML(value='')))

[54]rec loss: 0.0114 |sim loss: 0.0013|adv loss: 0.7566 |sd loss: 0.6829     |D(real): 0.69 |D(fake): 0.68 |D(G(fake)): 0.48


HBox(children=(IntProgress(value=0, description='BATCH', max=600), HTML(value='')))

[55]rec loss: 0.0117 |sim loss: 0.0015|adv loss: 0.7431 |sd loss: 0.6867     |D(real): 0.69 |D(fake): 0.68 |D(G(fake)): 0.48


HBox(children=(IntProgress(value=0, description='BATCH', max=600), HTML(value='')))

[56]rec loss: 0.0113 |sim loss: 0.0010|adv loss: 0.7482 |sd loss: 0.6852     |D(real): 0.69 |D(fake): 0.68 |D(G(fake)): 0.48


HBox(children=(IntProgress(value=0, description='BATCH', max=600), HTML(value='')))

[57]rec loss: 0.0110 |sim loss: 0.0008|adv loss: 0.7085 |sd loss: 0.6907     |D(real): 0.69 |D(fake): 0.69 |D(G(fake)): 0.50


HBox(children=(IntProgress(value=0, description='BATCH', max=600), HTML(value='')))

[58]rec loss: 0.0106 |sim loss: 0.0007|adv loss: 0.7124 |sd loss: 0.6901     |D(real): 0.69 |D(fake): 0.69 |D(G(fake)): 0.49


HBox(children=(IntProgress(value=0, description='BATCH', max=600), HTML(value='')))

[59]rec loss: 0.0103 |sim loss: 0.0007|adv loss: 0.7129 |sd loss: 0.6918     |D(real): 0.69 |D(fake): 0.69 |D(G(fake)): 0.49


HBox(children=(IntProgress(value=0, description='BATCH', max=600), HTML(value='')))

[60]rec loss: 0.0103 |sim loss: 0.0007|adv loss: 0.7131 |sd loss: 0.6921     |D(real): 0.69 |D(fake): 0.69 |D(G(fake)): 0.49


HBox(children=(IntProgress(value=0, description='BATCH', max=600), HTML(value='')))

[61]rec loss: 0.0099 |sim loss: 0.0006|adv loss: 0.7084 |sd loss: 0.6928     |D(real): 0.69 |D(fake): 0.69 |D(G(fake)): 0.49


HBox(children=(IntProgress(value=0, description='BATCH', max=600), HTML(value='')))

[62]rec loss: 0.0097 |sim loss: 0.0006|adv loss: 0.7077 |sd loss: 0.6926     |D(real): 0.69 |D(fake): 0.69 |D(G(fake)): 0.49


HBox(children=(IntProgress(value=0, description='BATCH', max=600), HTML(value='')))

[63]rec loss: 0.0099 |sim loss: 0.0007|adv loss: 0.7123 |sd loss: 0.6917     |D(real): 0.69 |D(fake): 0.69 |D(G(fake)): 0.49


HBox(children=(IntProgress(value=0, description='BATCH', max=600), HTML(value='')))

[64]rec loss: 0.0096 |sim loss: 0.0007|adv loss: 0.7100 |sd loss: 0.6917     |D(real): 0.69 |D(fake): 0.69 |D(G(fake)): 0.49


HBox(children=(IntProgress(value=0, description='BATCH', max=600), HTML(value='')))

[65]rec loss: 0.0097 |sim loss: 0.0006|adv loss: 0.7126 |sd loss: 0.6927     |D(real): 0.69 |D(fake): 0.69 |D(G(fake)): 0.49


HBox(children=(IntProgress(value=0, description='BATCH', max=600), HTML(value='')))

[66]rec loss: 0.0096 |sim loss: 0.0006|adv loss: 0.7041 |sd loss: 0.6919     |D(real): 0.69 |D(fake): 0.69 |D(G(fake)): 0.50


HBox(children=(IntProgress(value=0, description='BATCH', max=600), HTML(value='')))

[67]rec loss: 0.0095 |sim loss: 0.0007|adv loss: 0.7107 |sd loss: 0.6924     |D(real): 0.69 |D(fake): 0.69 |D(G(fake)): 0.49


HBox(children=(IntProgress(value=0, description='BATCH', max=600), HTML(value='')))

[68]rec loss: 0.0094 |sim loss: 0.0006|adv loss: 0.7107 |sd loss: 0.6919     |D(real): 0.69 |D(fake): 0.69 |D(G(fake)): 0.49


HBox(children=(IntProgress(value=0, description='BATCH', max=600), HTML(value='')))

[69]rec loss: 0.0094 |sim loss: 0.0006|adv loss: 0.7100 |sd loss: 0.6927     |D(real): 0.69 |D(fake): 0.69 |D(G(fake)): 0.49


HBox(children=(IntProgress(value=0, description='BATCH', max=600), HTML(value='')))

[70]rec loss: 0.0094 |sim loss: 0.0006|adv loss: 0.7089 |sd loss: 0.6925     |D(real): 0.69 |D(fake): 0.69 |D(G(fake)): 0.49


HBox(children=(IntProgress(value=0, description='BATCH', max=600), HTML(value='')))

[71]rec loss: 0.0094 |sim loss: 0.0006|adv loss: 0.7142 |sd loss: 0.6924     |D(real): 0.69 |D(fake): 0.69 |D(G(fake)): 0.49


HBox(children=(IntProgress(value=0, description='BATCH', max=600), HTML(value='')))

[72]rec loss: 0.0093 |sim loss: 0.0006|adv loss: 0.7043 |sd loss: 0.6928     |D(real): 0.69 |D(fake): 0.69 |D(G(fake)): 0.50


HBox(children=(IntProgress(value=0, description='BATCH', max=600), HTML(value='')))

[73]rec loss: 0.0092 |sim loss: 0.0006|adv loss: 0.7053 |sd loss: 0.6927     |D(real): 0.69 |D(fake): 0.69 |D(G(fake)): 0.50


HBox(children=(IntProgress(value=0, description='BATCH', max=600), HTML(value='')))

[74]rec loss: 0.0093 |sim loss: 0.0006|adv loss: 0.7063 |sd loss: 0.6931     |D(real): 0.69 |D(fake): 0.69 |D(G(fake)): 0.50


HBox(children=(IntProgress(value=0, description='BATCH', max=600), HTML(value='')))

[75]rec loss: 0.0091 |sim loss: 0.0006|adv loss: 0.7041 |sd loss: 0.6926     |D(real): 0.69 |D(fake): 0.69 |D(G(fake)): 0.50


HBox(children=(IntProgress(value=0, description='BATCH', max=600), HTML(value='')))

[76]rec loss: 0.0091 |sim loss: 0.0006|adv loss: 0.7061 |sd loss: 0.6929     |D(real): 0.69 |D(fake): 0.69 |D(G(fake)): 0.49


HBox(children=(IntProgress(value=0, description='BATCH', max=600), HTML(value='')))

[77]rec loss: 0.0091 |sim loss: 0.0006|adv loss: 0.7065 |sd loss: 0.6926     |D(real): 0.69 |D(fake): 0.69 |D(G(fake)): 0.49


HBox(children=(IntProgress(value=0, description='BATCH', max=600), HTML(value='')))

[78]rec loss: 0.0090 |sim loss: 0.0006|adv loss: 0.7024 |sd loss: 0.6927     |D(real): 0.69 |D(fake): 0.69 |D(G(fake)): 0.50


HBox(children=(IntProgress(value=0, description='BATCH', max=600), HTML(value='')))

[79]rec loss: 0.0091 |sim loss: 0.0006|adv loss: 0.7057 |sd loss: 0.6928     |D(real): 0.69 |D(fake): 0.69 |D(G(fake)): 0.50


HBox(children=(IntProgress(value=0, description='BATCH', max=600), HTML(value='')))

[80]rec loss: 0.0092 |sim loss: 0.0007|adv loss: 0.7092 |sd loss: 0.6924     |D(real): 0.69 |D(fake): 0.69 |D(G(fake)): 0.49


HBox(children=(IntProgress(value=0, description='BATCH', max=600), HTML(value='')))

[81]rec loss: 0.0092 |sim loss: 0.0007|adv loss: 0.7102 |sd loss: 0.6924     |D(real): 0.69 |D(fake): 0.69 |D(G(fake)): 0.49


HBox(children=(IntProgress(value=0, description='BATCH', max=600), HTML(value='')))

[82]rec loss: 0.0089 |sim loss: 0.0006|adv loss: 0.7024 |sd loss: 0.6932     |D(real): 0.69 |D(fake): 0.69 |D(G(fake)): 0.50


HBox(children=(IntProgress(value=0, description='BATCH', max=600), HTML(value='')))

[83]rec loss: 0.0090 |sim loss: 0.0006|adv loss: 0.7061 |sd loss: 0.6924     |D(real): 0.69 |D(fake): 0.69 |D(G(fake)): 0.49


HBox(children=(IntProgress(value=0, description='BATCH', max=600), HTML(value='')))

[84]rec loss: 0.0090 |sim loss: 0.0006|adv loss: 0.7092 |sd loss: 0.6926     |D(real): 0.69 |D(fake): 0.69 |D(G(fake)): 0.49


HBox(children=(IntProgress(value=0, description='BATCH', max=600), HTML(value='')))

[85]rec loss: 0.0089 |sim loss: 0.0006|adv loss: 0.7056 |sd loss: 0.6923     |D(real): 0.69 |D(fake): 0.69 |D(G(fake)): 0.50


HBox(children=(IntProgress(value=0, description='BATCH', max=600), HTML(value='')))

[86]rec loss: 0.0089 |sim loss: 0.0006|adv loss: 0.7063 |sd loss: 0.6925     |D(real): 0.69 |D(fake): 0.69 |D(G(fake)): 0.49


HBox(children=(IntProgress(value=0, description='BATCH', max=600), HTML(value='')))

[87]rec loss: 0.0089 |sim loss: 0.0006|adv loss: 0.7057 |sd loss: 0.6927     |D(real): 0.69 |D(fake): 0.69 |D(G(fake)): 0.50


HBox(children=(IntProgress(value=0, description='BATCH', max=600), HTML(value='')))

In [None]:
len(train_loader)

In [None]:
for i, x in enumerate(train_loader):
    if i == 0:
        with torch.no_grad():
            x = torch.transpose(x, 0, 1)
            x = x.to(device)
            plot_rec(x, 200)
            plot_analogy(x, 200)