<a href="https://colab.research.google.com/github/CanKeles5/ColorizeMountainAdversarial/blob/master/ColorizeMountainAdversarial.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [0]:
%reload_ext autoreload
%autoreload 2
%matplotlib inline

In [0]:
import torch
import torch.nn as nn
import torchvision
import torchvision.datasets as datasets
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import torch.utils.data.sampler
import torchvision.transforms.functional as TF
from torch.utils.data import DataLoader as DL
from torch.utils.data import *
from PIL import Image, ImageFilter
import os
import cv2
import numpy
import random
import fnmatch

In [0]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

In [0]:
! mkdir -p ~/.kaggle/
! mv kaggle.json ~/.kaggle/

In [0]:
! kaggle datasets download -d puneet6060/intel-image-classification

In [0]:
path = "/content/intel-image-classification.zip"
to = "/content/dataset"
! unzip -q -n {path} -d {to}

In [0]:
image_folder = "/content/dataset/seg_train/seg_train/mountain"
image_paths = []

for dirname, _, filenames in os.walk(image_folder):
    for filename in filenames:
        if(fnmatch.fnmatch(dirname, '*mountain*')):
            image_paths.append(os.path.join(dirname, filename))

In [0]:
class MyDataset(Dataset):
  def __init__(self, image_paths, train=True):
    self.image_paths = image_paths
  
  def transforms(self, image):
    tfms = transforms.Resize(size=(256, 256))
    
    tfms = transforms.Compose([
                              transforms.Resize(size=(256, 256)),
                              transforms.RandomHorizontalFlip(),
                              transforms.RandomRotation(degrees=5),
                              transforms.RandomPerspective(distortion_scale=0.05)
                              ])

    image = tfms(image)
    image = image.filter(ImageFilter.MedianFilter())
    image = TF.to_tensor(image)
    
    return image

  def __getitem__(self, index):
    image = Image.open(self.image_paths[index])
    x = self.transforms(image)

    return x
  
  def __len__(self):
    return len(self.image_paths)

In [0]:
class ValidateSet(Dataset):
  def __init__(self, image_paths, train=True):
    self.image_paths = image_paths
  
  def transforms(self, image):
    resize = transforms.Resize(size=(256, 256))
    image = resize(image)
    image = image.filter(ImageFilter.MedianFilter())
    image = TF.to_tensor(image)
    
    return image

  def __getitem__(self, index):
    image = Image.open(self.image_paths[index])
    x = self.transforms(image)

    return x
  
  def __len__(self):
    return len(self.image_paths)

In [0]:
dataset = MyDataset(image_paths[0:2500])
len(dataset)

In [0]:
validate_set = ValidateSet(image_paths[2500: 2512])

In [0]:
train_indices = range(0,2300)
test_indices = range(2300, 2500)

In [0]:
batch_size = 4
train_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, sampler=SubsetRandomSampler(train_indices))
test_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, sampler=SubsetRandomSampler(test_indices))

In [0]:
def en_double_conv(in_channels, out_channels):
    return nn.Sequential(
        nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1),
        nn.LeakyReLU(0.2, inplace=True),
        nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1),
        nn.BatchNorm2d(out_channels),
        nn.LeakyReLU(0.2, inplace=True)
    )

def dec_double_conv(in_channels, out_channels):
  return nn.Sequential(
      nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1),
      nn.ReLU(inplace=True),
      nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1),
      nn.BatchNorm2d(out_channels),
      nn.ReLU(inplace=True)
  )

n=8

