In [None]:
# Importing important libraries
!pip install segmentation_models_pytorch
import segmentation_models_pytorch as smp
import pickle 
from google.colab import drive
import torch, torchvision
import numpy as np 
import random 
import os
from skimage import color
from google.colab import drive
import torch.cuda as cuda
import matplotlib.pyplot as plt
from PIL import Image


drive.mount('/content/drive')
!nvidia-smi

In [2]:
class Dataset(torch.utils.data.Dataset):
  'Characterizes a dataset for PyTorch'
  def __init__(self, images_names, images_path, split):
        'Initialization'
        self.images_names = images_names # list of image names : [353397003_1dca2e74c2_138_97443916@N00, 353560364_72da5ae504_163_27027471@N00, ...]
        self.images_path = images_path
        self.split = split


  def __len__(self):
        'Denotes the total number of samples'
        return len(self.images_names)

  def __getitem__(self, index):
          'Generates one sample of data'
          ID = self.images_names[index]
          img = Image.open(self.images_path + ID +'.jpg').resize((dimension,dimension))
          if self.split == 'Train':
            img = torchvision.transforms.RandomHorizontalFlip()(img)

          img = img.convert('RGB') # (768,768,3)  
          img = np.array(img)  #(768,768,3)
          img = color.rgb2lab(img).astype("float32") #(768,768,3)
          img = torchvision.transforms.ToTensor()(img) #(3,768,768)  0 -> L  |  1 -> a | 2 -> b

          L = torch.unsqueeze(img[0,:,:], 0) #(1,768,768) 
          ab = img[1:3,:,:] #(2,768,768)

          L = (L/50.0) - 1 ## -1 ... 1
          ab = ab / 128 ## -1 ... 1

          return ab, L



In [3]:
# A function that checks if an input image is Black&White or not!  
# source: https://stackoverflow.com/questions/23660929/how-to-check-whether-a-jpeg-image-is-color-or-gray-scale-using-only-python-stdli

def is_grey_scale(img_path):
    img = Image.open(img_path).convert('RGB')
    w, h = img.size
    for i in range(w):
        for j in range(h):
            r, g, b = img.getpixel((i,j))
            if r != g != b: 
                return False
    return True

In [4]:
images_path = "/content/drive/MyDrive/DS_Total/"
images_names = [".".join(image.split(".")[:-1]) for image in os.listdir(images_path)]
random.shuffle(images_names)

######################### JUST RUN THIS BLOCK IN THE FIRST RUN ###############################

###### removing Black and White images 
# print("Number of Images Before Removing Black and White Images :  ", len(images_names))
# for ID in images_names:
#           if is_grey_scale(images_path + ID +'.jpg') == True:
#             images_names.remove(ID)
# print("Number of Images after Removing Black and White Images :  ", len(images_names))

##### spliting the data 
# train_images_names = images_names[:4500]
# test_images_names = images_names[4500:]

# with open("/content/drive/My Drive/val2017/train_2.txt", "wb") as fp:   #Pickling
#     pickle.dump(train_images_names, fp)

# with open("/content/drive/My Drive/val2017/test_2.txt", "wb") as fp:   #Pickling
#     pickle.dump(test_images_names, fp)

###############################################################################################


with open("/content/drive/My Drive/val2017/train_2.txt", "rb") as fp:   # Unpickling
    train_images_names = pickle.load(fp)

with open("/content/drive/My Drive/val2017/test_2.txt", "rb") as fp:   # Unpickling
    test_images_names = pickle.load(fp)


In [5]:
# Number of Images Before Removing Black and White Images :   5085
# Number of Images after Removing Black and White Images :   4958

In [6]:
epochs = 0
val_best_loss = np.inf
dimension = 768
alpha = 10
disc_iteration =  1

batch_size_train = 3
batch_size_test = 16

params_train = {
        'batch_size' : batch_size_train ,
        'shuffle': True , 
        'num_workers': 0
}        

params_test = {
        'batch_size' : batch_size_test ,
        'shuffle': False , 
        'num_workers': 0
}

train_data = Dataset(train_images_names, images_path, 'Train')
train_data_generator = torch.utils.data.DataLoader(train_data, **params_train, drop_last=True)

test_data = Dataset(test_images_names, images_path, 'Test')
test_data_generator = torch.utils.data.DataLoader(test_data, **params_test,  drop_last=False)

number_of_train_batches = len(train_images_names) // batch_size_train 
number_of_test_batches = len(test_images_names) // batch_size_test 

print("\nLentgh of Train Images : {}".format(len(train_images_names)))
print("Lentgh of Test Images : {}\n".format(len(test_images_names)))
print("Number of train batches :",number_of_train_batches)
print("Number of test batches :",number_of_test_batches)



Lentgh of Train Images : 4500
Lentgh of Test Images : 458

Number of train batches : 1500
Number of test batches : 28


