In [1]:
%load_ext autoreload
%autoreload 2

In [11]:
import cv2
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils
from torchvision import models
import torch.optim as optim
from torch.autograd import Variable

from utils.process_img import Rescale, DynamicCrop, ToTensor, CenterCrop
from utils.func import weights_init, random_annotate, gaussian_noise
from pose_dataset import PoseDataset, print_sample
from model.generator import PoseGeneratorDC, PoseGeneratorL
from model.discriminator import PoseDiscriminatorDC
from utils.process_text import tokenizer, get_embeddings, get_word2idx

In [12]:
# Parameters:
lr = 0.0002 
beta1 = 0.5 
img_size = 128
z_size = 64
batch_size = 16
composed = transforms.Compose([Rescale(512),
                               DynamicCrop(30),
                               Rescale((img_size, img_size))])

In [13]:
pose_dataset = PoseDataset('./data/data_list.csv', './data', transform = composed)
pose_dataloader = DataLoader(pose_dataset, batch_size=batch_size, shuffle=True, num_workers=4)
print(len(pose_dataset))
# for i in range(10):
#     sample = pose_dataset[i]
#     print_sample(sample)

8708


In [14]:
# Model, Loss, and Optimizer
embeddings = pose_dataset.embeddings
netD = PoseDiscriminatorDC(embeddings).cuda()
netG = PoseGeneratorDC(embeddings).cuda()
netD.apply(weights_init)
netG.apply(weights_init)

criterion = nn.BCEWithLogitsLoss()
optimizerD = optim.Adam(netD.parameters(), lr=lr, betas=(beta1, 0.999))
optimizerG = optim.Adam(netG.parameters(), lr=lr, betas=(beta1, 0.999))

In [15]:
real_label = 1
fake_label = 0

# Lists to keep track of progress
img_list = []
G_losses = []
D_losses = []
iters = 0
num_epochs = 1

torch.manual_seed(250)
fixed_noise = torch.randn(1, z_size, 1, 1).cuda()
fixed_annotate = random_annotate(1, pose_dataset, random_state = 1).cuda()

print("Starting Training Loop...")
# For each epoch

for epoch in range(num_epochs):
    netG.train()
    netD.train()
    for i, sample in enumerate(pose_dataloader, 0):
        # Add noise to real pose
        # real_image = gaussian_noise(real_image)
        
        # ---------------------------------------------
        # (1) Update D network: maximize log(D(x)) + log(1 - D(G(z)))
        # ---------------------------------------------
        # Train with all-real batch
        netD.zero_grad()
        
        batch_size = sample['parsing'].shape[0]
        real_image = torch.reshape(sample['parsing'], (batch_size, 3, img_size, img_size)).float().cuda()
        annotate = sample['annotate'].cuda()
        fake_annotate = random_annotate(batch_size, pose_dataset).cuda()
        
        label = torch.full((batch_size, ), real_label).cuda()     
        output = netD(real_image, annotate).view(-1)
        errD_real = criterion(output, label) 
        errD_real.backward()
        D_x = output.mean().item()
        
        # Train with all-fake batch
        noise = torch.randn(batch_size, z_size, 1, 1).cuda()
        fake_image = netG(noise, fake_annotate)
        label.fill_(fake_label)
        output = netD(fake_image.detach(), fake_annotate).view(-1)   
        errD_fake = criterion(output, label)
        errD_fake.backward()

        D_G_z1 = output.mean().item()
        errD = errD_real + errD_fake
        optimizerD.step() # Update D
        
        
        # --------------------------------------------- #
        # (2) Update G network: maximize log(D(G(z)))   #
        # --------------------------------------------- #
        netG.zero_grad()
        label.fill_(real_label)  # fake labels are real for generator cost 
        output = netD(fake_image, fake_annotate).view(-1)  # Since we just updated D, perform another forward pass of all-fake batch through D
        errG = criterion(output, label)
        errG.backward()
        
        D_G_z2 = output.mean().item()
        # Update G
        optimizerG.step()

        # Output training stats
        if i and i%10==0:
            print('[%d/%d][%d/%d]\tLoss_D: %.4f\tLoss_G: %.4f\tD(x): %.4f\tD(G(z)): %.4f / %.4f'
                      % (epoch, num_epochs, i, len(pose_dataloader),
                         errD.item(), errG.item(), D_x, D_G_z1, D_G_z2))
     
        # Save Losses for plotting later
        G_losses.append(errG.item())
        D_losses.append(errD.item())
        iters += 1

        
    netG.eval()
    with torch.no_grad():
        fake_image = netG(fixed_noise, fixed_annotate)
        img = torch.reshape(fake_image[0], (img_size, img_size, 3)).cpu().data.detach().numpy()
        plt.imshow(img)
        plt.show()
        

