<a href="https://colab.research.google.com/github/SayedAkhtar/ATG-intership/blob/master/ML_Image_Inpainting_Using_PyTorch.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# !curl https://raw.githubusercontent.com/pytorch/xla/master/contrib/scripts/env-setup.py -o pytorch-xla-env-setup.py
# !python pytorch-xla-env-setup.py --version 1.5

In [None]:
from google.colab import drive

In [None]:
import torch
from torch import nn

import argparse
import os
import random
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.optim as optim
import torch.utils.data
import torchvision.datasets as dset
import torchvision.transforms as transforms
import torchvision.utils as vutils
from torch.autograd import Variable

In [None]:
# setting device on GPU if available, else CPU
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Using device:', device)
print()

#Additional Info when using cuda
if device.type == 'cuda':
    print(torch.cuda.get_device_name(0))
    print('Memory Usage:')
    print('Allocated:', round(torch.cuda.memory_allocated(0)/1024**3,1), 'GB')
    print('Cached:   ', round(torch.cuda.memory_cached(0)/1024**3,1), 'GB')

Using device: cuda

Tesla P4
Memory Usage:
Allocated: 0.1 GB
Cached:    3.9 GB


In [None]:
EPOCHS = 100
BATCH_SIZE = 26
LR = 2e-4
BETAL = 0.5
WTL2 = 0.999

In [None]:
try:
  os.makedirs("./result/real")
  os.makedirs("./result/real_individual")
  os.makedirs("./result/recon")
  os.makedirs("./result/recon_individual")
except OSError:
  pass

In [None]:
# ls


#**LOADING DATA**


In [None]:
transform = transforms.Compose([transforms.Resize(128),
                                transforms.CenterCrop(128),
                                transforms.ToTensor(),
                                transforms.Normalize((05.,0.5,0.5),(0.5,0.5,0.5))])
dataset = dset.ImageFolder(root = './Dataset/', transform = transform)
assert dataset
dataloader = torch.utils.data.DataLoader(dataset=dataset, batch_size=BATCH_SIZE,shuffle=True, num_workers=0)

#**Generator**

In [None]:
import torch
from torch import nn
class generator(nn.Module):
  def __init__(self):
    super(generator,self).__init__()
    self.t1 = nn.Sequential(
      nn.Conv2d(in_channels=3, out_channels=128, kernel_size=(4,4),stride=(1,1), padding=1,),
      nn.LeakyReLU(0.2,inplace=True)
    )
    self.t2 = nn.Sequential(
      nn.Conv2d(in_channels=128, out_channels=512, kernel_size=(4,4), stride= (1,1), padding=1),
      nn.BatchNorm2d(512),
      nn.LeakyReLU(0.2, inplace=True)
    )

    self.t3 = nn.Sequential(
      nn.ConvTranspose2d(in_channels=512, out_channels=128, kernel_size=(4,4), stride= (1,1), padding=1),
      nn.BatchNorm2d(128),
      nn.ReLU()
    )
    self.t4 = nn.Sequential(
      nn.ConvTranspose2d(in_channels=128, out_channels=3, kernel_size=(4,4), padding=1),
      nn.Tanh()    
    )
    
  def forward(self,x):
      x=self.t1(x)
      x=self.t2(x)
      x=self.t3(x)
      x=self.t4(x)
      return x

In [None]:
import torch
from torch import nn
class discriminator(nn.Module):
  def __init__(self):
    super(discriminator,self).__init__()
    self.t1 = nn.Sequential(
      nn.Conv2d(in_channels=3, out_channels=64, kernel_size=(6,6),stride=(4,4), padding=0),
      nn.BatchNorm2d(64),
      nn.LeakyReLU(0.2,inplace=True)
    )
    self.t2 = nn.Sequential(
      nn.Conv2d(in_channels=64, out_channels=256, kernel_size=(6,6), stride= (4,4), padding=0),
      nn.BatchNorm2d(256),
      nn.LeakyReLU(0.2, inplace=True)
    )
    self.t3 = nn.Sequential(
      nn.Conv2d(in_channels=256, out_channels=1, kernel_size=(6,6), stride= (4,4), padding=0),
      nn.Sigmoid()
    )

  def forward(self,x):
      x=self.t1(x)
      x=self.t2(x)
      x=self.t3(x)
      return x

#**Initializing Discriminator and Generator**

In [None]:
def weights_init(m):
  classname = m.__class__.__name__
  if classname.find('Conv') != -1:
    m.weight.data.normal_(0, 2e-2),
  elif classname.find('BatchNorm') != -1:
    m.weight.data.normal_(1, 2e-2)
    m.bias.data.fill_(0)

In [None]:
def scale(img):
  img_mod = (img - img.min())/(img.max() - img.min())
  return img_mod

In [None]:
resume_epochs=0
netG = generator()
netG.to(torch.device(device))
netG.apply(weights_init)

netD = discriminator()
netD.to(torch.device(device))
netD.apply(weights_init)

