In [1]:
import os
import sys
import glob
import math
import time
import random
import itertools
import datetime
import numpy as np
from PIL import Image

import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets
from torch.autograd import Variable
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from torchvision.models import resnet18
from torchvision.utils import save_image
import torchvision.transforms as transforms

Data preprocessing for the dataloader

In [2]:
class ImageDataset(Dataset):
    def __init__(self, root, input_shape, mode="train1"):
        self.transform = transforms.Compose(
            [
                transforms.Resize((128,128), Image.BICUBIC), #input_shape[-2:], Image.BICUBIC),
                transforms.ToTensor(),
                transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
            ]
        )

        self.files = sorted(glob.glob(os.path.join(root, mode) + "/*.*"))
        self.mode = mode

    def __getitem__(self, index):

        img = Image.open(self.files[index % len(self.files)])
        w, h = img.size
        img_A = img.crop((0, 0, w / 2, h))
        img_B = img.crop((w / 2, 0, w, h))

        if np.random.random() < 0.5 and self.mode == "train1":
            img_A = Image.fromarray(np.array(img_A)[:, ::-1, :], "RGB")
            img_B = Image.fromarray(np.array(img_B)[:, ::-1, :], "RGB")

        img_A = self.transform(img_A)
        img_B = self.transform(img_B)

        return {"A": img_A, "B": img_B}

    def __len__(self):
        return len(self.files)

Populates the weights with values from a normal distribution

In [3]:
def weights_init_normal(m):
    classname = m.__class__.__name__
    if classname.find("Conv") != -1:
        torch.nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find("BatchNorm2d") != -1:
        torch.nn.init.normal_(m.weight.data, 1.0, 0.02)
        torch.nn.init.constant_(m.bias.data, 0.0)

Parameters
1. in_size : Input dimension(channels number) 
2. out_size : Output dimension(channels number)
3. normalize : If it is true add Batch Normalization layer, otherwise skip this layer
4. dropout : probability for dropping a unit

In [4]:
class UNetDown(nn.Module):
    def __init__(self, in_size, out_size, normalize=True, dropout=0.0):
        super(UNetDown, self).__init__()
        layers = [nn.Conv2d(in_size, out_size, 3, stride=2, padding=1, bias=False)]
        if normalize:
            layers.append(nn.BatchNorm2d(out_size, 0.8))
        layers.append(nn.LeakyReLU(0.2))
        self.model = nn.Sequential(*layers)

    def forward(self, x):
        return self.model(x)

In [5]:
class UNetUp(nn.Module):
    def __init__(self, in_size, out_size):
        super(UNetUp, self).__init__()
        self.model = nn.Sequential(
            nn.Upsample(scale_factor=2),
            nn.Conv2d(in_size, out_size, 3, stride=1, padding=1, bias=False),
            nn.BatchNorm2d(out_size, 0.8),
            nn.ReLU(inplace=True),
        )

    def forward(self, x, skip_input):
        x = self.model(x)
        x = torch.cat((x, skip_input), 1)
        return x

Unet's analog

