In [None]:
import os
import numpy as np
import matplotlib.pyplot as plt 
import matplotlib.animation as animation
import random
import math
import io

from PIL import Image
from copy import deepcopy
from IPython.display import HTML

import torch
import torchvision
import torchvision.transforms as transforms
import torchvision.utils as vutils
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

In [None]:
import torch.nn as nn

In [None]:
def conv(in_channels, out_channels, kernel_size, stride=1, padding=0):
    return nn.Conv2d(in_channels, out_channels, kernel_size, stride=stride, 
    padding=padding)
def conv_n(in_channels, out_channels, kernel_size,stride=1,padding=0,inst_norm= False):
    if inst_norm == True:
        return nn.Sequential(nn.Conv2d(in_channels,out_channels, kernel_size, 
        stride=stride, padding=padding) , nn.InstanceNorm2d(out_channels,momentum=0.1,eps= 1e-5))
    else :
        return  nn.Sequential(nn.Conv2d(in_channels,out_channels, kernel_size, 
        stride=stride, padding=padding) , nn.BatchNorm2d(out_channels,momentum=0.1,eps= 1e-5))


def tconv(in_channels, out_channels, kernel_size, stride=1, padding=0, output_padding=0,):
    return nn.ConvTranspose2d(in_channels, out_channels, kernel_size, stride=stride, 
    padding=padding, output_padding=output_padding)



def tconv_n(in_channels, out_channels, kernel_size, stride=1, padding=0, output_padding=0, inst_norm=False):
    if inst_norm == True:
        return nn.Sequential(nn.ConvTranspose2d(in_channels, out_channels, kernel_size, 
        stride=stride, padding=padding, output_padding=output_padding), 
        nn.InstanceNorm2d(out_channels, momentum=0.1, eps=1e-5),)
    else:
        return nn.Sequential(nn.ConvTranspose2d(in_channels, out_channels, kernel_size, 
        stride=stride, padding=padding, output_padding=output_padding), 
        nn.BatchNorm2d(out_channels, momentum=0.1, eps=1e-5),)

In [None]:
dim_c = 3
dim_g = 64

# Generator
class Gen(nn.Module):
    def __init__(self, inst_norm=False):
        super(Gen,self).__init__()
        self.n1 = conv(dim_c, dim_g, 4, 2, 1) 
        self.n2 = conv_n(dim_g, dim_g*2, 4, 2, 1, inst_norm=inst_norm)
        self.n3 = conv_n(dim_g*2, dim_g*4, 4, 2, 1, inst_norm=inst_norm)
        self.n4 = conv_n(dim_g*4, dim_g*8, 4, 2, 1, inst_norm=inst_norm)
        self.n5 = conv_n(dim_g*8, dim_g*8, 4, 2, 1, inst_norm=inst_norm)
        self.n6 = conv_n(dim_g*8, dim_g*8, 4, 2, 1, inst_norm=inst_norm)
        self.n7 = conv_n(dim_g*8, dim_g*8, 4, 2, 1, inst_norm=inst_norm)
        self.n8 = conv(dim_g*8, dim_g*8, 4, 2, 1)

        self.m1 = tconv_n(dim_g*8, dim_g*8, 4, 2, 1, inst_norm=inst_norm)
        self.m2 = tconv_n(dim_g*8*2, dim_g*8, 4, 2, 1, inst_norm=inst_norm)
        self.m3 = tconv_n(dim_g*8*2, dim_g*8, 4, 2, 1, inst_norm=inst_norm)
        self.m4 = tconv_n(dim_g*8*2, dim_g*8, 4, 2, 1, inst_norm=inst_norm)
        self.m5 = tconv_n(dim_g*8*2, dim_g*4, 4, 2, 1, inst_norm=inst_norm)
        self.m6 = tconv_n(dim_g*4*2, dim_g*2, 4, 2, 1, inst_norm=inst_norm)
        self.m7 = tconv_n(dim_g*2*2, dim_g*1, 4, 2, 1, inst_norm=inst_norm)
        self.m8 = tconv(dim_g*1*2, dim_c, 4, 2, 1)
        self.tanh = nn.Tanh()

    def forward(self,x):
        n1 = self.n1(x)
        n2 = self.n2(F.leaky_relu(n1, 0.2))
        n3 = self.n3(F.leaky_relu(n2, 0.2))
        n4 = self.n4(F.leaky_relu(n3, 0.2))
        n5 = self.n5(F.leaky_relu(n4, 0.2))
        n6 = self.n6(F.leaky_relu(n5, 0.2))
        n7 = self.n7(F.leaky_relu(n6, 0.2))
        n8 = self.n8(F.leaky_relu(n7, 0.2))
        m1 = torch.cat([F.dropout(self.m1(F.relu(n8)), 0.5, training=True), n7], 1)
        m2 = torch.cat([F.dropout(self.m2(F.relu(m1)), 0.5, training=True), n6], 1)
        m3 = torch.cat([F.dropout(self.m3(F.relu(m2)), 0.5, training=True), n5], 1)
        m4 = torch.cat([self.m4(F.relu(m3)), n4], 1)
        m5 = torch.cat([self.m5(F.relu(m4)), n3], 1)
        m6 = torch.cat([self.m6(F.relu(m5)), n2], 1)
        m7 = torch.cat([self.m7(F.relu(m6)), n1], 1)
        m8 = self.m8(F.relu(m7))

        return self.tanh(m8)

