In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision.transforms as transforms
import torchvision.models as models
from torchvision.utils import save_image
import itertools


import torch.optim as optim
from torch.autograd import Variable

from PIL import Image
import matplotlib.pyplot as plt

import copy

from pytorch_datasetloader import *

# The Discriminator

In [2]:
class Discriminator(nn.Module):
    def __init__(self, input_nc):
        super(Discriminator, self).__init__()

        # A bunch of convolutions one after another
        model = [   nn.Conv2d(input_nc, 64, 4, stride=2, padding=1),
                    nn.LeakyReLU(0.2, inplace=True) ]

        model += [  nn.Conv2d(64, 128, 4, stride=2, padding=1),
                    nn.InstanceNorm2d(128), 
                    nn.LeakyReLU(0.2, inplace=True) ]

        model += [  nn.Conv2d(128, 256, 4, stride=2, padding=1),
                    nn.InstanceNorm2d(256), 
                    nn.LeakyReLU(0.2, inplace=True) ]

        model += [  nn.Conv2d(256, 512, 4, padding=1),
                    nn.InstanceNorm2d(512), 
                    nn.LeakyReLU(0.2, inplace=True) ]

        # FCN classification layer
        model += [nn.Conv2d(512, 1, 4, padding=1)]

        self.model = nn.Sequential(*model)

    def forward(self, x):
        x =  self.model(x)
        # Average pooling and flatten
        return F.avg_pool2d(x, x.size()[2:]).view(x.size()[0], -1)



# The Generator

In [3]:
class ResidualBlock(nn.Module):
    def __init__(self, in_features):
        super(ResidualBlock, self).__init__()

        conv_block = [  nn.ReflectionPad2d(1),
                        nn.Conv2d(in_features, in_features, 3),
                        nn.InstanceNorm2d(in_features),
                        nn.ReLU(inplace=True),
                        nn.ReflectionPad2d(1),
                        nn.Conv2d(in_features, in_features, 3),
                        nn.InstanceNorm2d(in_features)  ]

        self.conv_block = nn.Sequential(*conv_block)

    def forward(self, x):
        return x + self.conv_block(x)

class Generator(nn.Module):
    def __init__(self, input_nc, output_nc, n_residual_blocks=9):
        super(Generator, self).__init__()

        # Initial convolution block       
        model = [   nn.ReflectionPad2d(3),
                    nn.Conv2d(input_nc, 64, 7),
                    nn.InstanceNorm2d(64),
                    nn.ReLU(inplace=True) ]

        # Downsampling
        in_features = 64
        out_features = in_features*2
        for _ in range(2):
            model += [  nn.Conv2d(in_features, out_features, 3, stride=2, padding=1),
                        nn.InstanceNorm2d(out_features),
                        nn.ReLU(inplace=True) ]
            in_features = out_features
            out_features = in_features*2

        # Residual blocks
        for _ in range(n_residual_blocks):
            model += [ResidualBlock(in_features)]

        # Upsampling
        out_features = in_features//2
        for _ in range(2):
            model += [  nn.ConvTranspose2d(in_features, out_features, 3, stride=2, padding=1, output_padding=1),
                        nn.InstanceNorm2d(out_features),
                        nn.ReLU(inplace=True) ]
            in_features = out_features
            out_features = in_features//2

        # Output layer
        model += [  nn.ReflectionPad2d(3),
                    nn.Conv2d(64, output_nc, 7),
                    nn.Tanh() ]

        self.model = nn.Sequential(*model)

    def forward(self, x):
        return self.model(x)

In [4]:
#using cuda if available
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

netD_A = Discriminator(3)
netD_A = netD_A.to(device)
netD_B = Discriminator(3)
netD_B = netD_B.to(device)

netG_A2B = Generator(3,3)
netG_A2B = netG_A2B.to(device)
netG_B2A = Generator(3,3)
netG_B2A = netG_B2A.to(device)

cuda:0


# Loading Dataset

In [7]:
batch_size_Realpix = 10
batch_size_mixed = 10
dataset_len =  350




path = 'RealFaces_350'
trainloader_nonpixel, validloader_nonpixel = get_loaders(path,split_perc=1.0,batch_size=batch_size_Realpix)

path = 'PixelFaces_350'
trainloader_pixel, validloader_pixel = get_loaders(path,split_perc=1.0,batch_size=batch_size_mixed)


trainiter_nonpixel = iter(trainloader_nonpixel)
trainIter_pixel = iter(trainloader_pixel)




350 images from the dataset
350 images from the dataset