In [6]:
class Generator(nn.Module):
    def __init__(self, latent_dim, img_shape):
        super(Generator, self).__init__()
        channels, self.h, self.w = img_shape

        self.fc = nn.Linear(latent_dim, self.h * self.w)

        self.down1 = UNetDown(channels + 1, 64, normalize=False)
        self.down2 = UNetDown(64, 128)
        self.down3 = UNetDown(128, 256)
        self.down4 = UNetDown(256, 512)
        self.down5 = UNetDown(512, 512)
        self.down6 = UNetDown(512, 512)
        self.down7 = UNetDown(512, 512, normalize=False)
        self.up1 = UNetUp(512, 512)
        self.up2 = UNetUp(1024, 512)
        self.up3 = UNetUp(1024, 512)
        self.up4 = UNetUp(1024, 256)
        self.up5 = UNetUp(512, 128)
        self.up6 = UNetUp(256, 64)

        self.final = nn.Sequential(
            nn.Upsample(scale_factor=2), nn.Conv2d(128, channels, 3, stride=1, padding=1), nn.Tanh()
        )

    def forward(self, x, z):
        # Propogate noise through fc layer and reshape to img shape
        #x:(N,3,128,128) z:(N,8)
        z = self.fc(z).view(z.size(0), 1, self.h, self.w)#z:(N,1,128,128)
        
        #concating (x and z): (N,4,128,128)
        d1 = self.down1(torch.cat((x, z), 1)) #d1:(N,64,64,64)
        d2 = self.down2(d1)         #d2:(N,128,32,32)
        d3 = self.down3(d2)         #d3:(N,256,16,16)
        d4 = self.down4(d3)         #d4:(N,512,8,8)
        d5 = self.down5(d4)         #d5:(N,512,4,4)
        d6 = self.down6(d5)         #d6:(N,512,2,2)
        d7 = self.down7(d6)         #d7:(N,512,1,1)
        u1 = self.up1(d7, d6)       #u1:(N,1024,2,2)
        u2 = self.up2(u1, d5)       #u2:(N,1024,4,4)
        u3 = self.up3(u2, d4)       #u3:(N,1024,8,8)
        u4 = self.up4(u3, d3)       #u4:(N,512,16,16)
        u5 = self.up5(u4, d2)       #u5:(N,256,32,32)
        u6 = self.up6(u5, d1)       #u6:(N,128,64,64)

        return self.final(u6)       #final:(N,3,128,128)

MultiDiscriminator uses three discriminators for object sizes 32, 64, 128

In [7]:
class MultiDiscriminator(nn.Module):
    def __init__(self, input_shape):
        super(MultiDiscriminator, self).__init__()

        def discriminator_block(in_filters, out_filters, normalize=True):
            """Returns downsampling layers of each discriminator block"""
            layers = [nn.Conv2d(in_filters, out_filters, 4, stride=2, padding=1)]
            if normalize:
                layers.append(nn.BatchNorm2d(out_filters, 0.8))
            layers.append(nn.LeakyReLU(0.2))
            return layers

        channels, _, _ = input_shape
        # Extracts discriminator models
        self.models = nn.ModuleList()
        for i in range(3):
            self.models.add_module(
                "disc_%d" % i,
                nn.Sequential(
                    *discriminator_block(channels, 64, normalize=False),
                    *discriminator_block(64, 128),
                    *discriminator_block(128, 256),
                    *discriminator_block(256, 512),
                    nn.Conv2d(512, 1, 3, padding=1)
                ),
            )

        self.downsample = nn.AvgPool2d(3, stride=2, padding=[1, 1], count_include_pad=False)

    def compute_loss(self, x, gt):
        """Computes the MSE between model output and scalar gt"""
        loss = sum([torch.mean((out - gt) ** 2) for out in self.forward(x)])
        return loss

    def forward(self, x):
        outputs = []
        for m in self.models:
            outputs.append(m(x))
            x = self.downsample(x)
        return outputs

Encoder

In [8]:
class Encoder(nn.Module):
    def __init__(self, latent_dim, input_shape):
        super(Encoder, self).__init__()
        resnet18_model = resnet18(pretrained=False)
        self.feature_extractor = nn.Sequential(*list(resnet18_model.children())[:-3])
        self.pooling = nn.AvgPool2d(kernel_size=8, stride=8, padding=0)
        # Output is mu and log(var) for reparameterization trick used in VAEs
        self.fc_mu = nn.Linear(256, latent_dim)
        self.fc_logvar = nn.Linear(256, latent_dim)

    def forward(self, img):
        #img : (N, 3, 128, 128)
        out = self.feature_extractor(img)  # out : (N, 256, 8, 8)
        out = self.pooling(out)            # out : (N, 256, 1, 1)
        out = out.view(out.size(0), -1)    # out : (N, 256)
        mu = self.fc_mu(out)               # mu : (N, latent_dim)
        logvar = self.fc_logvar(out)       # logvar : (N, latent_dim)
        return mu, logvar