class G(nn.Module):
    def __init__(self):
        super().__init__()

        self.dconv_1 = en_double_conv(1, n)
        self.dconv_2 = en_double_conv(n, n*2)
        self.dconv_3 = en_double_conv(n*2, n*4)
        self.dconv_4 = en_double_conv(n*4, n*8)
        self.dconv_5 = en_double_conv(n*8, n*8)
        self.dconv_6 = en_double_conv(n*8, n*8)
        self.dconv_7 = en_double_conv(n*8, n*8)
        self.dconv_8 = en_double_conv(n*8, n*8)

        self.dropout = nn.Dropout(0.5)
        self.maxpool = nn.MaxPool2d(2)

        self.TConv8 = nn.ConvTranspose2d(n*8, n*8, 4, 2, 1)
        self.TConv7 = nn.ConvTranspose2d(n*8*2, n*8*2, 4, 2, 1)
        self.TConv6 = nn.ConvTranspose2d(n*8*3, n*8*3, 4, 2, 1)
        self.TConv5 = nn.ConvTranspose2d(n*8*4, n*8*4, 4, 2, 1)
        self.TConv4 = nn.ConvTranspose2d(n*8*5, n*8*5, 4, 2, 1)
        self.TConv3 = nn.ConvTranspose2d(n*44, n*44, 4, 2, 1)
        self.TConv2 = nn.ConvTranspose2d(n*46, n*46, 4, 2, 1)
        self.TConv1 = nn.ConvTranspose2d(n*47, 3, 4, 2, 1)
        
    def forward(self, x):
        conv1 = self.dconv_1(x)
        conv1 = self.maxpool(conv1)

        conv2 = self.dconv_2(conv1)
        conv2 = self.maxpool(conv2)

        conv3 = self.dconv_3(conv2)
        conv3 = self.maxpool(conv3)

        conv4 = self.dconv_4(conv3)
        conv4 = self.maxpool(conv4)

        conv5 = self.dconv_5(conv4)
        conv5 = self.maxpool(conv5)

        conv6 = self.dconv_6(conv5)
        conv6 = self.maxpool(conv6)

        conv7 = self.dconv_7(conv6)
        conv7 = self.maxpool(conv7)

        conv8 = self.dconv_8(conv7)
        conv8 = self.maxpool(conv8)

        x = self.TConv8(conv8)

        x = torch.cat([x, conv7], dim=1)
        x = self.TConv7(x)
        x = self.dropout(x)

        x = torch.cat([x, conv6], dim=1)
        x = self.TConv6(x)
        x = self.dropout(x)

        x = torch.cat([x, conv5], dim=1)
        x = self.TConv5(x)
        x = self.dropout(x)

        x = torch.cat([x, conv4], dim=1)
        x = self.TConv4(x)
        x = self.dropout(x)

        x = torch.cat([x, conv3], dim=1)
        x = self.TConv3(x)
        x = self.dropout(x)

        x = torch.cat([x, conv2], dim=1)
        x = self.TConv2(x)
        x = self.dropout(x)

        x = torch.cat([x, conv1], dim=1)
        x = self.TConv1(x)
        x = nn.Tanh()(x)

        return x

In [0]:
class D(nn.Module):
  def __init__(self):
    super(D, self).__init__()
    self.main = nn.Sequential(        
        nn.Conv2d(3, 16, 2, 2, 0),
        nn.LeakyReLU(0.2, inplace=True),
        
        nn.Conv2d(16, 32, 2, 2, 0),
        nn.BatchNorm2d(32),
        nn.LeakyReLU(0.2, inplace=True),
        
        nn.Conv2d(32, 64, 2, 2, 0),
        nn.BatchNorm2d(64),
        nn.LeakyReLU(0.2, inplace=True),
        
        nn.Conv2d(64, 128, 2, 2, 0),
        nn.BatchNorm2d(128),
        nn.LeakyReLU(0.2, inplace=True),
        
        nn.Conv2d(128, 256, 2, 2, 0),
        nn.BatchNorm2d(256),
        nn.LeakyReLU(0.2, inplace=True),
        
        nn.Conv2d(256, 512, 2, 2, 0),
        nn.BatchNorm2d(512),
        nn.LeakyReLU(0.2, inplace=True),
        
        nn.Conv2d(512, 1024, 2, 2, 0),
        nn.BatchNorm2d(1024),
        nn.LeakyReLU(0.2, inplace=True),

        nn.Conv2d(1024, 1, 2, 2, 0),
        nn.Sigmoid()
    )
    
  def forward(self, im):
    return self.main(im)

In [0]:
Generator = G().to(device)
Discriminator = D().to(device)

In [0]:
def weights_init(m):
    if isinstance(m, nn.Conv2d):
        m.weight.data.normal_(0, 0.02)
        m.bias.data.normal_(0, 0.001)

In [0]:
weights_init(Generator)
weights_init(Discriminator)

In [0]:
print("Number of parameters in Generator: ", sum([p.numel() for p in Generator.parameters()]))
print("Number of parameters in Discriminator: ", sum([p.numel() for p in Discriminator.parameters()]))

In [0]:
criterion = nn.BCELoss()
adv_criterion = nn.BCELoss()
l1_criterion = nn.L1Loss()
G_optim = torch.optim.Adam(Generator.parameters(), lr=6e-6)
D_optim = torch.optim.Adam(Discriminator.parameters(), lr=6e-6)

