In [76]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [109]:
import cv2
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils
from torchvision import models
import torch.optim as optim
from torch.autograd import Variable

from utils.process_img import Rescale, DynamicCrop, ToTensor, CenterCrop
from utils.func import weights_init, random_annotate, gaussian_noise
from pose_dataset import PoseDataset, print_sample
from model.generator import PoseGeneratorDC, PoseGeneratorL
from model.discriminator import PoseDiscriminatorDC
from utils.process_text import tokenizer, get_embeddings, get_word2idx

In [100]:
# for i in range(10):
#     sample = pose_dataset[i]
#     print_sample(sample)

In [101]:
# 测试:
# embeddings = pose_dataset.embeddings

# # Generator
# netG = PoseGenerator(embeddings).cuda()
# netG.apply(weights_init)
# # Discriminator
# netD = PoseDiscriminator(embeddings).cuda()
# netD.apply(weights_init)

# for i, sample in enumerate(pose_dataloader):
#     annotate = sample['annotate'].cuda()
#     noise = torch.randn(2, 64, 1, 1).cuda()
#     fake_img = netG(noise, annotate)
#     print(fake_img.shape)
    
#     pred = netD(fake_img, annotate)
#     print(pred)
#     break

In [110]:
nc = 3   
nz = 100 
ngf = 64 
ndf = 64 
lr = 0.0002 
beta1 = 0.5 
img_size = 128
z_size = 64
composed = transforms.Compose([Rescale(512),
                               DynamicCrop(30),
                               Rescale((img_size, img_size))])

In [111]:
pose_dataset = PoseDataset('./data/truncate_data_list.csv', './data', transform = composed)
pose_dataloader = DataLoader(pose_dataset, batch_size=3, shuffle=True, num_workers=4)

embeddings = pose_dataset.embeddings
netD = PoseDiscriminatorDC(embeddings).cuda()
netG = PoseGeneratorDC(embeddings).cuda()
netD.apply(weights_init)
netG.apply(weights_init)

criterion = nn.BCEWithLogitsLoss()
optimizerD = optim.Adam(netD.parameters(), lr=0.001, betas=(0.5, 0.999))
optimizerG = optim.Adam(netG.parameters(), lr=0.001, betas=(0.5, 0.999))

In [114]:
real_label = 1
fake_label = 0

# Lists to keep track of progress
img_list = []
G_losses = []
D_losses = []
iters = 0
num_epochs = 50

torch.manual_seed(250)
fixed_noise = torch.randn(1, z_size, 1, 1).cuda()
fixed_annotate = random_annotate(1, pose_dataset, random_state = 1).cuda()

print("Starting Training Loop...")
# For each epoch

for epoch in range(num_epochs):
    netG.train()
    netD.train()
    for i, sample in enumerate(pose_dataloader, 0):
        # Add noise to real pose
        #real_image = gaussian_noise(real_image)
        
        # ---------------------------------------------
        # (1) Update D network: maximize log(D(x)) + log(1 - D(G(z)))
        # ---------------------------------------------
        # Train with all-real batch
        netD.zero_grad()
        
        batch_size = sample['parsing'].shape[0]
        real_image = torch.reshape(sample['parsing'], (batch_size, 3, img_size, img_size)).float().cuda()
        annotate = sample['annotate'].cuda()
        fake_annotate = random_annotate(batch_size, pose_dataset).cuda()
        
        label = torch.full((batch_size, ), real_label).cuda()     
        output = netD(real_image, annotate).view(-1)
        errD_real = criterion(output, label) 
        errD_real.backward()
        D_x = output.mean().item()
        
        # Train with all-fake batch
        noise = torch.randn(batch_size, z_size, 1, 1).cuda()
        fake_image = netG(noise, fake_annotate)
        label.fill_(fake_label)
        output = netD(fake_image.detach(), fake_annotate).view(-1)   
        errD_fake = criterion(output, label)
        errD_fake.backward()

        D_G_z1 = output.mean().item()
        errD = errD_real + errD_fake
        optimizerD.step() # Update D
        
        
        # --------------------------------------------- #
        # (2) Update G network: maximize log(D(G(z)))   #
        # --------------------------------------------- #
        netG.zero_grad()
        label.fill_(real_label)  # fake labels are real for generator cost 
        output = netD(fake_image, fake_annotate).view(-1)  # Since we just updated D, perform another forward pass of all-fake batch through D
        errG = criterion(output, label)
        errG.backward()
        
        D_G_z2 = output.mean().item()
        # Update G
        optimizerG.step()

        # Output training stats
        if i and i%10==0:
            print('[%d/%d][%d/%d]\tLoss_D: %.4f\tLoss_G: %.4f\tD(x): %.4f\tD(G(z)): %.4f / %.4f'
                      % (epoch, num_epochs, i, len(pose_dataloader),
                         errD.item(), errG.item(), D_x, D_G_z1, D_G_z2))
     
        # Save Losses for plotting later
        G_losses.append(errG.item())
        D_losses.append(errD.item())
        iters += 1

        
    netG.eval()
    with torch.no_grad():
        fake_image = netG(fixed_noise, fixed_annotate)
        img = torch.reshape(fake_image[0], (img_size, img_size, 3)).cpu().data.detach().numpy()
        plt.imshow(img)
        plt.show()
        