In [9]:
def reparameterization(mu, logvar):
    std = torch.exp(logvar / 2)
    sampled_z = Variable(Tensor(np.random.normal(0, 1, (mu.size(0), latent_dim))))
    z = sampled_z * std + mu
    return z
    # z : (N, latent_dim)

In [10]:
# Archive with training files
!unzip -q /content/drive/MyDrive/cv_project/cv_project_s2p.zip

In [32]:
epoch = 0                      #epoch to start training from
n_epochs = 150                 #number of epochs of training
dataset_name = "sketch2face"   #name of the dataset
batch_size = 16                 #size of the batches
lr = 0.0005                    #adam: learning rate
b1 = 0.6                       #adam: decay of first order momentum of gradient
b2 = 0.999                     #adam: decay of second order momentum of gradient
n_cpu = 8                      #number of cpu threads to use during batch generation
img_height = 128               #size of image height
img_width = 128                #size of image width
channels = 3                   #number of image channels
latent_dim = 8                 #number of latent codes
lambda_pixel = 10              #pixelwise loss weight
lambda_latent = 0.6            #latent loss weight
lambda_kl = 0.02               #kullback-leibler loss weight
mae_loss = torch.nn.L1Loss()   #Mean Absolute error loss


input_shape = (channels, img_height, img_width)       #shape of input image (tuple)

cuda = True if torch.cuda.is_available() else False   #availability of GPU
Tensor = torch.cuda.FloatTensor if cuda else torch.Tensor


generator = Generator(latent_dim, input_shape)    #Initialize generator
encoder = Encoder(latent_dim, input_shape)        #Initialize encoder
D_VAE = MultiDiscriminator(input_shape)           #initialize discriminators
D_LR = MultiDiscriminator(input_shape)

# Uncomment for further learning

# generator.load_state_dict(torch.load('/content/generator_100.pth'))
# encoder.load_state_dict(torch.load('/content/encoder_100.pth'))
# D_VAE.load_state_dict(torch.load('/content/images/D_VAE_100.pth'))
# D_LR.load_state_dict(torch.load('/content/images/D_LR_100.pth'))

if cuda:
    generator = generator.cuda()
    encoder.cuda()
    D_VAE = D_VAE.cuda()
    D_LR = D_LR.cuda()
    mae_loss.cuda()

# On initial use uncomment

    # Initialize weights
    # generator.apply(weights_init_normal)
    # D_VAE.apply(weights_init_normal)
    # D_LR.apply(weights_init_normal)

In [21]:
os.makedirs("images/%s" % dataset_name, exist_ok=True)

In [13]:
def get_concat_h(im2, im1):
    dst = Image.new('RGB', (im1.width + im2.width, im1.height))
    dst.paste(im1, (0, 0))
    dst.paste(im2, (im1.width, 0))
    return dst


# defining the size of image 
SIZE = 128

image_path = '/content/cv_project_s2p/train/photo'
sketch_path = '/content/cv_project_s2p/train/sketch'

image_file = os.listdir(image_path)


for i in range(len(image_file)):
    im1 = Image.open('/content/cv_project_s2p/train/photo/' + image_file[i])
    im2 = Image.open('/content/cv_project_s2p/train/sketch/' + image_file[i])
    im1 = im1.resize((SIZE, SIZE))
    im2 = im2.resize((SIZE, SIZE))
    get_concat_h(im1, im2).save('/content/cv_project_s2p/train1/' + str(i) + '.jpg')



image_path = '/content/cv_project_s2p/val/photo'
sketch_path = '/content/cv_project_s2p/val/sketch'

image_file = os.listdir(image_path)