In [0]:
Discriminator.train()
Generator.train()

In [0]:
def save_pic(epoch_no):
  Generator.eval()
  im = dataset[0]
  im = (0.2989*im[0,:,:] + 0.5870*im[1,:,:] + 0.1140*im[2,:,:])
  im = im.unsqueeze(0).unsqueeze(0).cuda()
  output = Generator(im)

  p = output[0].detach().cpu()
  p = p.clamp(0.0, 1.0)
  
  PIL_img = transforms.ToPILImage()(p)
  PIL_img = PIL_img.save(str(epoch_no) + ".jpg")
  Generator.train()

In [0]:
D_losses_train = []
G_losses_train = []

D_losses_test = []
G_losses_test = []

In [0]:
def shuffle_data(fake_im, real_im):
  batch_size=fake_im.shape[0]
  data=torch.cat((fake_im, real_im),dim=0)
  labels=torch.cat((torch.zeros(batch_size), torch.ones(batch_size)))
  
  return data, labels

In [0]:
print(len(train_indices))
print(len(test_indices))

In [0]:
n_epochs = 300

for epoch in range(n_epochs):
  D_train_loss = 0.0
  G_train_loss = 0.0

  D_test_loss = 0.0
  G_test_loss = 0.0
  for i, real_im in enumerate(train_loader):
    gray_scale_im=(0.2989*real_im[:,0,:,:] + 0.5870*real_im[:,1,:,:] + 0.1140*real_im[:,2,:,:])
    gray_scale_im=gray_scale_im.unsqueeze(1).to(device)
    ##########Train the discriminator##########
    D_optim.zero_grad()
    real_im=real_im.to(device)
    fake_img = Generator(gray_scale_im)

    data, labels = shuffle_data(fake_img, real_im)
    guess = Discriminator(data)

    D_loss = criterion(guess, labels.to(device))
    D_train_loss += D_loss.item()
    D_loss.backward()
    D_optim.step()
    ###########################################
    
    ############Train the generator############
    G_optim.zero_grad()
    fake_img = Generator(gray_scale_im)
    guess = Discriminator(fake_img).view(-1)
    G_loss_adv = adv_criterion(guess, torch.ones(4).to(device))
    G_loss_l1 = l1_criterion(fake_img, real_im.to(device))
    G_loss = G_loss_adv + G_loss_l1*100
    G_train_loss += G_loss.item()
    G_loss.backward()
    G_optim.step()
    ###########################################
  
  for i, real_im in enumerate(test_loader):
    real_im=real_im.to(device)
    gray_scale_im=(0.2989*real_im[:,0,:,:] + 0.5870*real_im[:,1,:,:] + 0.1140*real_im[:,2,:,:])
    gray_scale_im=gray_scale_im.unsqueeze(1).to(device)

    fake_img = Generator(gray_scale_im)

    guess = Discriminator(fake_img)
    adv_loss = adv_criterion(guess, torch.ones(4).to(device))
    l1_loss = l1_criterion(fake_img, real_im)
    G_loss = adv_loss + l1_loss*100
    G_test_loss += G_loss.item()
  
  G_train_loss = G_train_loss/len(train_indices)
  G_test_loss = G_test_loss/len(test_indices)
  print("Epoch " + str(epoch) + ", Train: " + str(G_train_loss) + " , Test: " + str(G_test_loss))
  G_losses_train.append(G_train_loss)
  G_losses_test.append(G_test_loss)
  save_pic(epoch)

In [0]:
plt.plot(G_losses_train)
plt.plot(G_losses_test)

In [0]:
def d_save_pic(i):
  Generator.eval()
  im = real = validate_set[i]
  gray_scale = im = (0.2989*im[0,:,:] + 0.5870*im[1,:,:] + 0.1140*im[2,:,:])
  im = im.unsqueeze(0).unsqueeze(0).cuda()
  
  output = Generator(im)
  p = output[0].detach().cpu()
  p = p.clamp(0.0, 1.0)
  
  real_img = transforms.ToPILImage()(real)
  real_img = real_img.save("real"+str(i)+".jpg")

  PIL_img = transforms.ToPILImage()(p)
  PIL_img = PIL_img.save("res" + str(i) + ".jpg")

  bw_img = transforms.ToPILImage()(gray_scale)
  bw_img = bw_img.save("bw" + str(i) + ".jpg")

  Generator.train()

In [0]:
for i in range(12):
    d_save_pic(i)