In [7]:
class patch_GAN(torch.nn.Module): # 3 * 768 * 768
  def __init__(self):
    super().__init__()
    self.conv1=torch.nn.Conv2d(3, 4, kernel_size = 4, stride = 2, padding = 1 ) # size =  torch.Size([batch_size, 4, 384, 384])
    self.bn1 = torch.nn.BatchNorm2d(4)
    self.relu=torch.nn.ReLU() 

    self.conv2=torch.nn.Conv2d(4, 4, kernel_size = 4, stride = 2, padding = 1 ) # size =  torch.Size([batch_size, 4, 192, 192])
    self.bn2 = torch.nn.BatchNorm2d(4)

    self.conv3=torch.nn.Conv2d(4, 4, kernel_size = 4, stride = 2, padding = 1 ) # size =  torch.Size([batch_size, 4, 96, 96])
    self.bn3 = torch.nn.BatchNorm2d(4)

    self.conv4=torch.nn.Conv2d(4, 4, kernel_size = 4, stride = 1, padding = 1 ) # size =  torch.Size([batch_size, 4, 95, 95])
    self.bn4 = torch.nn.BatchNorm2d(4)

    self.conv5=torch.nn.Conv2d(4, 1, kernel_size = 4, stride = 1, padding = 1 ) # size =  torch.Size([batch_size, 4, 94, 94])
    self.sigmoid=torch.nn.Sigmoid() 


  def forward(self, x):
    out=self.conv1(x)
    out=self.bn1(out)
    out=self.relu(out)

    out=self.conv2(out)
    out=self.bn2(out)
    out=self.relu(out)

    out=self.conv3(out)
    out=self.bn3(out)
    out=self.relu(out)

    out=self.conv4(out)
    out=self.bn4(out)
    out=self.relu(out)

    out=self.conv5(out)
    out=self.sigmoid(out) # size =  torch.Size([batch_size, 1, 94, 94])

    return out 

discriminator = patch_GAN()
x  = torch.randn(1, 3, dimension, dimension)  # A random image
output = discriminator(x)
print("Output shape of Discriminator : ", output.shape)

discriminator_output_dimension = output.shape[2] # 94


Output shape of Discriminator :  torch.Size([1, 1, 94, 94])


In [None]:
# https://github.com/qubvel/segmentation_models.pytorch
generator = smp.Unet(
    encoder_name="efficientnet-b5", # choose encoder, e.g. xception , mobilenet_v2 or efficientnet-b7
    encoder_weights="imagenet",     # use `imagenet` pre-trained weights for encoder initialization
    in_channels=1,                  # model input channels (1 for gray-scale images, 3 for RGB, etc.)
    classes=2
)
discriminator = patch_GAN()

generator.float()
discriminator.float()


if cuda.is_available():
         generator = generator.cuda()
         discriminator = discriminator.cuda()

BCE_loss = torch.nn.BCELoss()
optimizer_G = torch.optim.Adam(generator.parameters(), lr=0.001)
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=0.001)

checkpoint=torch.load("/content/drive/My Drive/val2017/Generator_2.pth")
generator.load_state_dict(checkpoint)

checkpoint=torch.load("/content/drive/My Drive/val2017/Discriminator_2.pth")
discriminator.load_state_dict(checkpoint)

checkpoint = torch.load("/content/drive/My Drive/val2017/Optimizer_G_2.pth")
optimizer_G.load_state_dict(checkpoint)

checkpoint = torch.load("/content/drive/My Drive/val2017/Optimizer_D_2.pth")
optimizer_D.load_state_dict(checkpoint)

with open("/content/drive/My Drive/val2017/Epoch_2.txt", "r") as f:   
   epochs = int(f.read())