for i in range(len(image_file)):
    im1 = Image.open('/content/cv_project_s2p/val/photo/' + image_file[i])
    im2 = Image.open('/content/cv_project_s2p/val/sketch/' + image_file[i])
    im1 = im1.resize((SIZE, SIZE))
    im2 = im2.resize((SIZE, SIZE))
    get_concat_h(im1, im2).save('/content/cv_project_s2p/val1/' + str(i) + '.jpg')



image_path = '/content/cv_project_s2p/test/photo'
sketch_path = '/content/cv_project_s2p/test/sketch'

image_file = os.listdir(image_path)

for i in range(len(image_file)):
    im1 = Image.open('/content/cv_project_s2p/test/photo/' + image_file[i])
    im2 = Image.open('/content/cv_project_s2p/test/skecth/' + image_file[i])
    im1 = im1.resize((SIZE, SIZE))
    im2 = im2.resize((SIZE, SIZE))
    get_concat_h(im2, im1).save('/content/cv_project_s2p/test1/' + str(i) + '.jpg')

In [33]:
optimizer_E = torch.optim.Adam(encoder.parameters(), lr=lr, betas=(b1, b2))
optimizer_G = torch.optim.Adam(generator.parameters(), lr=lr, betas=(b1, b2))
optimizer_D_VAE = torch.optim.Adam(D_VAE.parameters(), lr=lr, betas=(b1, b2))
optimizer_D_LR = torch.optim.Adam(D_LR.parameters(), lr=lr, betas=(b1, b2))

In [15]:
dataloader = DataLoader(
    ImageDataset("/content/cv_project_s2p", input_shape),
    batch_size=16,
    shuffle=True,
    num_workers=n_cpu,
)
val_dataloader = DataLoader(
    ImageDataset("/content/cv_project_s2p", input_shape, mode='val1'),
    batch_size=5,
    shuffle=False,
    num_workers=1,
)

test_dataloader = DataLoader(
    ImageDataset("/content/cv_project_s2p", input_shape, mode='test1'),
    batch_size=16,
    shuffle=False,
    num_workers=1,
)

  "Argument interpolation should be of type InterpolationMode instead of int. "
  cpuset_checked))


In [16]:
os.makedirs("images/sketch2face/img_val", exist_ok=True)
os.makedirs("images/sketch2face/img_test", exist_ok=True)

In [17]:
def sample_images(epoch_done, path="/content/images/sketch2face/", mode='img_val'):

    generator.eval()
    if mode == 'img_val':
        imgs = next(iter(val_dataloader))
    else:
        imgs = next(iter(test_dataloader))
    img_samples = None
    for img_A, img_B in zip(imgs["A"], imgs["B"]):

        # Repeat input image by number of desired columns
        real_A = img_A.view(1, *img_A.shape).repeat(latent_dim, 1, 1, 1)
        real_A = Variable(real_A.type(Tensor))

        # Sample latent representations
        sampled_z = Variable(Tensor(np.random.normal(0, 1, (latent_dim, latent_dim))))
        # Generate samples
        fake_B = generator(real_A, sampled_z)
        # Concatenate samples horisontally
        fake_B = torch.cat([x for x in fake_B.data.cpu()], -1)
        img_sample = torch.cat((img_A, fake_B), -1)
        img_sample = img_sample.view(1, *img_sample.shape)
        # Concatenate with previous samples vertically
        img_samples = img_sample if img_samples is None else torch.cat((img_samples, img_sample), -2)
    save_image(img_samples, path + mode + "/images_"+ str(epoch_done) + ".png", nrow=8, normalize=True)
    
    generator.train()

In [18]:
os.makedirs("images/sketch2face/checkpoints", exist_ok=True)

# Train

In [ ]:
# Adversarial loss
valid = 1
fake = 0