Starting Training Loop...
tensor([[154,  39, 155,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
           0],
        [ 68, 215,  47,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
           0],
        [142,  25,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
           0]], device='cuda:0')
tensor([[-7.0989e-02, -7.7011e-01, -2.6233e-01, -5.1319e-01, -4.6055e-02,
         -7.9337e-02,  5.9567e-01, -1.1256e-01, -3.2252e-01,  4.3202e-02,
          3.3053e-01,  4.6034e-01,  2.9963e-01, -2.7683e-01, -6.2079e-01,
          1.9298e-01, -8.8039e-02,  5.1626e-01, -6.4537e-01, -3.8270e-01,
         -1.6256e-01,  3.9393e-01,  4.8675e-01, -3.0187e-01,  5.6805e-01,
          1.1506e-01,  6.5116e-01, -5.1588e-01, -1.3991e-01, -3.6249e-01,
          5.4783e-01, -2.4878e-01,  5.0145e-01, -8.5865e-01, -8.3975e-01,
          8.9064e-01,  8.8714e-01, -3.8600e-02, -8.1121e-01,  1.6001e-01,
         -4.3772e-01,  1.8673e-01,  6.4578e-01, -8.2886e-01, -7.2701e-03,
        

tensor([[ 69,  35,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
           0],
        [148, 149,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
           0],
        [ 76,  77,  25,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
           0]], device='cuda:0')
tensor([[-2.9176e-02, -1.1258e-01, -5.5850e-02, -3.5744e-02,  4.0919e-03,
         -1.2134e-01,  4.1192e-02, -5.6778e-02, -1.3943e-02,  7.2167e-02,
         -1.6008e-01, -1.5983e-01,  4.3843e-02, -2.3189e-02, -2.0468e-02,
         -2.8180e-02, -4.6338e-03, -1.5087e-01, -7.2086e-02,  1.0661e-01,
          7.6719e-02,  1.0090e-02, -7.8539e-02,  1.2746e-01,  9.1149e-02,
         -2.2507e-02,  6.6849e-02,  3.6009e-02, -1.5949e-01, -1.6144e-02,
         -1.0712e-02, -2.2072e-02, -1.2104e-02,  1.5395e-02,  4.3361e-02,
          3.1495e-02,  2.8493e-02,  9.4792e-02,  2.5795e-02,  1.1651e-02,
         -8.7455e-02, -1.6125e-01, -1.4923e-02,  2.8898e-02,  3.7966e-02,
          7.0812e-03, -3.8691e-02,

tensor([[ 30,  49, 117,  87,  17,   0,   0,   0,   0,   0,   0,   0,   0,   0,
           0],
        [190,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
           0],
        [212, 213,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
           0]], device='cuda:0')
tensor([[-2.8759e-02, -1.1285e-01, -6.9378e-02, -6.4089e-02,  2.1981e-02,
         -3.6440e-01,  5.1344e-02, -7.4223e-02, -2.0891e-02,  7.6353e-02,
         -1.6159e-01, -1.2695e-01,  5.1960e-02, -2.8841e-02, -1.6441e-02,
         -2.3775e-02,  1.5362e-03, -1.5077e-01, -7.8924e-02,  9.4096e-02,
          6.7130e-02,  1.5834e-02, -7.6889e-02,  1.3279e-01,  9.5150e-02,
         -1.3255e-02,  7.3880e-02,  2.7728e-02, -8.6258e-02, -2.9466e-02,
         -3.1215e-03, -4.8247e-02, -6.0024e-03,  3.3038e-04,  3.4530e-02,
          3.3084e-02,  2.7012e-02,  3.2631e-01,  2.3603e-02, -9.2701e-03,
         -1.8407e-01, -1.5477e-01, -2.4017e-02,  2.1251e-02,  7.5840e-02,
          2.8630e-02, -3.0764e-02,

tensor([[156, 240,  14, 241, 242,  14, 117,   0,   0,   0,   0,   0,   0,   0,
           0],
        [ 51, 671,  30, 292,  14, 293,   0,   0,   0,   0,   0,   0,   0,   0,
           0],
        [ 88, 244, 167,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
           0]], device='cuda:0')
tensor([[-1.4297e-02, -1.1669e-01, -9.8769e-02, -4.1726e-03,  4.1098e-02,
         -1.8968e-01, -6.9399e-04, -9.7294e-02, -1.3085e-01,  1.4321e-01,
         -2.7247e-01, -1.6832e-01,  6.0496e-02, -5.0569e-02,  1.5484e-02,
         -2.4610e-02,  1.2448e-02, -2.0411e-01, -6.3422e-02,  1.2311e-01,
          1.0915e-01, -5.7172e-02, -1.3996e-01,  2.0341e-01,  7.6245e-02,
          7.9779e-03,  4.6306e-02,  4.8778e-02,  1.0965e-01, -2.5456e-02,
         -3.1318e-02, -3.6950e-02, -4.7457e-02,  3.8702e-02,  6.8894e-02,
          5.1826e-04,  1.4215e-03,  4.4187e-01,  5.6334e-02, -1.6037e-01,
         -1.7474e-01, -1.2369e-01, -4.1850e-02,  5.1571e-02,  1.3334e-01,
         -3.2540e-02, -7.9561e-02,

tensor([[-1.5166e-02, -7.1795e-01, -4.5701e-02, -3.8443e-01, -9.7231e-02,
         -8.4440e-02,  4.8659e-01, -1.4147e-01, -2.1873e-01,  2.2909e-02,
          2.5185e-01,  3.6815e-01, -5.8098e-03, -1.8679e-01, -4.8267e-01,
          7.9149e-02, -4.9819e-02,  4.3462e-01, -5.5864e-01, -2.9403e-01,
         -1.1245e-01,  2.5679e-01,  3.8199e-01, -1.6292e-01,  4.5220e-01,
          1.8462e-01,  5.8242e-01, -4.1017e-01, -2.7906e-01, -2.7116e-01,
          4.6384e-01, -2.0197e-01,  7.0499e-02, -7.3410e-01, -7.1518e-01,
          8.3337e-01,  7.7338e-01, -2.5060e-02, -7.0580e-01,  1.1073e-01,
         -2.5061e-01,  1.3945e-01,  7.5126e-02, -7.3463e-01, -3.3567e-03,
          7.6118e-01,  5.8073e-01, -6.8166e-02,  7.2377e-01, -5.8140e-01,
         -4.3111e-01, -2.6855e-01, -8.9896e-01,  1.7093e-01, -7.9805e-01,
          5.8083e-02,  1.7066e-01, -6.8511e-01, -3.8797e-02, -8.1204e-01,
         -7.1453e-01,  7.9105e-01, -3.5113e-01, -7.0886e-01],
        [-2.5109e-02, -8.6782e-02, -6.5516e-02,  6

tensor([[154,  39, 155,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
           0],
        [ 26,  27, 671,  28,  29,  30,  31,  32,   0,   0,   0,   0,   0,   0,
           0],
        [302,  25,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
           0]], device='cuda:0')
tensor([[-0.0528, -0.7477, -0.2440, -0.4842, -0.0630, -0.0865,  0.5488, -0.1051,
         -0.3183,  0.0377,  0.2833,  0.4162,  0.2595, -0.2718, -0.5890,  0.1563,
         -0.0834,  0.4601, -0.6259, -0.3537, -0.1437,  0.3456,  0.4219, -0.2357,
          0.5448,  0.1551,  0.6333, -0.4880, -0.1297, -0.3403,  0.4992, -0.2214,
          0.3336, -0.8083, -0.7796,  0.8456,  0.8079, -0.0195, -0.7664,  0.1821,
         -0.3722,  0.1454,  0.3257, -0.7834, -0.0048,  0.7704,  0.5856, -0.0954,
          0.7422, -0.6149, -0.3979, -0.1712, -0.9303,  0.1731, -0.7300,  0.0259,
          0.3151, -0.6954, -0.3532, -0.8272, -0.7236,  0.8248, -0.1342, -0.7434],
        [-0.5402, -0.8752, -0.7969, -0.7397,  0.699

tensor([[426, 148,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
           0],
        [161, 671,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
           0],
        [ 74,  75,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
           0]], device='cuda:0')
tensor([[-3.0599e-01, -7.5503e-01, -5.6411e-01, -5.7333e-01,  4.9553e-01,
         -1.0069e-01,  6.5203e-01, -4.7329e-01, -3.8592e-01,  9.0926e-02,
          3.1469e-01,  4.6879e-01,  4.7236e-01, -2.4681e-01, -5.4725e-01,
          3.9689e-01, -1.3842e-01,  4.3162e-01, -6.8118e-01, -4.3418e-01,
         -2.2201e-01,  4.7448e-01,  4.8855e-01, -3.5766e-01,  6.3383e-01,
         -4.0405e-01,  6.2130e-01, -4.9100e-01,  1.5983e-01, -5.4580e-01,
          5.3196e-01, -4.9572e-01,  6.1453e-02, -9.1183e-01, -8.7480e-01,
          9.6508e-01,  9.5258e-01,  2.5688e-02, -8.5110e-01,  2.1029e-02,
         -5.8746e-01,  6.2755e-01,  3.4720e-01, -9.1051e-01,  9.5004e-03,
          9.6857e-01,  8.3268e-01,

KeyboardInterrupt: 

In [None]:
# img = torch.reshape(img_list[0], (128, 128, 3)).cpu().detach().numpy()
# plt.imshow(img)
for i, sample in enumerate(pose_dataloader, 0):
    img = sample['parsing'][0]
    break

In [None]:
torch.cuda.empty_cache()

In [5]:
# Train Settings 1:
# embeddings = pose_dataset.embeddings
# netG = PoseGeneratorDC(embeddings).cuda()
# netG.apply(weights_init)
# netD = PoseDiscriminatorDC(embeddings).cuda()
# netD.apply(weights_init)

# adversarial_loss = torch.nn.MSELoss()
# optimizerD = optim.Adam(netD.parameters(), lr=0.0002, betas=(0.5, 0.999))
# optimizerG = optim.Adam(netG.parameters(), lr=0.0001, betas=(0.5, 0.999))

In [115]:
# real_label = 1
# fake_label = 0

# # Lists to keep track of progress
# img_list = []
# G_losses = []
# D_losses = []
# iters = 0
# num_epochs = 1

# torch.manual_seed(250)
# fixed_noise = torch.randn(1, 64, 1, 1).cuda()
# fixed_annotate = random_annotate(1, pose_dataset, random_state = 1).cuda()

# print("Starting Training Loop...")
# # For each epoch

# for epoch in range(num_epochs):
#     netG.train()
#     netD.train()
#     # For each batch in the dataloader
#     for i, sample in enumerate(pose_dataloader, 0):
#         batch_size = sample['raw'].shape[0]
#         # reformat the shape to be (batch_size, 3, 128, 128)
#         real_image = torch.reshape(sample['parsing'], (batch_size, 3, 128, 128)).float().cuda()
#         real_image = gaussian_noise(real_image)
        
#         label = torch.full((batch_size, ), real_label).cuda()
#         # -----------------
#         #  Train Generator
#         # -----------------
#         optimizerG.zero_grad()
#         noise = torch.randn(batch_size, 64, 1, 1).cuda()
#         annotate = sample['annotate'].cuda()
#         fake_annotate = random_annotate(batch_size, pose_dataset).cuda()
        
#         fake_image = netG(noise, fake_annotate)
#         validity = netD(fake_image, fake_annotate).view(-1)
        
        
#         break
#         g_loss = adversarial_loss(validity, label)
#         g_loss.backward()
#         optimizerG.step()
        
#         # ---------------------
#         #  Train Discriminator
#         # ---------------------
#         optimizerD.zero_grad()
        
#         validity_real = netD(real_image, annotate).view(-1)
#         d_real_loss = adversarial_loss(validity_real, Variable(label))
#         d_real_loss.backward()
        
#         validity_fake = netD(fake_image.detach(), fake_annotate).view(-1) ### 这里放fake annotate?
#         label.fill_(fake_label)
#         d_fake_loss = adversarial_loss(validity_fake, Variable(label))
#         d_fake_loss.backward()
        
#         d_loss = d_real_loss + d_fake_loss
#         optimizerD.step()

#         # Output training stats
#         if i and i%10==0:
#             print('[%d/%d][%d/%d]\tLoss_D: %.4f\tLoss_G: %.4f'
#                       % (epoch, num_epochs, i, len(pose_dataloader),
#                          d_loss.item(), g_loss.item()))

#         # Save Losses for plotting later
#         G_losses.append(g_loss.item())
#         D_losses.append(d_loss.item())

#         # Check how the generator is doing by saving G's output on fixed_noise
# #         if (iters % 500 == 0) or ((epoch == num_epochs-1) and (i == len(dataloader)-1)):
# #             with torch.no_grad():
# #                 fake = netG(fixed_noise).detach().cpu()
# #             img_list.append(vutils.make_grid(fake, padding=2, normalize=True))

#         iters += 1
#     netG.eval()
#     with torch.no_grad():
#         fake_image = netG(fixed_noise, fixed_annotate)
#         img = torch.reshape(fake_image[0], (128, 128, 3)).cpu().data.detach().numpy()
#         plt.imshow(img)
#         plt.show()