In [1]:
%load_ext autoreload
%autoreload 2

In [3]:
import cv2
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils
from torchvision import models
import torch.optim as optim
from torch.autograd import Variable
import torchvision.utils as vutils

from utils.process_img import Rescale, DynamicCrop, ToTensor, CenterCrop
from utils.func import weights_init, random_annotate, gaussian_noise
from pose_dataset import PoseDataset, print_sample
from model.generator import PoseGeneratorDC, PoseGeneratorL
from model.discriminator import PoseDiscriminatorDC, PoseDiscriminatorL
from utils.process_text import tokenizer, get_embeddings, get_word2idx

ImportError: cannot import name 'random_annotate'

In [None]:
# Parameters:
lr = 0.0002 
beta1 = 0.5 
img_size = 128
z_size = 16
batch_size = 3
composed = transforms.Compose([Rescale(512),
                               DynamicCrop(30),
                               Rescale((img_size, img_size))])

In [None]:
pose_dataset = PoseDataset('./data/sample_list.csv', './data', transform = composed)
pose_dataloader = DataLoader(pose_dataset, batch_size=batch_size, shuffle=True, num_workers=4)
print(len(pose_dataset))

In [None]:
# for i in range(10):
#     sample = pose_dataset[i]
#     print_sample(sample)

In [None]:
# Model, Loss, and Optimizer
embeddings = pose_dataset.embeddings
#netD = PoseDiscriminatorDC(embeddings).cuda()
netD = Discriminator().cuda()
netD.apply(weights_init)

# netD = PoseDiscriminatorL().cuda()
netG = PoseGeneratorDC(embeddings).cuda()

netG.apply(weights_init)

criterion = nn.BCEWithLogitsLoss()
optimizerD = optim.Adam(netD.parameters(), lr=lr, betas=(beta1, 0.999))
optimizerG = optim.Adam(netG.parameters(), lr=lr, betas=(beta1, 0.999))

In [None]:
real_label = 1
fake_label = 0

# Lists to keep track of progress
img_list = []
G_losses = []
D_losses = []
iters = 0
num_epochs = 50

torch.manual_seed(250)
fixed_noise = torch.randn(1, z_size, 1, 1).cuda()
fixed_annotate = random_annotate(1, pose_dataset, random_state = 1).cuda()
#torch.randn(batch_size, z_size, 1, 1).long().cuda()
print("Starting Training Loop...")
# For each epoch

for epoch in range(num_epochs):
    netG.train()
    netD.train()
    for i, sample in enumerate(pose_dataloader, 0):
        
        # ---------------------------------------------
        # (1) Update D network: maximize log(D(x)) + log(1 - D(G(z)))
        # ---------------------------------------------
        # Train with all-real batch
        netD.zero_grad()
        
        batch_size = sample['parsing'].shape[0]
        real_image = torch.reshape(sample['pose'], (batch_size, 3, img_size, img_size)).float().cuda()
        
        # Add noise to real pose
        #real_image = gaussian_noise(real_image)
        
        annotate = sample['annotate'].cuda()
        fake_annotate = random_annotate(batch_size, pose_dataset).cuda()

        
        label = torch.full((batch_size, ), real_label).cuda()  
        
        #output = netD(real_image, annotate).view(-1)
        output = netD(real_image).view(-1)
        
        errD_real = criterion(output, label) 
        errD_real.backward()
        D_x = output.mean().item()
        
        # Train with all-fake batch
        noise = torch.randn(batch_size, z_size, 1, 1).cuda()
        fake_image = netG(noise, fake_annotate)
        label.fill_(fake_label)
        
        output = netD(fake_image.detach()).view(-1)
        #output = netD(fake_image.detach(), fake_annotate).view(-1)   
        
        errD_fake = criterion(output, label)
        errD_fake.backward()

        D_G_z1 = output.mean().item()
        errD = errD_real + errD_fake
        optimizerD.step()
        
        
        # --------------------------------------------- #
        # (2) Update G network: maximize log(D(G(z)))   #
        # --------------------------------------------- #
        netG.zero_grad()
        label.fill_(real_label) 
        
        output = netD(fake_image).view(-1)
        #output = netD(fake_image, fake_annotate).view(-1)  
        
        errG = criterion(output, label)
        errG.backward()
        D_G_z2 = output.mean().item()
        optimizerG.step()

        # Output training stats
        if i and i%10==0:
            print('[%d/%d][%d/%d]\tLoss_D: %.4f\tLoss_G: %.4f\tD(x): %.4f\tD(G(z)): %.4f / %.4f'
                      % (epoch, num_epochs, i, len(pose_dataloader),
                         errD.item(), errG.item(), D_x, D_G_z1, D_G_z2))
     
        # Save Losses for plotting later
        G_losses.append(errG.item())
        D_losses.append(errD.item())
        iters += 1

        
    netG.eval()
    with torch.no_grad():
        fake_image = netG(fixed_noise, fixed_annotate)
        img = torch.reshape(fake_image[0], (img_size, img_size, 3)).cpu().data.detach().numpy()
        plt.imshow(img)
        plt.show()
        

In [None]:
torch.cuda.empty_cache()

### CGAN on clustered pose:

In [None]:
img_size = 64
batch_size = 64
img_shape = (3, 64, 64)
composed = transforms.Compose([Rescale(512),
                               DynamicCrop(30),
                               Rescale((img_size, img_size))])

pose_dataset = PoseDataset('./data/data_list_label.csv', './data', transform = composed, gray_scale = False)
pose_dataloader = DataLoader(pose_dataset, batch_size=batch_size, shuffle=True, num_workers=4)
print(len(pose_dataset))

In [None]:
sample = pose_dataset[1]
sample.keys()