prev_time = time.time()
for epoch in range(epoch, n_epochs):
    for i, batch in enumerate(dataloader):

        # Set model input
        real_A = Variable(batch["A"].type(Tensor))
        real_B = Variable(batch["B"].type(Tensor))

        # -------------------------------
        #  Train Generator and Encoder
        # -------------------------------

        optimizer_E.zero_grad()
        optimizer_G.zero_grad()

        # ----------
        # cVAE-GAN
        # ----------

        # Produce output using encoding of B (cVAE-GAN)
        mu, logvar = encoder(real_B)
        encoded_z = reparameterization(mu, logvar)
        fake_B = generator(real_A, encoded_z)

        # Pixelwise loss of translated image by VAE
        loss_pixel = mae_loss(fake_B, real_B)
        # Kullback-Leibler divergence of encoded B
        loss_kl = 0.5 * torch.sum(torch.exp(logvar) + mu ** 2 - logvar - 1)
        # Adversarial loss
        loss_VAE_GAN = D_VAE.compute_loss(fake_B, valid)

        # ---------
        # cLR-GAN
        # ---------

        # Produce output using sampled z (cLR-GAN)
        sampled_z = Variable(Tensor(np.random.normal(0, 1, (real_A.size(0), latent_dim))))
        _fake_B = generator(real_A, sampled_z)
        # cLR Loss: Adversarial loss
        loss_LR_GAN = D_LR.compute_loss(_fake_B, valid)

        # ----------------------------------
        # Total Loss (Generator + Encoder)
        # ----------------------------------

        loss_GE = loss_VAE_GAN + loss_LR_GAN + lambda_pixel * loss_pixel + lambda_kl * loss_kl

        loss_GE.backward(retain_graph=True)
        optimizer_E.step()

        # ---------------------
        # Generator Only Loss
        # ---------------------

        # Latent L1 loss
        _mu, _ = encoder(_fake_B)
        loss_latent = lambda_latent * mae_loss(_mu, sampled_z)

        loss_latent.backward()
        optimizer_G.step()

        # ----------------------------------
        #  Train Discriminator (cVAE-GAN)
        # ----------------------------------

        optimizer_D_VAE.zero_grad()

        loss_D_VAE = D_VAE.compute_loss(real_B, valid) + D_VAE.compute_loss(fake_B.detach(), fake)

        loss_D_VAE.backward()
        optimizer_D_VAE.step()

        # ---------------------------------
        #  Train Discriminator (cLR-GAN)
        # ---------------------------------

        optimizer_D_LR.zero_grad()

        loss_D_LR = D_LR.compute_loss(real_B, valid) + D_LR.compute_loss(_fake_B.detach(), fake)

        loss_D_LR.backward()
        optimizer_D_LR.step()

        # --------------
        #  Log Progress
        # --------------

        # Determine approximate time left
        batches_done = epoch * len(dataloader) + i
        batches_left = n_epochs * len(dataloader) - batches_done
        time_left = datetime.timedelta(seconds=batches_left * (time.time() - prev_time))
        prev_time = time.time()

        # Print log
        sys.stdout.write(
            "\r[Epoch %d/%d] [Batch %d/%d] [D VAE_loss: %f, LR_loss: %f] [G loss: %f, pixel: %f, kl: %f, latent: %f] ETA: %s"
            % (
                epoch,
                n_epochs,
                i,
                len(dataloader),
                loss_D_VAE.item(),
                loss_D_LR.item(),
                loss_GE.item(),
                loss_pixel.item(),
                loss_kl.item(),
                loss_latent.item(),
                time_left,
            )
        )
        
    if epoch % 4 == 0:
        sample_images(epoch + 1, mode='img_test')
        sample_images(epoch + 1, mode='img_val')
        torch.save(generator.state_dict(), "/content/images/sketch2face/checkpoints/generator_"+ str(epoch + 1) + ".pth")
        torch.save(encoder.state_dict(), "/content/images/sketch2face/checkpoints/encoder_"+ str(epoch + 1) + ".pth")
        torch.save(D_VAE.state_dict(), "/content/images/sketch2face/checkpoints/D_VAE_"+ str(epoch + 1) + ".pth")
        torch.save(D_LR.state_dict(), "/content/images/sketch2face/checkpoints/D_LR_"+ str(epoch + 1) + ".pth")