In [None]:
########################################### TRAINING ##########################################
generator = generator.train()
for epoch in range(epochs+1, epochs + 100):

  disc_losses = []
  gen_losses = []
  for colored_images,  BW_images in train_data_generator:
    
     for k in range(disc_iteration):
        #  Colored_images(Original image): tensor (batch_size, 2, 768, 768)  |  BW_images : Corrosponds to Colored_images  , tensor (batch_size, 1, 768, 768)
        # In fact-> colored_images contain ab-channels for the images (Real version) ,
        #          BW_images contain L-channel for the images (Real version)
        if cuda.is_available():
            BW_images = BW_images.cuda()
            colored_images = colored_images.cuda()

        with torch.no_grad():
           Generated_colored_image = generator(BW_images).detach() # tensor (batch_size, 2, 768, 768)
           # In fact-> Generated_colored_image contain the predicted values for ab-channels (Fake version)


        BW_and_Generated_images = torch.cat((BW_images, Generated_colored_image), dim = 1) # tensor (batch_size, 3, 768, 768)
        BW_and_colored_images = torch.cat((BW_images, colored_images), dim = 1) # tensor (batch_size, 3, 768, 768)
        
        label_0_for_BW_and_Generated_images = torch.zeros(batch_size_train, 1, discriminator_output_dimension, discriminator_output_dimension)
        label_1_for_BW_and_colored_images = torch.ones(batch_size_train, 1, discriminator_output_dimension, discriminator_output_dimension)

        x = torch.cat((BW_and_Generated_images, BW_and_colored_images), dim = 0) # tensor (2*batch_size, 3, 768, 768)
        y = torch.cat((label_0_for_BW_and_Generated_images,label_1_for_BW_and_colored_images), dim = 0) # tensor (2*batch_size, 1, 94, 94)

        c = list(zip(x, y))
        random.shuffle(c)
        x, y = zip(*c) 

        
        x = torch.stack(list(x), dim=0) # tensor (2*batch_size, 3, 768, 768)
        y = torch.stack(list(y), dim=0) # tensor (2*batch_size, 1, 94, 94)
        
        if cuda.is_available():
            y = y.cuda()
            x = x.cuda()

        dis = discriminator(x.float()) # tensor (2*batch_size, 1, 94, 94) 
        dis = dis.view(-1, 1)  # tensor (2*batch_size*94*94 , 1) 
        y = y.view(-1, 1).float()  # tensor (2*batch_size*94*94 , 1) 

        disc_loss = BCE_loss(dis, target = y) * 0.5

        optimizer_D.zero_grad()
        disc_loss.backward()
        optimizer_D.step()
        disc_losses.append(disc_loss.item())

        

     ################################################################################################
     if cuda.is_available():
            BW_images = BW_images.cuda()

     Generated_colored_image = generator(BW_images) # tensor (batch_size, 2, 768, 768)

     BW_and_Generated_images = torch.cat((BW_images, Generated_colored_image), dim = 1) # tensor (batch_size, 3, 768, 768)


     if cuda.is_available():
            BW_and_Generated_images = BW_and_Generated_images.cuda()
            colored_images = colored_images.cuda()

     dis = discriminator(BW_and_Generated_images.float()) # tensor (batch_size, 1, 94, 94) 
     dis = dis.view(-1, 1)  # tensor (batch_size*94*94 , 1) 
     y = torch.ones(batch_size_train, 1, discriminator_output_dimension, discriminator_output_dimension)
     y = y.view(-1, 1).float()  # tensor (batch_size*94*94 , 1) 
     if cuda.is_available():
            y = y.cuda()

     pixel_distance_loss_tensor =  torch.subtract(Generated_colored_image, colored_images) # tensor (batch_size, 3, 768, 768)
     pixel_distance_loss_tensor =  pixel_distance_loss_tensor.view(pixel_distance_loss_tensor.shape[0], -1) # tensor (batch_size, 3*768*768)
     pixel_distance_loss_tensor_norm = torch.norm(pixel_distance_loss_tensor, p=1, dim=1).reshape(-1,1) # tensor (batch_size, 1)
     pixel_distance_loss = torch.mean(pixel_distance_loss_tensor_norm,0) # a number

     gen_loss =   BCE_loss(dis, target = y) + alpha * pixel_distance_loss  # a number

     optimizer_G.zero_grad()   
     gen_loss.backward()
     optimizer_G.step()

     gen_losses.append(gen_loss.item())



     ################################################################################################

  print("Iteration : {} -> Train , D_LOSS : {}".format(epoch, np.mean(disc_losses)))
  print("Iteration : {} -> Train , G_LOSS : {}".format(epoch, np.mean(gen_losses)))
  
  if epoch%3==0:
      torch.save(generator.state_dict(),"/content/drive/My Drive/val2017/Generator_"+str(epoch)+"_2.pth")
      torch.save(discriminator.state_dict(),"/content/drive/My Drive/val2017/Discriminator_"+str(epoch)+"_2.pth")
      torch.save(optimizer_G.state_dict(),"/content/drive/My Drive/val2017/Optimizer_G_"+str(epoch)+"_2.pth")
      torch.save(optimizer_D.state_dict(),"/content/drive/My Drive/val2017/Optimizer_D_"+str(epoch)+"_2.pth")
      print("[SAVED !]")


  torch.save(generator.state_dict(),"/content/drive/My Drive/val2017/Generator_2.pth")
  torch.save(discriminator.state_dict(),"/content/drive/My Drive/val2017/Discriminator_2.pth")
  torch.save(optimizer_G.state_dict(),"/content/drive/My Drive/val2017/Optimizer_G_2.pth")
  torch.save(optimizer_D.state_dict(),"/content/drive/My Drive/val2017/Optimizer_D_2.pth")


  with open("/content/drive/My Drive/val2017/Epoch_2.txt", "w+") as f :   
      f.write(str(epoch))

  with open("/content/drive/My Drive/val2017/Results_2.txt", "a+") as f:   
     f.write("\n Iteration : {} -> Train D_LOSS : {}".format(epoch, np.mean(disc_losses)))
     f.write("\n Iteration : {} -> Train G_LOSS : {}".format(epoch, np.mean(gen_losses)))


  print("______________________________________________")

    