In [None]:
import torch 
import torchvision
import numpy
from torchvision import transforms
from torch import nn
from torch.autograd import Variable
import torchvision.datasets as dset
import torch.optim as optim
import cv2
import numpy as np
import matplotlib.pyplot as plt
import os

In [None]:
def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        m.weight.data.normal_(0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        m.weight.data.normal_(1.0, 0.02)
        m.bias.data.fill_(0)
def denorm(x):
    out = (x + 1) / 2
    return out.clamp(0, 1)

In [None]:
def plot_result(generator, real_image, num_epoch, save=True, save_dir='MNIST_DCGAN_results/', show=False, fig_size=(5, 5)):
    generator.eval()


    gen_image = generator(real_image)
    gen_image = denorm(gen_image)

    generator.train()

    n_rows = np.sqrt(real_image.size()[0]).astype(np.int32)
    n_cols = np.sqrt(real_image.size()[0]).astype(np.int32)
    fig, axes = plt.subplots(n_rows, n_cols, figsize=fig_size)
    
    axes.axis('off')
    axes.set_adjustable('box-forced')
    axes.imshow(gen_image.cpu().data.view(128, 128,3).numpy(), cmap='gray', aspect='equal')
    plt.subplots_adjust(wspace=0, hspace=0)
    title = 'Epoch {0}'.format(num_epoch+1)
    fig.text(0.5, 0.04, title, ha='center')

    # save figure
    if save:
        if not os.path.exists(save_dir):
            os.mkdir(save_dir)
        save_fn = save_dir + 'MNIST_DCGAN_epoch_{:d}'.format(num_epoch+1) + '.png'
        plt.savefig(save_fn)

    if show:
        plt.show()
    else:
        plt.close()



In [None]:
## Generator U-net
class generator(nn.Module):
    def __init__(self, input_size = 256, in_c = 3):
        super(generator, self).__init__()
        # `
        
        self.encoder_block_1 = nn.Sequential(
        nn.Conv2d(in_c, 64, 4, 2, 1, bias = False),
        nn.BatchNorm2d(64)
        )
       
        
        self.encoder_block_2 = nn.Sequential(
        nn.Conv2d(64, 128, 4, 2, 1,bias = False),
        nn.BatchNorm2d(128),
        nn.LeakyReLU(0.2, inplace = True)
        
        )
        self.encoder_block_3 = nn.Sequential(
        nn.Conv2d(128, 256, 4, 2, 1,bias = False),
        nn.BatchNorm2d(256),
        nn.LeakyReLU(0.2, inplace = True)
        
        )
        self.encoder_block_4 = nn.Sequential(
        nn.Conv2d(256, 512, 4, 2, 1,bias = False),
        nn.BatchNorm2d(512),
        nn.LeakyReLU(0.2, inplace = True)
        
        )
        self.encoder_block_5 = nn.Sequential(
        nn.Conv2d(512, 512, 4, 2, 1,bias = False),
        nn.BatchNorm2d(512),
        nn.LeakyReLU(0.2, inplace = True)
        
        )
        self.encoder_block_6 = nn.Sequential(
        nn.Conv2d(512, 512, 4, 2, 1,bias = False),
        nn.BatchNorm2d(512),
        nn.LeakyReLU(0.2, inplace = True)
        
        )
        
        self.encoder_block_7 = nn.Sequential(
        nn.Conv2d(512, 512, 4, 2, 1,bias = False),
        nn.ReLU(inplace = True)
        
        )
        
        self.decoder_block_7 = nn.Sequential(
        nn.ConvTranspose2d(512, 512, 4, 2, 1),
        #nn.BatchNorm2d(512),
        nn.ReLU(True)
        
        )
        
        self.decoder_block_6 = nn.Sequential(
        nn.ConvTranspose2d(1024, 512, 4, 2, 1),
        nn.BatchNorm2d(512),
        nn.ReLU(True)
        
        )
        self.decoder_block_5 = nn.Sequential(
        nn.ConvTranspose2d(1024, 512, 4, 2, 1),
        nn.BatchNorm2d(512),
        nn.ReLU(True)
        
        )
        self.decoder_block_4 = nn.Sequential(
        nn.ConvTranspose2d(1024, 256, 4, 2, 1),
        nn.BatchNorm2d(256),
        nn.ReLU(True)
        
        )
        self.decoder_block_3 = nn.Sequential(
        nn.ConvTranspose2d(512, 128, 4, 2, 1),
        nn.BatchNorm2d(128),
        nn.ReLU(True)
        
        )
        self.decoder_block_2 = nn.Sequential(
        nn.ConvTranspose2d(256, 64, 4, 2, 1),
        nn.BatchNorm2d(64),
        nn.ReLU(True)
        
        )
        
        self.decoder_block_1 = nn.Sequential(
        nn.ConvTranspose2d(128, 3, 4, 2, 1),
        nn.Tanh()
        
        )
        
    def forward(self, x):
        x1 = self.encoder_block_1(x)
        x2 = self.encoder_block_2(x1)
        x3= self.encoder_block_3(x2)
        x4= self.encoder_block_4(x3)
        x5 = self.encoder_block_5(x4)
        x6 = self.encoder_block_6(x5)
        x7 = self.encoder_block_7(x6)
        
        
        #x8 = self.encoder_block_1(x7)
        
        #y8 = self.decoder_block_8(x8)
        
        
        y7 = self.decoder_block_7(x7)
        y6 = self.decoder_block_6(torch.cat((y7,x6),1))
        y5 = self.decoder_block_5(torch.cat((y6,x5),1))
        y4 = self.decoder_block_4(torch.cat((y5,x4),1))
        y3 = self.decoder_block_3(torch.cat((y4,x3),1))
        y2 = self.decoder_block_2(torch.cat((y3,x2),1))
        y1 = self.decoder_block_1(torch.cat((y2,x1),1))
        
        return y1
generator = generator()

In [None]:
# Discriminator

class discriminator(nn.Module):
    def __init__(self, input_ch = 3):
        
        super(discriminator, self).__init__()
        
        self.main_block = nn.Sequential(
        
        nn.Conv2d(6, 64, 4, 2, 1, bias = False),         #output size 64
        nn.LeakyReLU(0.2, True),
        
        nn.Conv2d(64, 128, 4, 2, 1, bias = False),         #output size 32
        nn.BatchNorm2d(128),
        nn.LeakyReLU(0.2, True),
        
        nn.Conv2d(128, 512, 4, 2, 1, bias = False),         #output size 16
        nn.BatchNorm2d(512),
        nn.LeakyReLU(0.2, True),
        
        nn.Conv2d(512, 512, 4, 2, 1, bias = False),      #output size 8
        nn.BatchNorm2d(512),
        nn.LeakyReLU(0.2, True),
        
        nn.Conv2d(512, 512, 4, 2, 1, bias = False),         #output size 4
        nn.BatchNorm2d(512),
        nn.LeakyReLU(0.2, True),
        
        nn.Conv2d(512, 512, 4, 2, 1, bias = False),         #output size 2
        nn.BatchNorm2d(512),
        nn.LeakyReLU(0.2, True),
    
        nn.Conv2d(512, 1, 4, 2, 1, bias = False),
        nn.Sigmoid()
        )
    def forward(self, input_img, target_img):
        concatenated = torch.cat((input_img, target_img),1)
        x = self.main_block(concatenated)
        return x
discriminator = discriminator()


In [None]:
discriminator.apply(weights_init)
generator.apply(weights_init)

discriminator  = discriminator.cuda()
generator = generator.cuda()



In [None]:
#data loader
transform = transforms.Compose(
[transforms.ToTensor(),
 ])

data_A = dset.ImageFolder("/home/abdullah/Documents/ai/models/Generative Adversarial Networks/Own implementation of Gan/pix2pix/dataset/A/",transform)
data_loader_A = torch.utils.data.DataLoader(data_A,
                                          batch_size=1,
                                          shuffle=True,
                                          num_workers=8)

data_B = dset.ImageFolder("/home/abdullah/Documents/ai/models/Generative Adversarial Networks/Own implementation of Gan/pix2pix/dataset/B/",transform)
data_loader_B = torch.utils.data.DataLoader(data_B,
                                          batch_size=1,
                                          shuffle=True,
                                          num_workers=8)

In [None]:
loss1 = nn.BCELoss()
loss2 = nn.L1Loss()

D_optim = optim.Adam(discriminator.parameters(),betas = [0.5, 0.99])
G_optim = optim.Adam(generator.parameters(),betas=[0.5,0.99])

In [None]:
for i in range(epochs):
    n=1
    for (real_images_A,_) ,(real_images_B,_) in zip (data_loader_A, data_loader_B):
        real_images_A = Variable(real_images_A.cuda())
        real_images_B = Variable(real_images_B.cuda())
        discriminator.zero_grad()
        # Discriminator train
                #First with real data
            
        ones = Variable(torch.ones(1).cuda())
        real_decisions = discriminator(real_images_A,real_images_B)
        D_loss_1 = loss1(real_decisions,ones)
                #Them with fake data
            
        zeros = Variable(torch.zeros(1).cuda())
        fake_images = generator(real_images_A)
        fake_decisions = discriminator(real_images_A,fake_images)
        D_loss_2 = loss1(fake_decisions,zeros)
        
        D_loss = D_loss_1 + D_loss_2
        D_loss.backward()
        D_optim.step()
        
        ### Train Generator 
        generator.zero_grad()
        fake_images = generator(real_images_A)
        fake_decisions = discriminator(real_images_A,fake_images)
        G_loss1 = loss1(fake_decisions, ones)
        
        G_loss2 = loss2(fake_images, real_images_B)
        
        G_loss = G_loss1 + 100*G_loss2
        G_loss.backward()
        G_optim.step()
        
        print("{} epoch, {} step,{} Generator loss, {} Discriminator loss".format(i+1,n,G_loss,D_loss))
        
        n+=1
        
        total_image = np.concatenate([(real_images_A.cpu()).detach().numpy(),(fake_images.cpu()).detach().numpy(),(real_images_B.cpu()).detach().numpy()],1)
        plot_result(generator, real_images_A,i)
        
        
        
        
        