In [None]:
Generator = Gen()

In [None]:
dim_d = 64

# Discriminator
class Disc(nn.Module):
    def __init__(self, inst_norm=False): 
        super(Disc,self).__init__()
        self.c1 = conv(dim_c*2, dim_d, 4, 2, 1) 
        self.c2 = conv_n(dim_d, dim_d*2, 4, 2, 1, inst_norm=inst_norm)
        self.c3 = conv_n(dim_d*2, dim_d*4, 4, 2, 1, inst_norm=inst_norm)
        self.c4 = conv_n(dim_d*4, dim_d*8, 4, 1, 1, inst_norm=inst_norm)
        self.c5 = conv(dim_d*8, 1, 4, 1, 1)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x, y):
        
        xy=torch.cat([x,y],dim=1)
        xy=F.leaky_relu(self.c1(xy), 0.2)
        xy=F.leaky_relu(self.c2(xy), 0.2)
        xy=F.leaky_relu(self.c3(xy), 0.2)
        xy=F.leaky_relu(self.c4(xy), 0.2)
        xy=self.c5(xy)

        return self.sigmoid(xy)

In [None]:
Discriminator = Disc()

In [None]:
Discriminator

In [None]:
# read = plt.imread('AllHumans/glued_human_face_101.jpg')

# plt.imshow(read)

# def Draw_images(images):
#     for i in range(len(images)):
#         read = plt.imread('AllHumans/'+images[i])
#         plt.imshow(read)
#         plt.show()

# img = [i for i in os.walk('AllHumans/')][0][2]

# Draw_images(img[:10])

# read.shape

# read.shape

# plt.imshow(read[:,:256,:])

# plt.imshow(read[:,256:,:])

# training_images= [i for i in os.walk('AllHumans/')][0][2]

# len(training_images)

# 512//2

# training_images = sorted(training_images)

In [None]:
# plt.imshow(plt.imread('AllHumans/'+training_images[-3]))