criterion = nn.BCELoss()
criterionMSE = nn.MSELoss()

input_real = torch.cuda.FloatTensor(BATCH_SIZE,3, 128, 128,)
label = torch.cuda.FloatTensor(BATCH_SIZE,)
real_label = 1
fake_label = 0

input_real = Variable(input_real)
label = Variable(label)

optimizerD = optim.Adam(netD.parameters(), lr=LR, betas =( BETAL, 0.999))
optimizerG = optim.Adam(netG.parameters(), lr=LR, betas = (BETAL, 0.999))


In [None]:
for epoch in range( resume_epochs , EPOCHS):
  for i, data in enumerate(dataloader, 0):
    real_data, _ = data 
    BATCH_SIZE = real_data.size(0)
    with torch.no_grad():
      input_real.resize_(real_data.size()).copy_(real_data)

    # train the discriminator with real images
    netD.zero_grad()
    with torch.no_grad():
      label.resize_(BATCH_SIZE).fill_(real_label)

    output = netD(input_real)
    errD_real = criterion(output, label)
    errD_real.backward()
    D_x = output.data.mean()

    # train the discriminator with fake or synthesized images:
    fake = netG(input_real)
    label.data.fill_(fake_label)
    output = netD(fake.detach())
    errD_fake = criterion(output, label)
    errD_fake.backward()
    D_G_z1 = output.data.mean()
    errD = errD_real + errD_fake
    optimizerD.step()

    # train the generator to produce more real looking images:   
    netG.zero_grad()
    label.data.fill_(real_label)
    output = netD(fake)
    errG_D = criterion(output, label)
    errG_l2 = (fake-input_real).pow(2)
    errG_l2 = errG_l2.mean()

    errG = (1-WTL2) * errG_D + WTL2 * errG_l2
    errG.backward()
    D_G_z2 = output.data.mean()
    optimizerG.step()

    # saving real and reconstructed images
    print('[%d / %d][%d / %d] LossD: %.4f LossG: %.4f / %.4f l_D(x): %.4f l_D(G(z)): %.4f'
          %(epoch, EPOCHS, i, len(dataloader), errD.data, errG_D.data, errG_l2.data, D_x, D_G_z1)
          )
    
    if i % 100 == 0:
      vutils.save_image(real_data, 
                        'result/real/real_samples_epoch_%03d.png' % (epoch))
      recon_image=fake
      vutils.save_image(recon_image, 'result/real/real_samples_epoch_%03d.png' % (epoch))

  if(epoch + 1) % 25 == 0:
    for k in range(BATCH_SIZE):
      image = recon_image[k,:,:]
      image = scale(image)
      vutils.save_image(recon_image, 'result/recon_individual/real_samples_epoch_%03d_img%d.png' % (epoch,k))

      image = real_data[k,:,:]
      image = scale(image)
      vutils.save_image(recon_image, 'result/real_individual/real_samples_epoch_%03d_img%d.png' % (epoch,k))
    

  return F.binary_cross_entropy(input, target, weight=self.weight, reduction=self.reduction)


[0 / 100][0 / 1] LossD: 1.9772 LossG: 2.0507 / 27.2472 l_D(x): 0.3701 l_D(G(z)): 0.4518
[1 / 100][0 / 1] LossD: 1.0276 LossG: 2.5495 / 24.6498 l_D(x): 0.7442 l_D(G(z)): 0.4288
[2 / 100][0 / 1] LossD: 0.5034 LossG: 3.4488 / 23.5147 l_D(x): 0.8318 l_D(G(z)): 0.2401
[3 / 100][0 / 1] LossD: 0.4050 LossG: 3.5135 / 22.9634 l_D(x): 0.8613 l_D(G(z)): 0.2120
[4 / 100][0 / 1] LossD: 0.2742 LossG: 3.6534 / 22.6623 l_D(x): 0.8580 l_D(G(z)): 0.1082
[5 / 100][0 / 1] LossD: 0.1873 LossG: 3.7156 / 22.4937 l_D(x): 0.9078 l_D(G(z)): 0.0835
[6 / 100][0 / 1] LossD: 0.1180 LossG: 4.0371 / 22.4205 l_D(x): 0.9393 l_D(G(z)): 0.0526
[7 / 100][0 / 1] LossD: 0.0825 LossG: 4.2404 / 22.3677 l_D(x): 0.9561 l_D(G(z)): 0.0363
[8 / 100][0 / 1] LossD: 0.0713 LossG: 4.1676 / 22.3328 l_D(x): 0.9661 l_D(G(z)): 0.0357
[9 / 100][0 / 1] LossD: 0.0721 LossG: 4.0698 / 22.3108 l_D(x): 0.9717 l_D(G(z)): 0.0421
[10 / 100][0 / 1] LossD: 0.0655 LossG: 4.1674 / 22.2976 l_D(x): 0.9743 l_D(G(z)): 0.0385
[11 / 100][0 / 1] LossD: 0.0481