In [1]:
%load_ext autoreload
%autoreload 2

In [9]:
import cv2
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils
from torchvision import models
import torch.optim as optim

from utils.process_img import Rescale, DynamicCrop, ToTensor, CenterCrop
from utils.func import weights_init
from pose_dataset import PoseDataset, print_sample
from model.generator import PoseGenerator
from model.discriminator import PoseDiscriminator

In [3]:
composed = transforms.Compose([Rescale(512),
                               DynamicCrop(30),
                               Rescale((128, 128))])

pose_dataset = PoseDataset('./data/data_list.csv', './data', transform = composed)
pose_dataloader = DataLoader(pose_dataset, batch_size=4, shuffle=True, num_workers=4)

In [None]:
# for i in range(10):
#     sample = pose_dataset[i]
#     print_sample(sample)

In [None]:
# 测试:
embeddings = pose_dataset.embeddings

# Generator
netG = PoseGenerator(embeddings).cuda()
netG.apply(weights_init)
# Discriminator
netD = PoseDiscriminator(embeddings).cuda()
netD.apply(weights_init)

for i, sample in enumerate(pose_dataloader):
    annotate = sample['annotate'].cuda()
    noise = torch.randn(2, 64, 1, 1).cuda()
    fake_img = netG(noise, annotate)
    print(fake_img.shape)
    
    pred = netD(fake_img, annotate)
    print(pred)
    break

In [10]:
embeddings = pose_dataset.embeddings

# Generator
netG = PoseGenerator(embeddings).cuda()
netG.apply(weights_init)
# Discriminator
netD = PoseDiscriminator(embeddings).cuda()
netD.apply(weights_init)

# Settings:
criterion = nn.BCELoss()
optimizerD = optim.Adam(netD.parameters())
optimizerG = optim.Adam(netG.parameters())

In [14]:
# Training Loop
real_label = 1
fake_label = 0

# Lists to keep track of progress
img_list = []
G_losses = []
D_losses = []
iters = 0
num_epochs = 1

print("Starting Training Loop...")
# For each epoch
for epoch in range(num_epochs):
    # For each batch in the dataloader
    for i, sample in enumerate(pose_dataloader, 0):
        if i > 5:
            break
        batch_size = sample['raw'].shape[0]
        # reformat the shape to be (batch_size, 3, 128, 128)
        real_pose = torch.reshape(sample['pose'], (batch_size, 3, 128, 128)).float().cuda()
        annotate = sample['annotate'].cuda()
        
        ############################
        # (1) Update D network: maximize log(D(x)) + log(1 - D(G(z)))
        ###########################
        ## Train with all-real batch
        netD.zero_grad()

        label = torch.full((batch_size, ), real_label).cuda()      
        output = netD(real_pose, annotate).view(-1)        
        errD_real = criterion(output, label) # Calculate loss on all-real batch
        errD_real.backward()
        D_x = output.mean().item()
        
        ## Train with all-fake batch
        noise = torch.randn(batch_size, 64, 1, 1).cuda()
        fake_pose = netG(noise, annotate)
        label.fill_(fake_label)
        output = netD(fake_pose.detach(), annotate).view(-1)                
        errD_fake = criterion(output, label)
        errD_fake.backward()

        D_G_z1 = output.mean().item()
        errD = errD_real + errD_fake
        optimizerD.step() # Update D
        
        
        ############################
        # (2) Update G network: maximize log(D(G(z)))
        ###########################
        netG.zero_grad()
        label.fill_(real_label)  # fake labels are real for generator cost
        
        output = netD(fake_pose, annotate).view(-1) # Since we just updated D, perform another forward pass of all-fake batch through D
        errG = criterion(output, label)
        errG.backward()
        
        D_G_z2 = output.mean().item()
        # Update G
        optimizerG.step()

        # Output training stats
#         if i % 50 == 0:
#             print('[%d/%d][%d/%d]\tLoss_D: %.4f\tLoss_G: %.4f\tD(x): %.4f\tD(G(z)): %.4f / %.4f'
#                   % (epoch, num_epochs, i, len(dataloader),
#                      errD.item(), errG.item(), D_x, D_G_z1, D_G_z2))
        print('[%d/%d][%d/%d]\tLoss_D: %.4f\tLoss_G: %.4f\tD(x): %.4f\tD(G(z)): %.4f / %.4f'
                  % (epoch, num_epochs, i, len(pose_dataloader),
                     errD.item(), errG.item(), D_x, D_G_z1, D_G_z2))

        # Save Losses for plotting later
        G_losses.append(errG.item())
        D_losses.append(errD.item())

#         # Check how the generator is doing by saving G's output on fixed_noise
#         if (iters % 500 == 0) or ((epoch == num_epochs-1) and (i == len(dataloader)-1)):
#             with torch.no_grad():
#                 fake = netG(fixed_noise).detach().cpu()
#             img_list.append(vutils.make_grid(fake, padding=2, normalize=True))

#         iters += 1
        

Starting Training Loop...
[0/1][0/86]	Loss_D: 1.5465	Loss_G: 0.9717	D(x): 0.5079	D(G(z)): 0.4991 / 0.5473
[0/1][1/86]	Loss_D: 1.5159	Loss_G: 0.6813	D(x): 0.5107	D(G(z)): 0.5169 / 0.5163
[0/1][2/86]	Loss_D: 1.4964	Loss_G: 0.6867	D(x): 0.5181	D(G(z)): 0.5142 / 0.5167
[0/1][3/86]	Loss_D: 1.7075	Loss_G: 0.6641	D(x): 0.4613	D(G(z)): 0.5181 / 0.5181
[0/1][4/86]	Loss_D: 1.4257	Loss_G: 0.6638	D(x): 0.5183	D(G(z)): 0.5173 / 0.5162
[0/1][5/86]	Loss_D: 1.4464	Loss_G: 0.6676	D(x): 0.5133	D(G(z)): 0.5158 / 0.5135


In [6]:
torch.cuda.empty_cache()