Starting Training Loop...
[0/1][10/545]	Loss_D: 1.3996	Loss_G: 0.7867	D(x): -0.0999	D(G(z)): -0.1172 / -0.1490


KeyboardInterrupt: 

In [None]:
torch.cuda.empty_cache()

In [None]:
# Train Settings 1:
# embeddings = pose_dataset.embeddings
# netG = PoseGeneratorDC(embeddings).cuda()
# netG.apply(weights_init)
# netD = PoseDiscriminatorDC(embeddings).cuda()
# netD.apply(weights_init)

# adversarial_loss = torch.nn.MSELoss()
# optimizerD = optim.Adam(netD.parameters(), lr=0.0002, betas=(0.5, 0.999))
# optimizerG = optim.Adam(netG.parameters(), lr=0.0001, betas=(0.5, 0.999))

In [None]:
# real_label = 1
# fake_label = 0

# # Lists to keep track of progress
# img_list = []
# G_losses = []
# D_losses = []
# iters = 0
# num_epochs = 1

# torch.manual_seed(250)
# fixed_noise = torch.randn(1, 64, 1, 1).cuda()
# fixed_annotate = random_annotate(1, pose_dataset, random_state = 1).cuda()

# print("Starting Training Loop...")
# # For each epoch

# for epoch in range(num_epochs):
#     netG.train()
#     netD.train()
#     # For each batch in the dataloader
#     for i, sample in enumerate(pose_dataloader, 0):
#         batch_size = sample['raw'].shape[0]
#         # reformat the shape to be (batch_size, 3, 128, 128)
#         real_image = torch.reshape(sample['parsing'], (batch_size, 3, 128, 128)).float().cuda()
#         real_image = gaussian_noise(real_image)
        
#         label = torch.full((batch_size, ), real_label).cuda()
#         # -----------------
#         #  Train Generator
#         # -----------------
#         optimizerG.zero_grad()
#         noise = torch.randn(batch_size, 64, 1, 1).cuda()
#         annotate = sample['annotate'].cuda()
#         fake_annotate = random_annotate(batch_size, pose_dataset).cuda()
        
#         fake_image = netG(noise, fake_annotate)
#         validity = netD(fake_image, fake_annotate).view(-1)
        
        
#         break
#         g_loss = adversarial_loss(validity, label)
#         g_loss.backward()
#         optimizerG.step()
        
#         # ---------------------
#         #  Train Discriminator
#         # ---------------------
#         optimizerD.zero_grad()
        
#         validity_real = netD(real_image, annotate).view(-1)
#         d_real_loss = adversarial_loss(validity_real, Variable(label))
#         d_real_loss.backward()
        
#         validity_fake = netD(fake_image.detach(), fake_annotate).view(-1) ### 这里放fake annotate?
#         label.fill_(fake_label)
#         d_fake_loss = adversarial_loss(validity_fake, Variable(label))
#         d_fake_loss.backward()
        
#         d_loss = d_real_loss + d_fake_loss
#         optimizerD.step()

#         # Output training stats
#         if i and i%10==0:
#             print('[%d/%d][%d/%d]\tLoss_D: %.4f\tLoss_G: %.4f'
#                       % (epoch, num_epochs, i, len(pose_dataloader),
#                          d_loss.item(), g_loss.item()))

#         # Save Losses for plotting later
#         G_losses.append(g_loss.item())
#         D_losses.append(d_loss.item())

#         # Check how the generator is doing by saving G's output on fixed_noise
# #         if (iters % 500 == 0) or ((epoch == num_epochs-1) and (i == len(dataloader)-1)):
# #             with torch.no_grad():
# #                 fake = netG(fixed_noise).detach().cpu()
# #             img_list.append(vutils.make_grid(fake, padding=2, normalize=True))

#         iters += 1
#     netG.eval()
#     with torch.no_grad():
#         fake_image = netG(fixed_noise, fixed_annotate)
#         img = torch.reshape(fake_image[0], (128, 128, 3)).cpu().data.detach().numpy()
#         plt.imshow(img)
#         plt.show()