In [None]:
opt = {'b1': 0.5, 
       'b2': 0.999, 
       'batch_size': 64, 
       'channels': 1, 
       'img_size': 32, 
       'latent_dim': 100, 
       'lr': 0.0002, 
       'n_classes': 200, 
       'n_cpu': 8, 
       'n_epochs': 1, 
       'sample_interval': 400}

In [None]:
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()

        self.label_emb = nn.Embedding(opt['n_classes'], opt['n_classes'])

        def block(in_feat, out_feat, normalize=True):
            layers = [nn.Linear(in_feat, out_feat)]
            if normalize:
                layers.append(nn.BatchNorm1d(out_feat, 0.8))
            layers.append(nn.LeakyReLU(0.2, inplace=True))
            return layers

        self.model = nn.Sequential(
            *block(opt['latent_dim'] + opt['n_classes'], 128, normalize=False),
            *block(128, 256),
            *block(256, 512),
            *block(512, 1024),
            nn.Linear(1024, int(np.prod(img_shape))),
            nn.Tanh()
        )

    def forward(self, noise, labels):
        # Concatenate label embedding and image to produce input
        gen_input = torch.cat((self.label_emb(labels), noise), -1)

        img = self.model(gen_input)
        img = img.view(img.size(0), *img_shape)
        return img


class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()

        self.label_embedding = nn.Embedding(opt['n_classes'], opt['n_classes'])

        self.model = nn.Sequential(
            nn.Linear(opt['n_classes'] + int(np.prod(img_shape)), 512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(512, 512),
            nn.Dropout(0.4),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(512, 512),
            nn.Dropout(0.4),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(512, 1),
        )

    def forward(self, img, labels):
        # Concatenate label embedding and image to produce input
        img = img.view(img.shape[0], -1)
        d_in = torch.cat((self.label_embedding(labels), img.float()), -1)

        validity = self.model(d_in)
        return validity


In [None]:
adversarial_loss = torch.nn.MSELoss()
generator = Generator().cuda()
discriminator = Discriminator().cuda()

In [None]:
optimizer_G = torch.optim.Adam(generator.parameters(), lr=opt['lr'], betas=(opt['b1'], opt['b2']))
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=opt['lr'], betas=(opt['b1'], opt['b2']))

In [None]:
# Train
test_batch_size = 100
fixed_z = torch.randn(test_batch_size, opt['latent_dim']).cuda()
#np.random.seed(250)
#fixed_gen_labels = torch.from_numpy(np.random.randint(0, 200, batch_size)).cuda()
fixed_gen_labels = torch.from_numpy(np.arange(test_batch_size)).cuda()
gen_img_list = []
G_loss, D_loss = [], []

n_epochs = 10
for epoch in range(n_epochs):
    for i, sample in enumerate(pose_dataloader):
        imgs = sample['pose']
        labels = sample['label']
        
        b_size = imgs.shape[0]
        # Adversarial ground truths
        valid = torch.full((b_size, ), 1.0).cuda()
        fake = torch.full((b_size, ), 0.0).cuda()

        # Configure input
        real_imgs = imgs.cuda()
        labels = labels.cuda()
        
        # -----------------
        #  Train Generator
        # -----------------
        optimizer_G.zero_grad()

        # Sample noise and labels as generator input
        z = torch.randn(batch_size, opt['latent_dim']).cuda()
        gen_labels = torch.from_numpy(np.random.randint(0, 200, batch_size)).cuda()
        
        # Generate a batch of images
        gen_imgs = generator(z, gen_labels)
        
        # Loss measures generator's ability to fool the discriminator
        validity = discriminator(gen_imgs, gen_labels)
        g_loss = adversarial_loss(validity, valid)

        G_loss.append(g_loss.item())
        g_loss.backward()
        optimizer_G.step()
        
        # ---------------------
        #  Train Discriminator
        # ---------------------

        optimizer_D.zero_grad()

        # Loss for real images
        
        validity_real = discriminator(real_imgs, labels)
        d_real_loss = adversarial_loss(validity_real, valid)
        
        # Loss for fake images
        validity_fake = discriminator(gen_imgs.detach(), gen_labels)
        d_fake_loss = adversarial_loss(validity_fake, fake)

        # Total discriminator loss
        d_loss = (d_real_loss + d_fake_loss) / 2

        D_loss.append(d_loss.item())
        d_loss.backward()
        optimizer_D.step()

        if i and i % 5 == 0:
            print(
                "[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f]"
                % (epoch, n_epochs, i, len(pose_dataloader), d_loss.item(), g_loss.item())
            )
        
        if i and i % 50 == 0:
            gen_imgs = generator(fixed_z, fixed_gen_labels)
            gen_img_list.append(gen_imgs)
            img = np.reshape(gen_imgs[0].cpu().detach().numpy(), (64, 64, 3))
            plt.imshow(img, cmap='gray')
            plt.show()

#         batches_done = epoch * len(dataloader) + i
#         if batches_done % opt.sample_interval == 0:
#             sample_image(n_row=10, batches_done=batches_done)

In [None]:
# torch.save(generator, 'parsing_netG.pth')
# torch.save(discriminator, 'parsing_netD.pth')
# print(discriminator)

In [None]:
for k in range(20):
    class_images = gen_img_list[k].cpu().detach().numpy()
    # np.random.seed(4995)
    #target_labels = np.random.choice(np.arange(100), 10, replace = False)
    target_labels = np.arange(10)

    fig = plt.figure(figsize=(20, 20))

    for i in range(10):
        fig.add_subplot(k+1, 10, i+1)
        img = np.reshape(class_images[target_labels[i]], (64, 64, -1))
        plt.imshow(img)

fig.tight_layout()
plt.show()

In [None]:
sample = pose_dataset[1]
img = sample['pose']
plt.imshow(img, cmap='gray')