In [1]:
import logging
import torch
import torch.optim as optim
import torch.nn as nn
import torch.nn.functional as F
import argparse
import os
import random
import numpy as np
from torch.autograd import Variable
from torch.utils.data import DataLoader
import utils
import itertools
from tqdm import tqdm_notebook
import models.dcgan_unet_64 as dcgan_unet_models
import models.dcgan_64 as dcgan_models
import models.classifiers as classifiers
import models.my_model as my_model
from data.moving_mnist import MovingMNIST

In [2]:
torch.cuda.set_device(1)

Constant definition

In [3]:
np.random.seed(1)
random.seed(1)
torch.manual_seed(1)
torch.cuda.manual_seed_all(1)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [4]:
lr = 2e-3
seq_len = 12
beta1 = 0.5
content_dim = 128
pose_dim = 50
channels = 3
normalize = False
sd_nf = 100
image_width = 64
batch_size = 100
log_dir = './logs/0529_noiseLikePoseVector/'
os.makedirs(os.path.join(log_dir, 'rec'), exist_ok=True)
os.makedirs(os.path.join(log_dir, 'analogy'), exist_ok=True)
logging.basicConfig(filename=os.path.join(log_dir, 'record.txt'), level=logging.DEBUG)

Data Loader

In [5]:
train_data = MovingMNIST(True, '../data_uni/', seq_len=seq_len)
test_data = MovingMNIST(False, '../data_uni/', seq_len=seq_len)

train_loader = DataLoader(
    train_data,
    batch_size=batch_size,
    num_workers=16,
    shuffle=True,
    drop_last=True,
    pin_memory=True
)
test_loader = DataLoader(
    test_data,
    batch_size=batch_size,
    num_workers=0,
    shuffle=True,
    drop_last=True,
    pin_memory=True
)

Model definition

In [6]:
# # netEC = dcgan_unet_models.content_encoder(content_dim, channels).to(device)
# netEC = dcgan_models.content_encoder(content_dim, channels).to(device)
# netEP = dcgan_models.pose_encoder(pose_dim, channels).to(device)
# # netD = dcgan_unet_models.decoder(content_dim, pose_dim, channels).to(device)
# netD = dcgan_models.decoder(content_dim, pose_dim, channels).to(device)
# netC = classifiers.scene_discriminator(pose_dim, sd_nf).to(device)

netEC = my_model.content_encoder(content_dim, channels).to(device)
netEP = my_model.pose_encoder(pose_dim, channels, conditional=True).to(device)
netD = my_model.decoder(content_dim, pose_dim, channels).to(device)
# netC = my_model.scene_discriminator(pose_dim, sd_nf).to(device)
# netC = my_model.Discriminator(channels).to(device)

netEC.apply(utils.weights_init)
netEP.apply(utils.weights_init)
netD.apply(utils.weights_init)
# netC.apply(utils.weights_init)

print(netEC)
print(netEP)
print(netD)
# print(netC)