# Selecting Optimizer 

In [8]:
criterion_GAN = torch.nn.MSELoss()
criterion_cycle = torch.nn.L1Loss()
criterion_identity = torch.nn.L1Loss()


# Optimizers & LR schedulers
optimizer_G = torch.optim.Adam(itertools.chain(netG_A2B.parameters(), netG_B2A.parameters()),
                                lr=0.0002, betas=(0.5, 0.999))
optimizer_D_A = torch.optim.Adam(netD_A.parameters(), lr=0.0002, betas=(0.5, 0.999))
optimizer_D_B = torch.optim.Adam(netD_B.parameters(), lr=0.0002, betas=(0.5, 0.999))

# Training CycleGan

In [9]:
#Patch Version

# Train the model
total_step = len(trainiter_nonpixel)
d_loss_list = []
g_loss_list = []
acc_list = []
num_epochs = 2500
sample_imflag = 0
imsize = 128

Tensor = torch.cuda.FloatTensor

# A: Non-pixel, B: pixel

for epoch in range(num_epochs):
    
    trainiter_nonpixel = iter(trainloader_nonpixel)
    trainIter_pixel = iter(trainloader_pixel)
    
    for i in range(5):
        
        real_pixel, input_lbls = trainIter_pixel.next()
        real_nonpixel, out_lbls = trainiter_nonpixel.next()
        
        
        real_pixel = real_pixel.to(device)
        real_nonpixel = real_nonpixel.to(device)
        
        target_real = Variable(Tensor(len(input_lbls)).fill_(1.0), requires_grad=False)
        target_fake = Variable(Tensor(len(out_lbls)).fill_(0.0), requires_grad=False)
        
        
        netG_A2B.zero_grad()
        netG_B2A.zero_grad()
        
        #self identity loss:
        
        pixel_gen = netG_A2B(real_pixel)
        loss_identity_B = criterion_identity(pixel_gen, real_pixel)*20.0
        

        nonpixel_gen = netG_B2A(real_nonpixel)
        loss_identity_A = criterion_GAN(nonpixel_gen, real_nonpixel)*20.0

        
        # GAN loss
        actual_pixel_gen = netG_A2B(real_nonpixel)
        pred_pixel = netD_B(actual_pixel_gen)
        loss_GAN_B2A = criterion_GAN(pred_pixel, target_real)
        

        actual_nonpixel_gen = netG_B2A(real_pixel)
        pred_nonpixel = netD_B(actual_nonpixel_gen)
        loss_GAN_A2B = criterion_GAN(pred_nonpixel, target_real)        
        
        
        # Cycle loss
        recovered_nonpixel = netG_B2A(actual_pixel_gen)
        loss_cycle_ABA = criterion_cycle(recovered_nonpixel, real_nonpixel)*10.0

        recovered_pixel = netG_A2B(actual_nonpixel_gen)
        loss_cycle_BAB = criterion_cycle(recovered_pixel, real_pixel)*10.0
        
        
        
        # Total loss
        loss_G = loss_identity_A + loss_identity_B + loss_GAN_A2B + loss_GAN_B2A + loss_cycle_ABA + loss_cycle_BAB
        loss_G.backward()
        optimizer_G.step()
        
        
        ###### Discriminator A ######
        optimizer_D_A.zero_grad()

        # Real loss
        pred_real = netD_A(real_nonpixel)
        loss_D_real = criterion_GAN(pred_real, target_real)

        # Fake loss
        actual_nonpixel_gen = netG_B2A(real_pixel)
        pred_fake = netD_A(actual_nonpixel_gen)
        loss_D_fake = criterion_GAN(pred_fake, target_fake)

        # Total loss
        loss_D_A = (loss_D_real + loss_D_fake)*0.5
        loss_D_A.backward()
        optimizer_D_A.step()
        
        
        ###### Discriminator B ######
        optimizer_D_B.zero_grad()

        # Real loss
        pred_real = netD_B(real_pixel)
        loss_D_real = criterion_GAN(pred_real, target_real)
        
        # Fake loss
        actual_pixel_gen = netG_A2B(real_nonpixel)
        pred_fake = netD_B(actual_pixel_gen)
        loss_D_fake = criterion_GAN(pred_fake, target_fake)

        # Total loss
        loss_D_B = (loss_D_real + loss_D_fake)*0.5
        loss_D_B.backward()

        optimizer_D_B.step()        
        

 
        

      
    if (epoch+1) % 10 == 0:
        print('Epoch No [{}/{}] Discriminator Loss: {:.4f}, Generator Loss: {:.4f}'.format(epoch+1,num_epochs,loss_D_B.item(),loss_G.item()))
        #appending loss every 10 epochs
        d_loss_list.append(loss_D_B.item())
        g_loss_list.append(loss_G.item())
    
        # Save fake images
        fake_images = actual_pixel_gen.view(actual_pixel_gen.size(0), 3, imsize, imsize)
        save_image(fake_images.data, 'images/GAN_IO/fakes/fake_images-%d.png' %(epoch+1))
        
        
        # Save input images
        if sample_imflag == 0:
            inputImages = real_nonpixel.view(real_nonpixel.size(0), 3, imsize, imsize)
            save_image(inputImages.data, 'images/GAN_IO/input_images/inputImages-%d.png' %(epoch+1))
            sample_imflag = 1
    