In [None]:
def preprocess_images(images,path):
    original_images= {}
    sketch_images = {}
    for i in images:
        read_images = plt.imread(path+i)
        h,w,c = read_images.shape
        original_images[i] =read_images[:,:w//2,:]
        sketch_images[i] = read_images[:,w//2:,:]
    return original_images , sketch_images

In [None]:
# testing_image= [i for i in os.walk('test/')][0][2]

# import pandas as pd

# # for i in list(ground_truth.keys()):
# #     cv2.imwrite('ground_truth/'+i,ground_truth[i])
    

# # for i in list(sketch_images.keys()):
# #     cv2.imwrite('sketch_images/'+i,sketch_images[i])

# training_images= [i for i in os.walk('AllHumans/')][0][2]



# sketch_images_df = pd.DataFrame(sorted([i for i in os.walk('sketch_images/')][0][2]),columns=['sketch_images_path'])

# ground_truth_df = pd.DataFrame(sorted([i for i in os.walk('ground_truth/')][0][2]),columns=['ground_truth_path'])

# training_data = pd.concat([sketch_images_df,ground_truth_df],axis=1)

# training_data

# ground_truth , sketch_images = preprocess_images(testing_image,'test/')

# import cv2

# # for i in list(ground_truth.keys()):
# #     cv2.imwrite('ground_truth_test/'+i,ground_truth[i])

# # for i in list(sketch_images.keys()):
# #     cv2.imwrite('sketch_images_test/'+i,sketch_images[i])





# sketch_images_test_df = pd.DataFrame(sorted([i for i in os.walk('sketch_images_test/')][0][2]),columns=['sketch_images'])

# ground_truth_test_df = pd.DataFrame(sorted([i for i in os.walk('ground_truth_test/')][0][2]),columns=['ground_truth'])

# sketch_images_test_df['sketch_images_path'] = sketch_images_test_df.sketch_images.apply(lambda x: 'sketch_images_test/'+x)
# ground_truth_test_df['ground_truth_path'] = ground_truth_test_df.ground_truth.apply(lambda x: 'ground_truth_test/'+x)

# ground_truth_test_df

# validation_data = pd.concat([sketch_images_test_df,ground_truth_test_df],axis=1)

# validation_data=validation_data.drop(columns=['sketch_images','ground_truth'])

In [None]:
# validation_data.to('validation_data.pkl')

In [None]:
# plt.imshow(ground_truth[0])

In [None]:
# for i in range(10):
#     plt.imshow(ground_truth[i])
#     plt.show()

In [None]:
# for i in range(10):
#     plt.imshow(sketch_images[i])
#     plt.show()

In [None]:
# plt.imshow(plt.imread(sketch_images_df.iloc[2].image_path))

In [None]:
# plt.imshow(plt.imread(ground_truth_df.iloc[2].image_path))

In [None]:
BCE = nn.BCELoss() #binary cross-entropy
L1 = nn.L1Loss() 

#instance normalization
Gen = Gen().cuda(0)
Disc = Disc().cuda(0)

#optimizers
Gen_optim = torch.optim.Adam(Gen.parameters(), lr=2e-4, betas=(0.5, 0.999))
Disc_optim = torch.optim.Adam(Disc.parameters(), lr=2e-4, betas=(0.5, 0.999))

In [None]:
from PIL import Image

In [None]:
# training_data.sketch_images_path = training_data.sketch_images_path.apply(lambda x: 'sketch_images/'+x)
# training_data.ground_truth_path = training_data.ground_truth_path.apply(lambda x: 'ground_truth/'+x)

In [None]:
from torch.utils.data import DataLoader, Dataset
from sklearn.preprocessing import StandardScaler
import torch

class dataset(Dataset): 

    def __init__(self, data):
        self.data= data.copy()
        self.transform = transforms.Compose([transforms.Resize((256,256)),
                                transforms.ToTensor(),
                                transforms.Normalize((0.5,),(0.5,)),])
    def __len__(self): return len(self.data)
    
    def __getitem__(self, idx): 
        current = self.data.iloc[idx]
        input_image = current['sketch_images_path']
        target_image = current['ground_truth_path']
        X = Image.open(input_image)
        y = Image.open(target_image)
        return self.transform(X) , self.transform(y)

      

In [None]:
import pandas as pd
validation_data = pd.read_csv('validation_data.pkl')
training_data = pd.read_pickle('training_data.pkl')

validation_loader =  DataLoader(dataset(validation_data),5,shuffle=True)

training_loader =  DataLoader(dataset(training_data),5,shuffle=True)
x_test ,y_test = next(iter(validation_loader))
x_train ,y_train = next(iter(training_loader))

In [None]:
def show_E2S(batch1, batch2, title1, title2):
    # edges
    plt.figure(figsize=(15,15))
    plt.subplot(1,2,1)
    plt.axis("off")
    plt.title(title1)
    plt.imshow(np.transpose(vutils.make_grid(batch1, nrow=1, padding=5, 
    normalize=True).cpu(),(1,2,0)))
    # faces
    plt.subplot(1,2,2)
    plt.axis("off")
    plt.title(title2)
    plt.imshow(np.transpose(vutils.make_grid(batch2, nrow=1, padding=5, 
    normalize=True).cpu(),(1,2,0)))

show_E2S(y_train,x_train,"input X (edges)","ground truth y (faces)")

In [None]:
def compare_batches(batch1, batch2, title1, title2, batch3, title3):
    # batch1
    plt.figure(figsize=(15,15))
    plt.subplot(1,3,1)
    plt.axis("off")
    plt.title(title1)
    plt.imshow(np.transpose(vutils.make_grid(batch1, nrow=1, padding=5, 
    normalize=True).cpu(), (1,2,0)))
    # batch2
    plt.subplot(1,3,2)
    plt.axis("off")
    plt.title(title2)
    plt.imshow(np.transpose(vutils.make_grid(batch2, nrow=1, padding=5, 
    normalize=True).cpu(), (1,2,0)))
    # third batch
    
#     if batch3:
    plt.subplot(1,3,3)
    plt.axis("off")
    plt.title(title3)
    plt.imshow(np.transpose(vutils.make_grid(batch3, nrow=1, padding=5, 
    normalize=True).cpu(), (1,2,0)))
with torch.no_grad():
    fk = Gen(x_test.cuda(0))
compare_batches(x_test, fk, "input image", "prediction", y_test, "ground truth")

In [None]:
img_list = []
Disc_losses = Gen_losses = Gen_GAN_losses = Gen_L1_losses = []
from tqdm import tqdm 
iter_per_plot = 500
epochs = 10
L1_lambda = 100.0

Gen.train()
for ep in range(epochs):
    for steps, (x,y ) in tqdm(enumerate(training_loader)):
        x = x.cuda(0)
        y = y.cuda(0)
        r_masks = torch.ones(x.size()[0],1,30,30).cuda(0) ## Real Mask
        f_masks = torch.zeros(x.size()[0],1,30,30).cuda(0) ### Fask Mask 

        ############### why we use r_mask (5,1,30,30,)
        #### because the output come from Disc are 30,30,
        ### use that patch to calcuate loss and update the weight
        # disc
        Disc.zero_grad()
        #real_patch
        r_patch=Disc(y,x)
        r_gan_loss=BCE(r_patch,r_masks)

        fake=Gen(x)
        #fake_patch
        
        f_patch = Disc(fake.detach(),x)
        f_gan_loss=BCE(f_patch,f_masks)

        Disc_loss =  f_gan_loss
        Disc_loss.backward()
        Disc_optim.step()

        # gen
        Gen.zero_grad()
        f_patch = Disc(fake,x)
        f_gan_loss=BCE(f_patch,r_masks)

        L1_loss = L1(fake,y)
        Gen_loss = f_gan_loss + L1_lambda*L1_loss
        Gen_loss.backward()
    
        Gen_optim.step()
        if (steps+1)%iter_per_plot == 0 :
            print('Epoch [{}/{}], Step [{}/{}], disc_loss: {:.4f}, gen_loss: {:.4f},Disc(real): {:.2f}, Disc(fake):{:.2f}, gen_loss_gan:{:.4f}, gen_loss_L1:{:.4f}'.format(ep, epochs, steps+1, len(training_loader), Disc_loss.item(), Gen_loss.item(), r_patch.mean(), f_patch.mean(), f_gan_loss.item(), L1_loss.item()))
            
            Gen_losses.append(Gen_loss.item())
            Disc_losses.append(Disc_loss.item())
            Gen_GAN_losses.append(f_gan_loss.item())
            Gen_L1_losses.append(L1_loss.item())

            with torch.no_grad():
                Gen.eval()
                for steps , (x_test, y_test) in tqdm(enumerate(validation_loader)):
                    fake = Gen(x_test.cuda(0)).detach().cpu()
                    
                
            figs=plt.figure(figsize=(10,10))

            plt.subplot(1,3,1)
            plt.axis("off")
            plt.title("input image")
            plt.imshow(np.transpose(vutils.make_grid(x_test, nrow=1, padding=5, 
            normalize=True).cpu(), (1,2,0)))

            plt.subplot(1,3,2)
            plt.axis("off")
            plt.title("generated image")
            plt.imshow(np.transpose(vutils.make_grid(fake, nrow=1, padding=5, 
            normalize=True).cpu(), (1,2,0)))
      
            plt.subplot(1,3,3)
            plt.axis("off")
            plt.title("ground truth")
            plt.imshow(np.transpose(vutils.make_grid(y_test, nrow=1, padding=5, 
            normalize=True).cpu(), (1,2,0)))
      
            plt.savefig(os.path.join('log_path/','gan'+"-"+str(ep) +".png"))
            plt.close()
            img_list.append(figs)

In [None]:
# t= Gen(x_test.cuda(0))

In [None]:
# Gen

In [None]:
# torch.save(Gen,'Generative.pt')

In [None]:
# torch.save({
# 'model_state_dict': Gen.state_dict(),
# 'optimizer_state_dict': Gen_optim.state_dict(),
# }, 'Gen')

In [None]:
# for steps , (x,y) in enumerate(training_loader):
#     Disc.zero_grad()
#     r_p = torch.ones(5,1,30,30).cuda(0)
#     r_f = torch.zeros(5,1,30,30).cuda(0)
#     r_pc= Disc(y.cuda(0),x.cuda(0))
#     loss_d = BCE(r_pc,r_p)
    
    
#     fake = Gen(x.cuda(0))
#     f_g = Disc(fake,x.cuda(0))
#     lossf = BCE(f_g,r_f)
#     Disc_loss = lossf + loss_d
#     Disc_loss.backward()
#     Disc_optim.step()
    
#     Gen.zero_grad()
#     Disc(fake,x)
    
    
    

In [None]:
# f_patch = Disc(fake,x)
# f_gan_loss=BCE(f_patch,r_masks)

# L1_loss = L1(fake,y)
# Gen_loss = f_gan_loss + L1_lambda*L1_loss
# Gen_loss.backward()

# Gen_optim.step()

In [None]:
torch

In [None]:
Disc_losses

In [None]:
# 