content_encoder(
  (main): Sequential(
    (0): dcgan_conv(
      (main): Sequential(
        (0): Conv2d(3, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
        (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU(inplace)
      )
    )
    (1): dcgan_conv(
      (main): Sequential(
        (0): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
        (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU(inplace)
      )
    )
    (2): dcgan_conv(
      (main): Sequential(
        (0): Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
        (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU(inplace)
      )
    )
    (3): dcgan_conv(
      (main): Sequential(
        (0): Conv2d(256, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
 

In [7]:
optimizerEC = optim.Adam(netEC.parameters(), lr=lr, betas=(beta1, 0.999))
optimizerEP = optim.Adam(netEP.parameters(), lr=lr, betas=(beta1, 0.999))
optimizerD = optim.Adam(netD.parameters(), lr=lr, betas=(beta1, 0.999))
# optimizerC = optim.Adam(netC.parameters(), lr=lr, betas=(beta1, 0.999))

Plot function

In [8]:
# --------- plotting funtions ------------------------------------
def plot_rec(x, epoch, dtype):
    x_c = x[0]
    x_p = x[np.random.randint(1, len(x))]

    h_c = netEC(x_c)
    h_p = netEP(x_p, h_c)
    rec = netD([h_c, h_p])

    x_c, x_p, rec = x_c.data, x_p.data, rec.data
    fname = '{}-{}.png'.format(dtype, epoch)
    fname = os.path.join(log_dir, 'rec', fname)
    to_plot = []
    row_sz = 5
    nplot = 20
    for i in range(0, nplot-row_sz, row_sz):
        row = [[xc, xp, xr] for xc, xp, xr in zip(x_c[i:i+row_sz], x_p[i:i+row_sz], rec[i:i+row_sz])]
        to_plot.append(list(itertools.chain(*row)))
    utils.save_tensors_image(fname, to_plot)

def plot_analogy(x, epoch, dtype):
    x_c = x[0]
    h_c = netEC(x_c)
    
    nrow = 10
    row_sz = len(x)
    to_plot = []
    row = [xi[0].data for xi in x]
    zeros = torch.zeros(channels, image_width, image_width)
    to_plot.append([zeros] + row)
    for i in range(nrow):
        to_plot.append([x[0][i].data])

    for j in range(0, row_sz):
        # for each time step
        h_p = netEP(x[j], h_c).data
        # first 10 pose vector, equal to first pose vector
        for i in range(nrow):
            h_p[i] = h_p[0]
        rec = netD([h_c, h_p])
        for i in range(nrow):
            to_plot[i+1].append(rec[i].data.clone())

    fname = '{}-{}.png'.format(dtype, epoch)
    fname = os.path.join(log_dir, 'analogy', fname)
    utils.save_tensors_image(fname, to_plot)

Training function

In [9]:
def train(x):
    optimizerEC.zero_grad()
    optimizerEP.zero_grad()
    optimizerD.zero_grad()

    x_c1 = x[np.random.randint(len(x))]
    x_c2 = x[np.random.randint(len(x))]
    x_p = x[np.random.randint(len(x))]

    h_c1 = netEC(x_c1)
    h_c2 = netEC(x_c2).detach()
    h_p = torch.randn((h_c1.size(0), pose_dim, 1, 1), device=device)


    # similarity loss: ||h_c1 - h_c2||
    sim_loss = F.mse_loss(h_c1, h_c2)


    # reconstruction loss: ||D(h_c1, h_p1), x_p1|| 
    rec = netD([h_c1, h_p])
    rec_loss = F.mse_loss(rec, x_p)
    
    # Noise reconstruction loss
    h_p_rec = netEP(rec.detach(), h_c1)
    noise_loss = F.mse_loss(h_p_rec, h_p)
    
    # full loss
    loss = sim_loss + rec_loss + 0.1 * noise_loss
    loss.backward()

    optimizerEC.step()
    optimizerEP.step()
    optimizerD.step()

    return sim_loss.item(), rec_loss.item(), noise_loss.item()

In [10]:
def train_scene_discriminator(x):
    optimizerC.zero_grad()

    target = torch.FloatTensor(batch_size, 1).to(device)
    
    # condition
    h_c = netEC(x[np.random.randint(len(x))]).detach()
    
    x1 = x[0]
    x2 = x[1]
    h_p1 = netEP(x1, h_c).detach()
    h_p2 = netEP(x2, h_c).detach()

    half = batch_size // 2
    rp = torch.randperm(half).cuda()
    h_p2[:half] = h_p2[rp]
    target[:half] = 0
    target[half:] = 1

    out = netC([h_p1, h_p2])
    bce = F.binary_cross_entropy(out, target)

    bce.backward()
    optimizerC.step()

    acc =out[:half].le(0.5).sum() + out[half:].gt(0.5).sum()
    return bce.data.cpu().numpy(), acc.data.cpu().numpy() / batch_size

In [11]:
def train_discriminator(x):
    optimizerC.zero_grad()

    real_lbl = torch.FloatTensor(batch_size, 1).fill_(1).to(device)
    fake_lbl = torch.FloatTensor(batch_size, 1).fill_(0).to(device)
    
    x1 = x[np.random.randint(len(x))]
    x2 = x[np.random.randint(len(x))]
    x3 = x[np.random.randint(len(x))]

    # real pair
    # 1. x1
    # 2. reconstructed frames by pose(x2) and content(x1)
    h_c = netEC(x1).detach()
    h_p = netEP(x3, h_c)
    x_rec = netD([h_c, h_p]).detach()
    out_real = netC([x1, x_rec]).view(-1, 1)
    loss_real = F.binary_cross_entropy(out_real, real_lbl)
    D_real = loss_real.mean().item()
    
    # fake pair
    # 1. x1
    # 2. swapped reconstructed frames
    #    by swapped pose(x3) and content(x1)
    idx = torch.randperm(batch_size)
    h_p = netEP(x3, h_c)
    h_p = h_p[idx]
    x_swap = netD([h_c, h_p]).detach()
    out_fake = netC([x1, x_swap]).view(-1, 1)
    loss_fake = F.binary_cross_entropy(out_fake, fake_lbl)
    D_fake = loss_fake.mean().item()
    
    
    bce = 0.5*loss_real + 0.5*loss_fake
    bce.backward()
    optimizerC.step()

    
    return bce.item(), D_real, D_fake

In [12]:
epoch_size = len(train_loader)
test_x = next(iter(test_loader))
test_x = torch.transpose(test_x, 0, 1)
test_x = test_x.to(device)

for epoch in tqdm_notebook(range(200), desc='EPOCH'):
    netEP.train()
    netEC.train()
    netD.train()
#     netC.train()
    epoch_sim_loss, epoch_rec_loss, epoch_noise_loss, epoch_sd_loss = 0, 0, 0, 0
    epoch_D_real, epoch_D_fake, epoch_D_G_fake = 0, 0, 0
    
    for i, x in enumerate(tqdm_notebook(train_loader, desc='BATCH')):
        # x to device
        x = torch.transpose(x, 0, 1)
        x = x.to(device)
        
        # train scene discriminator
        # sd_loss, sd_acc = train_scene_discriminator(x)
#         sd_loss, D_real, D_fake = train_discriminator(x)
#         epoch_sd_loss += sd_loss
#         epoch_D_real += D_real
#         epoch_D_fake += D_fake
        
        # train main model
        sim_loss, rec_loss, noise_loss = train(x)
        epoch_sim_loss += sim_loss
        epoch_rec_loss += rec_loss
        epoch_noise_loss += noise_loss
#         epoch_adv_loss += adv_loss
#         epoch_D_G_fake += D_G_fake
    
    log_str='[%02d]rec loss: %.4f |sim loss: %.4f|noise loss: %.4f' %\
    (epoch,
     epoch_rec_loss/epoch_size,
     epoch_sim_loss/epoch_size,
     epoch_noise_loss/epoch_size
     )
    
    print(log_str)
    logging.info(log_str)
    
    netEP.eval()
    netEC.eval()
    netD.eval()
    
    with torch.no_grad():
        plot_rec(test_x, epoch, 'test')
        plot_analogy(test_x, epoch, 'test')

    # save the model
    torch.save({
        'netD': netD,
        'netEP': netEP,
        'netEC': netEC},
        '%s/model.pth' % log_dir)

HBox(children=(IntProgress(value=0, description='EPOCH', max=200), HTML(value='')))

HBox(children=(IntProgress(value=0, description='BATCH', max=600), HTML(value='')))

[00]rec loss: 0.0159 |sim loss: 0.0354|noise loss: 1.2883


HBox(children=(IntProgress(value=0, description='BATCH', max=600), HTML(value='')))

[01]rec loss: 0.0140 |sim loss: 0.0067|noise loss: 0.9991


HBox(children=(IntProgress(value=0, description='BATCH', max=600), HTML(value='')))

[02]rec loss: 0.0140 |sim loss: 0.0030|noise loss: 0.9715


HBox(children=(IntProgress(value=0, description='BATCH', max=600), HTML(value='')))

[03]rec loss: 0.0139 |sim loss: 0.0015|noise loss: 0.9492


HBox(children=(IntProgress(value=0, description='BATCH', max=600), HTML(value='')))

[04]rec loss: 0.0139 |sim loss: 0.0011|noise loss: 0.9326


HBox(children=(IntProgress(value=0, description='BATCH', max=600), HTML(value='')))

[05]rec loss: 0.0139 |sim loss: 0.0008|noise loss: 0.9317


HBox(children=(IntProgress(value=0, description='BATCH', max=600), HTML(value='')))

[06]rec loss: 0.0139 |sim loss: 0.0007|noise loss: 0.9127


HBox(children=(IntProgress(value=0, description='BATCH', max=600), HTML(value='')))

[07]rec loss: 0.0139 |sim loss: 0.0005|noise loss: 0.8989


HBox(children=(IntProgress(value=0, description='BATCH', max=600), HTML(value='')))

[08]rec loss: 0.0139 |sim loss: 0.0005|noise loss: 0.8863


HBox(children=(IntProgress(value=0, description='BATCH', max=600), HTML(value='')))

[09]rec loss: 0.0140 |sim loss: 0.0004|noise loss: 0.8773


HBox(children=(IntProgress(value=0, description='BATCH', max=600), HTML(value='')))

[10]rec loss: 0.0139 |sim loss: 0.0004|noise loss: 0.8661


HBox(children=(IntProgress(value=0, description='BATCH', max=600), HTML(value='')))

[11]rec loss: 0.0139 |sim loss: 0.0004|noise loss: 0.8602


HBox(children=(IntProgress(value=0, description='BATCH', max=600), HTML(value='')))

[12]rec loss: 0.0140 |sim loss: 0.0003|noise loss: 0.8686


HBox(children=(IntProgress(value=0, description='BATCH', max=600), HTML(value='')))

[13]rec loss: 0.0139 |sim loss: 0.0003|noise loss: 0.8509


HBox(children=(IntProgress(value=0, description='BATCH', max=600), HTML(value='')))

[14]rec loss: 0.0139 |sim loss: 0.0003|noise loss: 0.8439


HBox(children=(IntProgress(value=0, description='BATCH', max=600), HTML(value='')))

[15]rec loss: 0.0139 |sim loss: 0.0003|noise loss: 0.8678


HBox(children=(IntProgress(value=0, description='BATCH', max=600), HTML(value='')))

[16]rec loss: 0.0139 |sim loss: 0.0002|noise loss: 0.8464


HBox(children=(IntProgress(value=0, description='BATCH', max=600), HTML(value='')))

[17]rec loss: 0.0139 |sim loss: 0.0002|noise loss: 0.8003


HBox(children=(IntProgress(value=0, description='BATCH', max=600), HTML(value='')))

[18]rec loss: 0.0139 |sim loss: 0.0002|noise loss: 0.7682


HBox(children=(IntProgress(value=0, description='BATCH', max=600), HTML(value='')))

[19]rec loss: 0.0140 |sim loss: 0.0003|noise loss: 0.7610


HBox(children=(IntProgress(value=0, description='BATCH', max=600), HTML(value='')))

[20]rec loss: 0.0140 |sim loss: 0.0002|noise loss: 0.7402


HBox(children=(IntProgress(value=0, description='BATCH', max=600), HTML(value='')))

[21]rec loss: 0.0139 |sim loss: 0.0002|noise loss: 0.7236


HBox(children=(IntProgress(value=0, description='BATCH', max=600), HTML(value='')))

[22]rec loss: 0.0139 |sim loss: 0.0002|noise loss: 0.7076


HBox(children=(IntProgress(value=0, description='BATCH', max=600), HTML(value='')))

[23]rec loss: 0.0139 |sim loss: 0.0003|noise loss: 0.6986


HBox(children=(IntProgress(value=0, description='BATCH', max=600), HTML(value='')))

[24]rec loss: 0.0139 |sim loss: 0.0003|noise loss: 0.7141


HBox(children=(IntProgress(value=0, description='BATCH', max=600), HTML(value='')))

[25]rec loss: 0.0139 |sim loss: 0.0003|noise loss: 0.7046


HBox(children=(IntProgress(value=0, description='BATCH', max=600), HTML(value='')))

[26]rec loss: 0.0139 |sim loss: 0.0004|noise loss: 0.7352


HBox(children=(IntProgress(value=0, description='BATCH', max=600), HTML(value='')))

[27]rec loss: 0.0138 |sim loss: 0.0003|noise loss: 0.7451


HBox(children=(IntProgress(value=0, description='BATCH', max=600), HTML(value='')))

[28]rec loss: 0.0137 |sim loss: 0.0002|noise loss: 0.7619


HBox(children=(IntProgress(value=0, description='BATCH', max=600), HTML(value='')))

[29]rec loss: 0.0136 |sim loss: 0.0002|noise loss: 0.7387


HBox(children=(IntProgress(value=0, description='BATCH', max=600), HTML(value='')))

[30]rec loss: 0.0136 |sim loss: 0.0002|noise loss: 0.7381


HBox(children=(IntProgress(value=0, description='BATCH', max=600), HTML(value='')))

[31]rec loss: 0.0134 |sim loss: 0.0002|noise loss: 0.7498


HBox(children=(IntProgress(value=0, description='BATCH', max=600), HTML(value='')))

[32]rec loss: 0.0134 |sim loss: 0.0001|noise loss: 0.7293


HBox(children=(IntProgress(value=0, description='BATCH', max=600), HTML(value='')))

[33]rec loss: 0.0133 |sim loss: 0.0001|noise loss: 0.7320


HBox(children=(IntProgress(value=0, description='BATCH', max=600), HTML(value='')))

[34]rec loss: 0.0131 |sim loss: 0.0001|noise loss: 0.7563


HBox(children=(IntProgress(value=0, description='BATCH', max=600), HTML(value='')))

[35]rec loss: 0.0132 |sim loss: 0.0001|noise loss: 0.7296


HBox(children=(IntProgress(value=0, description='BATCH', max=600), HTML(value='')))

[36]rec loss: 0.0132 |sim loss: 0.0001|noise loss: 0.7178


HBox(children=(IntProgress(value=0, description='BATCH', max=600), HTML(value='')))

[37]rec loss: 0.0131 |sim loss: 0.0001|noise loss: 0.7143


HBox(children=(IntProgress(value=0, description='BATCH', max=600), HTML(value='')))

[38]rec loss: 0.0130 |sim loss: 0.0001|noise loss: 0.7217


HBox(children=(IntProgress(value=0, description='BATCH', max=600), HTML(value='')))

[39]rec loss: 0.0130 |sim loss: 0.0001|noise loss: 0.7105


HBox(children=(IntProgress(value=0, description='BATCH', max=600), HTML(value='')))

[40]rec loss: 0.0131 |sim loss: 0.0001|noise loss: 0.6963


HBox(children=(IntProgress(value=0, description='BATCH', max=600), HTML(value='')))

[41]rec loss: 0.0130 |sim loss: 0.0001|noise loss: 0.6990


HBox(children=(IntProgress(value=0, description='BATCH', max=600), HTML(value='')))

[42]rec loss: 0.0131 |sim loss: 0.0001|noise loss: 0.6915


HBox(children=(IntProgress(value=0, description='BATCH', max=600), HTML(value='')))

[43]rec loss: 0.0130 |sim loss: 0.0001|noise loss: 0.6877


HBox(children=(IntProgress(value=0, description='BATCH', max=600), HTML(value='')))

[44]rec loss: 0.0130 |sim loss: 0.0001|noise loss: 0.6874


HBox(children=(IntProgress(value=0, description='BATCH', max=600), HTML(value='')))

[45]rec loss: 0.0129 |sim loss: 0.0001|noise loss: 0.6858


HBox(children=(IntProgress(value=0, description='BATCH', max=600), HTML(value='')))

[46]rec loss: 0.0131 |sim loss: 0.0001|noise loss: 0.6775


HBox(children=(IntProgress(value=0, description='BATCH', max=600), HTML(value='')))

[47]rec loss: 0.0128 |sim loss: 0.0001|noise loss: 0.6961


HBox(children=(IntProgress(value=0, description='BATCH', max=600), HTML(value='')))

[48]rec loss: 0.0129 |sim loss: 0.0001|noise loss: 0.6819


HBox(children=(IntProgress(value=0, description='BATCH', max=600), HTML(value='')))

[49]rec loss: 0.0128 |sim loss: 0.0001|noise loss: 0.6955


HBox(children=(IntProgress(value=0, description='BATCH', max=600), HTML(value='')))

[50]rec loss: 0.0129 |sim loss: 0.0001|noise loss: 0.6896


HBox(children=(IntProgress(value=0, description='BATCH', max=600), HTML(value='')))

[51]rec loss: 0.0130 |sim loss: 0.0001|noise loss: 0.6782


HBox(children=(IntProgress(value=0, description='BATCH', max=600), HTML(value='')))

[52]rec loss: 0.0129 |sim loss: 0.0001|noise loss: 0.6859


HBox(children=(IntProgress(value=0, description='BATCH', max=600), HTML(value='')))

[53]rec loss: 0.0130 |sim loss: 0.0001|noise loss: 0.6833


HBox(children=(IntProgress(value=0, description='BATCH', max=600), HTML(value='')))

[54]rec loss: 0.0129 |sim loss: 0.0001|noise loss: 0.6843


HBox(children=(IntProgress(value=0, description='BATCH', max=600), HTML(value='')))

[55]rec loss: 0.0129 |sim loss: 0.0001|noise loss: 0.6918


HBox(children=(IntProgress(value=0, description='BATCH', max=600), HTML(value='')))

[56]rec loss: 0.0129 |sim loss: 0.0001|noise loss: 0.6853


HBox(children=(IntProgress(value=0, description='BATCH', max=600), HTML(value='')))

[57]rec loss: 0.0128 |sim loss: 0.0001|noise loss: 0.6872


HBox(children=(IntProgress(value=0, description='BATCH', max=600), HTML(value='')))

[58]rec loss: 0.0129 |sim loss: 0.0001|noise loss: 0.6928


HBox(children=(IntProgress(value=0, description='BATCH', max=600), HTML(value='')))

[59]rec loss: 0.0128 |sim loss: 0.0001|noise loss: 0.7057


HBox(children=(IntProgress(value=0, description='BATCH', max=600), HTML(value='')))

[60]rec loss: 0.0129 |sim loss: 0.0001|noise loss: 0.6961


HBox(children=(IntProgress(value=0, description='BATCH', max=600), HTML(value='')))

[61]rec loss: 0.0129 |sim loss: 0.0001|noise loss: 0.6915


HBox(children=(IntProgress(value=0, description='BATCH', max=600), HTML(value='')))

[62]rec loss: 0.0129 |sim loss: 0.0001|noise loss: 0.6958


HBox(children=(IntProgress(value=0, description='BATCH', max=600), HTML(value='')))

[63]rec loss: 0.0130 |sim loss: 0.0001|noise loss: 0.6902


HBox(children=(IntProgress(value=0, description='BATCH', max=600), HTML(value='')))

[64]rec loss: 0.0128 |sim loss: 0.0001|noise loss: 0.7014


HBox(children=(IntProgress(value=0, description='BATCH', max=600), HTML(value='')))

[65]rec loss: 0.0128 |sim loss: 0.0001|noise loss: 0.7109


HBox(children=(IntProgress(value=0, description='BATCH', max=600), HTML(value='')))

[66]rec loss: 0.0129 |sim loss: 0.0001|noise loss: 0.7061


HBox(children=(IntProgress(value=0, description='BATCH', max=600), HTML(value='')))

[67]rec loss: 0.0128 |sim loss: 0.0001|noise loss: 0.6971


HBox(children=(IntProgress(value=0, description='BATCH', max=600), HTML(value='')))

[68]rec loss: 0.0127 |sim loss: 0.0001|noise loss: 0.6933


HBox(children=(IntProgress(value=0, description='BATCH', max=600), HTML(value='')))

[69]rec loss: 0.0129 |sim loss: 0.0001|noise loss: 0.6883


HBox(children=(IntProgress(value=0, description='BATCH', max=600), HTML(value='')))

[70]rec loss: 0.0129 |sim loss: 0.0001|noise loss: 0.6729


HBox(children=(IntProgress(value=0, description='BATCH', max=600), HTML(value='')))

[71]rec loss: 0.0129 |sim loss: 0.0001|noise loss: 0.6851


HBox(children=(IntProgress(value=0, description='BATCH', max=600), HTML(value='')))

[72]rec loss: 0.0129 |sim loss: 0.0001|noise loss: 0.6872


HBox(children=(IntProgress(value=0, description='BATCH', max=600), HTML(value='')))

[73]rec loss: 0.0129 |sim loss: 0.0001|noise loss: 0.6774


HBox(children=(IntProgress(value=0, description='BATCH', max=600), HTML(value='')))

[74]rec loss: 0.0128 |sim loss: 0.0001|noise loss: 0.6714


HBox(children=(IntProgress(value=0, description='BATCH', max=600), HTML(value='')))

[75]rec loss: 0.0128 |sim loss: 0.0001|noise loss: 0.6770


HBox(children=(IntProgress(value=0, description='BATCH', max=600), HTML(value='')))

[76]rec loss: 0.0129 |sim loss: 0.0002|noise loss: 0.7106


HBox(children=(IntProgress(value=0, description='BATCH', max=600), HTML(value='')))

[77]rec loss: 0.0128 |sim loss: 0.0001|noise loss: 0.6743


HBox(children=(IntProgress(value=0, description='BATCH', max=600), HTML(value='')))

[78]rec loss: 0.0128 |sim loss: 0.0001|noise loss: 0.7236


HBox(children=(IntProgress(value=0, description='BATCH', max=600), HTML(value='')))

[79]rec loss: 0.0129 |sim loss: 0.0001|noise loss: 0.6709


HBox(children=(IntProgress(value=0, description='BATCH', max=600), HTML(value='')))

[80]rec loss: 0.0128 |sim loss: 0.0001|noise loss: 0.6758


HBox(children=(IntProgress(value=0, description='BATCH', max=600), HTML(value='')))

[81]rec loss: 0.0129 |sim loss: 0.0001|noise loss: 0.6861


HBox(children=(IntProgress(value=0, description='BATCH', max=600), HTML(value='')))

[82]rec loss: 0.0129 |sim loss: 0.0001|noise loss: 0.7014


HBox(children=(IntProgress(value=0, description='BATCH', max=600), HTML(value='')))

[83]rec loss: 0.0129 |sim loss: 0.0001|noise loss: 0.7127


HBox(children=(IntProgress(value=0, description='BATCH', max=600), HTML(value='')))

[84]rec loss: 0.0129 |sim loss: 0.0001|noise loss: 0.6646


HBox(children=(IntProgress(value=0, description='BATCH', max=600), HTML(value='')))

[85]rec loss: 0.0127 |sim loss: 0.0001|noise loss: 0.7139


HBox(children=(IntProgress(value=0, description='BATCH', max=600), HTML(value='')))

[86]rec loss: 0.0129 |sim loss: 0.0001|noise loss: 0.6636


HBox(children=(IntProgress(value=0, description='BATCH', max=600), HTML(value='')))

[87]rec loss: 0.0128 |sim loss: 0.0001|noise loss: 0.6899


HBox(children=(IntProgress(value=0, description='BATCH', max=600), HTML(value='')))

[88]rec loss: 0.0128 |sim loss: 0.0001|noise loss: 0.7388


HBox(children=(IntProgress(value=0, description='BATCH', max=600), HTML(value='')))

[89]rec loss: 0.0127 |sim loss: 0.0001|noise loss: 0.6696


HBox(children=(IntProgress(value=0, description='BATCH', max=600), HTML(value='')))

[90]rec loss: 0.0127 |sim loss: 0.0001|noise loss: 0.6815


HBox(children=(IntProgress(value=0, description='BATCH', max=600), HTML(value='')))

[91]rec loss: 0.0128 |sim loss: 0.0001|noise loss: 0.6777


HBox(children=(IntProgress(value=0, description='BATCH', max=600), HTML(value='')))

[92]rec loss: 0.0128 |sim loss: 0.0001|noise loss: 0.6522


HBox(children=(IntProgress(value=0, description='BATCH', max=600), HTML(value='')))

[93]rec loss: 0.0128 |sim loss: 0.0001|noise loss: 0.6771


HBox(children=(IntProgress(value=0, description='BATCH', max=600), HTML(value='')))

[94]rec loss: 0.0128 |sim loss: 0.0003|noise loss: 0.7817


HBox(children=(IntProgress(value=0, description='BATCH', max=600), HTML(value='')))

[95]rec loss: 0.0127 |sim loss: 0.0003|noise loss: 0.7229


HBox(children=(IntProgress(value=0, description='BATCH', max=600), HTML(value='')))

[96]rec loss: 0.0128 |sim loss: 0.0001|noise loss: 0.6651


HBox(children=(IntProgress(value=0, description='BATCH', max=600), HTML(value='')))

[97]rec loss: 0.0127 |sim loss: 0.0001|noise loss: 0.6806


HBox(children=(IntProgress(value=0, description='BATCH', max=600), HTML(value='')))

[98]rec loss: 0.0128 |sim loss: 0.0001|noise loss: 0.6623


HBox(children=(IntProgress(value=0, description='BATCH', max=600), HTML(value='')))

[99]rec loss: 0.0129 |sim loss: 0.0001|noise loss: 0.6584


HBox(children=(IntProgress(value=0, description='BATCH', max=600), HTML(value='')))

[100]rec loss: 0.0128 |sim loss: 0.0001|noise loss: 0.6474


HBox(children=(IntProgress(value=0, description='BATCH', max=600), HTML(value='')))

[101]rec loss: 0.0128 |sim loss: 0.0001|noise loss: 0.7046


HBox(children=(IntProgress(value=0, description='BATCH', max=600), HTML(value='')))

[102]rec loss: 0.0128 |sim loss: 0.0002|noise loss: 0.7398


HBox(children=(IntProgress(value=0, description='BATCH', max=600), HTML(value='')))

[103]rec loss: 0.0128 |sim loss: 0.0001|noise loss: 0.6695


HBox(children=(IntProgress(value=0, description='BATCH', max=600), HTML(value='')))

[104]rec loss: 0.0128 |sim loss: 0.0001|noise loss: 0.7118


HBox(children=(IntProgress(value=0, description='BATCH', max=600), HTML(value='')))

[105]rec loss: 0.0129 |sim loss: 0.0001|noise loss: 0.6993


HBox(children=(IntProgress(value=0, description='BATCH', max=600), HTML(value='')))

[106]rec loss: 0.0129 |sim loss: 0.0001|noise loss: 0.6564


HBox(children=(IntProgress(value=0, description='BATCH', max=600), HTML(value='')))

[107]rec loss: 0.0128 |sim loss: 0.0001|noise loss: 0.6640


HBox(children=(IntProgress(value=0, description='BATCH', max=600), HTML(value='')))

[108]rec loss: 0.0127 |sim loss: 0.0001|noise loss: 0.6819


HBox(children=(IntProgress(value=0, description='BATCH', max=600), HTML(value='')))

[109]rec loss: 0.0129 |sim loss: 0.0001|noise loss: 0.6541


HBox(children=(IntProgress(value=0, description='BATCH', max=600), HTML(value='')))

[110]rec loss: 0.0129 |sim loss: 0.0001|noise loss: 0.6796


HBox(children=(IntProgress(value=0, description='BATCH', max=600), HTML(value='')))

[111]rec loss: 0.0128 |sim loss: 0.0001|noise loss: 0.6550


HBox(children=(IntProgress(value=0, description='BATCH', max=600), HTML(value='')))

[112]rec loss: 0.0128 |sim loss: 0.0002|noise loss: 0.7016


HBox(children=(IntProgress(value=0, description='BATCH', max=600), HTML(value='')))

[113]rec loss: 0.0127 |sim loss: 0.0001|noise loss: 0.6797


HBox(children=(IntProgress(value=0, description='BATCH', max=600), HTML(value='')))

[114]rec loss: 0.0127 |sim loss: 0.0001|noise loss: 0.6655


HBox(children=(IntProgress(value=0, description='BATCH', max=600), HTML(value='')))

[115]rec loss: 0.0127 |sim loss: 0.0002|noise loss: 0.7091


HBox(children=(IntProgress(value=0, description='BATCH', max=600), HTML(value='')))

[116]rec loss: 0.0128 |sim loss: 0.0009|noise loss: 0.9997


HBox(children=(IntProgress(value=0, description='BATCH', max=600), HTML(value='')))

[117]rec loss: 0.0127 |sim loss: 0.0006|noise loss: 0.9380


HBox(children=(IntProgress(value=0, description='BATCH', max=600), HTML(value='')))

[118]rec loss: 0.0128 |sim loss: 0.0004|noise loss: 0.7403


HBox(children=(IntProgress(value=0, description='BATCH', max=600), HTML(value='')))

[119]rec loss: 0.0127 |sim loss: 0.0003|noise loss: 0.6892


HBox(children=(IntProgress(value=0, description='BATCH', max=600), HTML(value='')))

[120]rec loss: 0.0128 |sim loss: 0.0002|noise loss: 0.6883


HBox(children=(IntProgress(value=0, description='BATCH', max=600), HTML(value='')))

[121]rec loss: 0.0128 |sim loss: 0.0002|noise loss: 0.6578


HBox(children=(IntProgress(value=0, description='BATCH', max=600), HTML(value='')))

[122]rec loss: 0.0129 |sim loss: 0.0002|noise loss: 0.6477


HBox(children=(IntProgress(value=0, description='BATCH', max=600), HTML(value='')))

[123]rec loss: 0.0128 |sim loss: 0.0001|noise loss: 0.6854


HBox(children=(IntProgress(value=0, description='BATCH', max=600), HTML(value='')))

[124]rec loss: 0.0128 |sim loss: 0.0001|noise loss: 0.7190


HBox(children=(IntProgress(value=0, description='BATCH', max=600), HTML(value='')))

[125]rec loss: 0.0128 |sim loss: 0.0001|noise loss: 0.6897


HBox(children=(IntProgress(value=0, description='BATCH', max=600), HTML(value='')))

[126]rec loss: 0.0127 |sim loss: 0.0001|noise loss: 0.6530


HBox(children=(IntProgress(value=0, description='BATCH', max=600), HTML(value='')))

[127]rec loss: 0.0127 |sim loss: 0.0001|noise loss: 0.6589


HBox(children=(IntProgress(value=0, description='BATCH', max=600), HTML(value='')))

[128]rec loss: 0.0128 |sim loss: 0.0003|noise loss: 0.7490


HBox(children=(IntProgress(value=0, description='BATCH', max=600), HTML(value='')))

[129]rec loss: 0.0128 |sim loss: 0.0001|noise loss: 0.6678


HBox(children=(IntProgress(value=0, description='BATCH', max=600), HTML(value='')))

[130]rec loss: 0.0128 |sim loss: 0.0001|noise loss: 0.6715


HBox(children=(IntProgress(value=0, description='BATCH', max=600), HTML(value='')))

[131]rec loss: 0.0129 |sim loss: 0.0001|noise loss: 0.6591


HBox(children=(IntProgress(value=0, description='BATCH', max=600), HTML(value='')))

[132]rec loss: 0.0126 |sim loss: 0.0001|noise loss: 0.7172


HBox(children=(IntProgress(value=0, description='BATCH', max=600), HTML(value='')))

[133]rec loss: 0.0128 |sim loss: 0.0001|noise loss: 0.6878


HBox(children=(IntProgress(value=0, description='BATCH', max=600), HTML(value='')))

[134]rec loss: 0.0128 |sim loss: 0.0001|noise loss: 0.6796


HBox(children=(IntProgress(value=0, description='BATCH', max=600), HTML(value='')))

[135]rec loss: 0.0127 |sim loss: 0.0001|noise loss: 0.6754


HBox(children=(IntProgress(value=0, description='BATCH', max=600), HTML(value='')))

[136]rec loss: 0.0128 |sim loss: 0.0001|noise loss: 0.6717


HBox(children=(IntProgress(value=0, description='BATCH', max=600), HTML(value='')))

[137]rec loss: 0.0128 |sim loss: 0.0002|noise loss: 0.6939


HBox(children=(IntProgress(value=0, description='BATCH', max=600), HTML(value='')))

[138]rec loss: 0.0128 |sim loss: 0.0001|noise loss: 0.6585


HBox(children=(IntProgress(value=0, description='BATCH', max=600), HTML(value='')))

[139]rec loss: 0.0129 |sim loss: 0.0003|noise loss: 0.7087


HBox(children=(IntProgress(value=0, description='BATCH', max=600), HTML(value='')))

[140]rec loss: 0.0128 |sim loss: 0.0002|noise loss: 0.7081


HBox(children=(IntProgress(value=0, description='BATCH', max=600), HTML(value='')))

[141]rec loss: 0.0127 |sim loss: 0.0001|noise loss: 0.6813


HBox(children=(IntProgress(value=0, description='BATCH', max=600), HTML(value='')))

[142]rec loss: 0.0128 |sim loss: 0.0001|noise loss: 0.6825


HBox(children=(IntProgress(value=0, description='BATCH', max=600), HTML(value='')))

[143]rec loss: 0.0128 |sim loss: 0.0001|noise loss: 0.6618


HBox(children=(IntProgress(value=0, description='BATCH', max=600), HTML(value='')))

[144]rec loss: 0.0129 |sim loss: 0.0001|noise loss: 0.6676


HBox(children=(IntProgress(value=0, description='BATCH', max=600), HTML(value='')))

[145]rec loss: 0.0127 |sim loss: 0.0001|noise loss: 0.6694


HBox(children=(IntProgress(value=0, description='BATCH', max=600), HTML(value='')))

[146]rec loss: 0.0128 |sim loss: 0.0001|noise loss: 0.7311


HBox(children=(IntProgress(value=0, description='BATCH', max=600), HTML(value='')))

[147]rec loss: 0.0128 |sim loss: 0.0001|noise loss: 0.6771


HBox(children=(IntProgress(value=0, description='BATCH', max=600), HTML(value='')))

[148]rec loss: 0.0127 |sim loss: 0.0002|noise loss: 0.7519


HBox(children=(IntProgress(value=0, description='BATCH', max=600), HTML(value='')))

[149]rec loss: 0.0128 |sim loss: 0.0002|noise loss: 0.6871


HBox(children=(IntProgress(value=0, description='BATCH', max=600), HTML(value='')))

[150]rec loss: 0.0128 |sim loss: 0.0001|noise loss: 0.6838


HBox(children=(IntProgress(value=0, description='BATCH', max=600), HTML(value='')))

[151]rec loss: 0.0127 |sim loss: 0.0001|noise loss: 0.6767


HBox(children=(IntProgress(value=0, description='BATCH', max=600), HTML(value='')))

[152]rec loss: 0.0127 |sim loss: 0.0001|noise loss: 0.6841


HBox(children=(IntProgress(value=0, description='BATCH', max=600), HTML(value='')))

[153]rec loss: 0.0127 |sim loss: 0.0001|noise loss: 0.6932


HBox(children=(IntProgress(value=0, description='BATCH', max=600), HTML(value='')))

[154]rec loss: 0.0127 |sim loss: 0.0001|noise loss: 0.6934


HBox(children=(IntProgress(value=0, description='BATCH', max=600), HTML(value='')))

[155]rec loss: 0.0128 |sim loss: 0.0001|noise loss: 0.6960


HBox(children=(IntProgress(value=0, description='BATCH', max=600), HTML(value='')))

[156]rec loss: 0.0128 |sim loss: 0.0001|noise loss: 0.6975


HBox(children=(IntProgress(value=0, description='BATCH', max=600), HTML(value='')))

[157]rec loss: 0.0128 |sim loss: 0.0001|noise loss: 0.6763


HBox(children=(IntProgress(value=0, description='BATCH', max=600), HTML(value='')))

[158]rec loss: 0.0127 |sim loss: 0.0001|noise loss: 0.6994


HBox(children=(IntProgress(value=0, description='BATCH', max=600), HTML(value='')))

[159]rec loss: 0.0128 |sim loss: 0.0001|noise loss: 0.6874


HBox(children=(IntProgress(value=0, description='BATCH', max=600), HTML(value='')))

[160]rec loss: 0.0127 |sim loss: 0.0002|noise loss: 0.7133


HBox(children=(IntProgress(value=0, description='BATCH', max=600), HTML(value='')))

[161]rec loss: 0.0128 |sim loss: 0.0001|noise loss: 0.6946


HBox(children=(IntProgress(value=0, description='BATCH', max=600), HTML(value='')))

[162]rec loss: 0.0127 |sim loss: 0.0001|noise loss: 0.7178


HBox(children=(IntProgress(value=0, description='BATCH', max=600), HTML(value='')))

[163]rec loss: 0.0128 |sim loss: 0.0001|noise loss: 0.7002


HBox(children=(IntProgress(value=0, description='BATCH', max=600), HTML(value='')))

[164]rec loss: 0.0128 |sim loss: 0.0001|noise loss: 0.7065


HBox(children=(IntProgress(value=0, description='BATCH', max=600), HTML(value='')))

[165]rec loss: 0.0128 |sim loss: 0.0002|noise loss: 0.6894


HBox(children=(IntProgress(value=0, description='BATCH', max=600), HTML(value='')))

[166]rec loss: 0.0127 |sim loss: 0.0001|noise loss: 0.6836


HBox(children=(IntProgress(value=0, description='BATCH', max=600), HTML(value='')))

[167]rec loss: 0.0128 |sim loss: 0.0001|noise loss: 0.6881


HBox(children=(IntProgress(value=0, description='BATCH', max=600), HTML(value='')))

[168]rec loss: 0.0129 |sim loss: 0.0001|noise loss: 0.6840


HBox(children=(IntProgress(value=0, description='BATCH', max=600), HTML(value='')))

[169]rec loss: 0.0129 |sim loss: 0.0001|noise loss: 0.6721


HBox(children=(IntProgress(value=0, description='BATCH', max=600), HTML(value='')))

[170]rec loss: 0.0129 |sim loss: 0.0001|noise loss: 0.6682


HBox(children=(IntProgress(value=0, description='BATCH', max=600), HTML(value='')))

[171]rec loss: 0.0128 |sim loss: 0.0001|noise loss: 0.6763


HBox(children=(IntProgress(value=0, description='BATCH', max=600), HTML(value='')))

[172]rec loss: 0.0127 |sim loss: 0.0001|noise loss: 0.6885


HBox(children=(IntProgress(value=0, description='BATCH', max=600), HTML(value='')))

[173]rec loss: 0.0128 |sim loss: 0.0001|noise loss: 0.6852


HBox(children=(IntProgress(value=0, description='BATCH', max=600), HTML(value='')))

[174]rec loss: 0.0126 |sim loss: 0.0001|noise loss: 0.6809


HBox(children=(IntProgress(value=0, description='BATCH', max=600), HTML(value='')))

[175]rec loss: 0.0128 |sim loss: 0.0001|noise loss: 0.6910


HBox(children=(IntProgress(value=0, description='BATCH', max=600), HTML(value='')))

[176]rec loss: 0.0127 |sim loss: 0.0001|noise loss: 0.6839


HBox(children=(IntProgress(value=0, description='BATCH', max=600), HTML(value='')))

[177]rec loss: 0.0128 |sim loss: 0.0001|noise loss: 0.7542


HBox(children=(IntProgress(value=0, description='BATCH', max=600), HTML(value='')))

[178]rec loss: 0.0127 |sim loss: 0.0001|noise loss: 0.7968


HBox(children=(IntProgress(value=0, description='BATCH', max=600), HTML(value='')))

[179]rec loss: 0.0128 |sim loss: 0.0001|noise loss: 0.7134


HBox(children=(IntProgress(value=0, description='BATCH', max=600), HTML(value='')))

[180]rec loss: 0.0127 |sim loss: 0.0001|noise loss: 0.7011


HBox(children=(IntProgress(value=0, description='BATCH', max=600), HTML(value='')))

[181]rec loss: 0.0127 |sim loss: 0.0001|noise loss: 0.7054


HBox(children=(IntProgress(value=0, description='BATCH', max=600), HTML(value='')))

[182]rec loss: 0.0128 |sim loss: 0.0002|noise loss: 0.7155


HBox(children=(IntProgress(value=0, description='BATCH', max=600), HTML(value='')))

[183]rec loss: 0.0127 |sim loss: 0.0001|noise loss: 0.6848


HBox(children=(IntProgress(value=0, description='BATCH', max=600), HTML(value='')))

[184]rec loss: 0.0128 |sim loss: 0.0001|noise loss: 0.6866


HBox(children=(IntProgress(value=0, description='BATCH', max=600), HTML(value='')))

[185]rec loss: 0.0127 |sim loss: 0.0001|noise loss: 0.6871


HBox(children=(IntProgress(value=0, description='BATCH', max=600), HTML(value='')))

[186]rec loss: 0.0127 |sim loss: 0.0001|noise loss: 0.6909


HBox(children=(IntProgress(value=0, description='BATCH', max=600), HTML(value='')))

[187]rec loss: 0.0128 |sim loss: 0.0002|noise loss: 0.7358


HBox(children=(IntProgress(value=0, description='BATCH', max=600), HTML(value='')))

[188]rec loss: 0.0127 |sim loss: 0.0001|noise loss: 0.7187


HBox(children=(IntProgress(value=0, description='BATCH', max=600), HTML(value='')))

[189]rec loss: 0.0128 |sim loss: 0.0001|noise loss: 0.7464


HBox(children=(IntProgress(value=0, description='BATCH', max=600), HTML(value='')))

[190]rec loss: 0.0128 |sim loss: 0.0001|noise loss: 0.7021


HBox(children=(IntProgress(value=0, description='BATCH', max=600), HTML(value='')))

[191]rec loss: 0.0127 |sim loss: 0.0001|noise loss: 0.7156


HBox(children=(IntProgress(value=0, description='BATCH', max=600), HTML(value='')))

[192]rec loss: 0.0126 |sim loss: 0.0001|noise loss: 0.7049


HBox(children=(IntProgress(value=0, description='BATCH', max=600), HTML(value='')))

[193]rec loss: 0.0127 |sim loss: 0.0001|noise loss: 0.6982


HBox(children=(IntProgress(value=0, description='BATCH', max=600), HTML(value='')))

[194]rec loss: 0.0127 |sim loss: 0.0001|noise loss: 0.7163


HBox(children=(IntProgress(value=0, description='BATCH', max=600), HTML(value='')))

[195]rec loss: 0.0128 |sim loss: 0.0001|noise loss: 0.7031


HBox(children=(IntProgress(value=0, description='BATCH', max=600), HTML(value='')))

[196]rec loss: 0.0127 |sim loss: 0.0001|noise loss: 0.7101


HBox(children=(IntProgress(value=0, description='BATCH', max=600), HTML(value='')))

[197]rec loss: 0.0128 |sim loss: 0.0001|noise loss: 0.7188


HBox(children=(IntProgress(value=0, description='BATCH', max=600), HTML(value='')))

[198]rec loss: 0.0128 |sim loss: 0.0001|noise loss: 0.7102


HBox(children=(IntProgress(value=0, description='BATCH', max=600), HTML(value='')))

[199]rec loss: 0.0128 |sim loss: 0.0001|noise loss: 0.7139



In [None]:
len(train_loader)

In [None]:
for i, x in enumerate(train_loader):
    if i == 0:
        with torch.no_grad():
            x = torch.transpose(x, 0, 1)
            x = x.to(device)
            plot_rec(x, 200)
            plot_analogy(x, 200)