Epoch No [10/2500] Discriminator Loss: 0.0530, Generator Loss: 6.4598
Epoch No [20/2500] Discriminator Loss: 0.1142, Generator Loss: 5.6953
Epoch No [30/2500] Discriminator Loss: 0.0715, Generator Loss: 5.1546
Epoch No [40/2500] Discriminator Loss: 0.2964, Generator Loss: 6.8656
Epoch No [50/2500] Discriminator Loss: 0.0362, Generator Loss: 6.0577
Epoch No [60/2500] Discriminator Loss: 0.0837, Generator Loss: 5.3438
Epoch No [70/2500] Discriminator Loss: 0.0568, Generator Loss: 4.2689
Epoch No [80/2500] Discriminator Loss: 0.0461, Generator Loss: 4.5756
Epoch No [90/2500] Discriminator Loss: 0.0467, Generator Loss: 5.4504
Epoch No [100/2500] Discriminator Loss: 0.0679, Generator Loss: 3.5014
Epoch No [110/2500] Discriminator Loss: 0.0578, Generator Loss: 4.1587
Epoch No [120/2500] Discriminator Loss: 0.0239, Generator Loss: 3.6526
Epoch No [130/2500] Discriminator Loss: 0.0636, Generator Loss: 4.0301
Epoch No [140/2500] Discriminator Loss: 0.0768, Generator Loss: 3.9887
Epoch No [150/2

KeyboardInterrupt: 

In [10]:
    
# Save models checkpoints
torch.save(netG_A2B.state_dict(), 'GAN_OUTs/netG_A2B_HigherCyclce.pth')
torch.save(netG_B2A.state_dict(), 'GAN_OUTs/netG_B2A_HigherCylce.pth')
torch.save(netD_A.state_dict(), 'GAN_OUTs/netD_A_HigherCycle.pth')
torch.save(netD_B.state_dict(), 'GAN_OUTs/netD_B_Cylce.pth')

In [6]:
# # Save models checkpoints
# torch.save(netG_A2B.state_dict(), 'GAN_OUTs/netG_A2B.pth')
# torch.save(netG_B2A.state_dict(), 'GAN_OUTs/netG_B2A.pth')
# torch.save(netD_A.state_dict(), 'GAN_OUTs/netD_A.pth')
# torch.save(netD_B.state_dict(), 'GAN_OUTs/netD_B.pth')



plt.figure(figsize=[12,8])
plt.plot(d_loss_list)
plt.plot(g_loss_list)
plt.ylabel('Loss')
plt.xlabel('x10 Epochs')
plt.savefig('images/GAN_LOSS_PLOT/GAN_Loss_CycleGAN.png')

NameError: name 'd_loss_list' is not defined

<Figure size 864x576 with 0 Axes>

# Load and Inference

In [12]:
G = Generator(3,3)
G = G.to(device)
G.load_state_dict(torch.load('GAN_OUTs/netG_A2B.pth'))

batch_size = 10

path = 'RealFaces_350'
trainloader_pixel, _ = get_loaders(path,split_perc=1.0,batch_size=batch_size,num_workers=0)

testImages = iter(trainloader_pixel).next()

testimg = testImages[0].to(device)
print(testimg.shape)

output = G(testimg)

fake_images = output.view(output.size(0), 3, 128, 128)
f = fake_images.detach().cpu().numpy()
save_image(fake_images.data, 'images/GAN_IO/CycleGAN_output.png')

testImages = testimg.view(testimg.size(0), 3, 128, 128)
t = testImages.detach().cpu().numpy()
save_image(testimg.data, 'images/GAN_IO/CycleGAN_input.png')

350 images from the dataset
torch.Size([10, 3, 128, 128])
