In [1]:
##################### Importin necessary laibraries ############
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)
import os
import sys
import cv2
import math
import numpy as np
import torch.utils.data
import torchvision.transforms as transforms
import PIL
import random
from scipy import ndimage
import glob
import tensorflow as tf
import matplotlib.pyplot as plt
import torch
import torch.nn.functional as F
from torch import nn
from torch.utils.data import DataLoader,Dataset
from torchvision import transforms
from torchvision.utils import save_image
from time import time
from skimage.data import astronaut,rocket
from skimage.metrics import structural_similarity
from PIL import Image
from statistics import mean
# %pip install -U git+https://github.com/szagoruyko/pytorchviz.git@master
# from torchviz import make_dot
from torch.autograd import Variable
device = ('cuda' if torch.cuda.is_available() else 'cpu')
# !pip install torchmetrics

In [2]:
################# Creating datasets ##############

################# Training Dataset ############
class train_Dataset(torch.utils.data.Dataset):
    def __init__(self):
        super(train_Dataset, self).__init__()
        self.paths_train=glob.glob('../input/massachusetts-roads-dataset/tiff/train/*.tiff')
        self.paths_train.sort()
        self.paths_train_label=glob.glob('../input/massachusetts-roads-dataset/tiff/train_labels/*.tif')
        self.paths_train_label.sort()
        
    def __getitem__(self,idx):
        image =Image.open(self.paths_train[idx])
        label =Image.open(self.paths_train_label[idx])
#         kernel = np.array([[0,0,-1,0,0],
#                           [0,-1,-2,-1,0],
#                           [-1,-2,16,-2,-1],
#                           [0,-1,-2,-1,0],
#                           [0,0,-1,0,0]])
        kernel = np.array([[0,-1,0],
                          [-1,5,-1],
                          [0,-1,0]])
        np_img=np.array(image)
        convolved_img = cv2.filter2D(np_img , -1 , kernel)
        image_tensor = transforms.ToTensor()(image)
        label_tensor = transforms.ToTensor()(label)
        image_resized=image_tensor[:,328:840,328:840]
        label_resized=label_tensor[:,328:840,328:840]
        return image_resized,label_resized
          
    def __len__(self):
        return len(self.paths_train_label)

################# Testing Dataset ############
class test_Dataset(torch.utils.data.Dataset):
    def __init__(self):
        super(test_Dataset, self).__init__()
        self.paths_test=glob.glob('../input/massachusetts-roads-dataset/tiff/val/*.tiff')
        self.paths_test.sort()
        self.paths_test_label=glob.glob('../input/massachusetts-roads-dataset/tiff/val_labels/*.tif')
        self.paths_test.sort()

    def __getitem__(self,idx):
        image =Image.open(self.paths_test[idx])
        label =Image.open(self.paths_test_label[idx])
#         kernel = np.array([[0,0,-1,0,0],
#                           [0,-1,-2,-1,0],
#                           [-1,-2,16,-2,-1],
#                           [0,-1,-2,-1,0],
#                           [0,0,-1,0,0]])
        kernel = np.array([[0,-1,0],
                          [-1,5,-1],
                          [0,-1,0]])
        
        np_img=np.array(image)
        convolved_img = cv2.filter2D(np_img , -1 , kernel)
        image_tensor = transforms.ToTensor()(convolved_img)
        label_tensor = transforms.ToTensor()(label)
        image_resized=image_tensor[:,328:840,328:840]
        label_resized=label_tensor[:,328:840,328:840]
        return image_resized,label_resized
    
    def __len__(self):
        return len(self.paths_test_label)

In [3]:
################# Defining Dataloaders ############
'''Dataloader provides desired data from dataset with specific batch size '''

train_data=train_Dataset()
test_data=test_Dataset()
train_loader = DataLoader(train_data, batch_size = 1,num_workers = 0)
test_loader = DataLoader(test_data, batch_size = 1, num_workers = 0)

In [None]:
########## Numbers of images in train & test dataset ###########
print(len(train_data))
print(len(test_data))

In [4]:
################## Observing arial image and ground truth ###############
for batch_i, (arial_image,ground_truth) in enumerate(train_loader):
    if batch_i == 30 :
        fig, ax = plt.subplots(1,2, figsize =(25,25))
        ax[0].imshow(arial_image[0].permute(1,2,0))
        ax[1].imshow(ground_truth[0].permute(1,2,0),'gray')
        print(batch_i)
        break

In [5]:
############ Defining classes which are used in Network #############

############ Transition down layer is used in downsampling path generator and discriminator #############
class Transition_Down(nn.Module): 
  def __init__(self, input_channels):
    super(Transition_Down, self).__init__()
    self.TD = nn.Sequential(nn.Dropout2d(0.5),
                            nn.Conv2d(input_channels,2*input_channels,kernel_size=4,padding =1,stride=2),
                            nn.BatchNorm2d(2*input_channels),
                            nn.LeakyReLU(negative_slope=0.01))
    
  def forward(self,x):
    out = self.TD(x)
    return out  
    
####################################################################################################

############ Transition up layer is used in upsampling path of generator  #############
class Transition_Up(nn.Module): 
  def __init__(self, input_channels,output_channels):
    super(Transition_Up, self).__init__()
    self.TU = nn.Sequential(nn.Dropout2d(0.5),
                            nn.ConvTranspose2d(input_channels,output_channels,kernel_size=4,padding=1,stride=2),
                            nn.BatchNorm2d(output_channels),
                            nn.ReLU())
  def forward(self,x):
    out = self.TU(x)
    return out  
     
    

In [6]:
############ Defining Generator & Discriminator by using the classes which are defined previously #############
############ Generator is used to produce segmentation map ############
class Generator(nn.Module):
  def __init__(self):
    super(Generator,self).__init__()
    self.conv  = nn.Conv2d( 3, 32, kernel_size = 4 , padding = 1 ,stride=2)
    self.LReLU = nn.LeakyReLU(negative_slope=0.1)
    self.TD1   = Transition_Down(32)
    self.TD2   = Transition_Down(64)
    self.TD3   = Transition_Down(128)
    self.TD4   = Transition_Down(256)
    self.TU1   = Transition_Up (512,256)
    self.TU2   = Transition_Up (512,128)
    self.TU3   = Transition_Up (256,64)
    self.TU4   = Transition_Up (128,32)
    self.deconv= nn.ConvTranspose2d(64,1,kernel_size=4,padding=1,stride=2)
    self.relu = nn.ReLU()
    self.Tan  = nn.Tanh()
  def forward(self, input):
    y   = self.LReLU(self.conv(input))
    TD1 = self.TD1(y)
    TD2 = self.TD2(TD1)
    TD3 = self.TD3(TD2)
    TD4 = self.TD4(TD3)
    TU1 = self.TU1(TD4)
    C1  = torch.cat((TD3, TU1), 1) # Concatenation
    TU2 = self.TU2(C1)
    C2  = torch.cat((TD2, TU2), 1)
    TU3 = self.TU3(C2)
    C3  = torch.cat((TD1, TU3), 1)
    TU4 = self.TU4(C3)
    C4  = torch.cat((y, TU4), 1)
    output = self.Tan(self.deconv(C4))
    return output

############### Discriminator is used to distinguish between ground truth & fake segmentations ######## 
class Discriminator(nn.Module):
  def __init__(self):
    super(Discriminator,self).__init__()
    self.conv1 = nn.Conv2d( 4, 32, kernel_size = 4 , padding = 1 ,stride=2)
    self.LReLU = nn.LeakyReLU(negative_slope=0.1)
    self.TD1   = Transition_Down(32)
    self.TD2   = Transition_Down(64)
    self.TD3   = Transition_Down(128)
    self.TD4   = Transition_Down(256)
    self.conv2 = nn.Conv2d( 512, 1, kernel_size = 4 , padding = 1 ,stride=2)
    self.Sigm  = nn.Sigmoid()
  def forward(self, input,ground_truth):
    y_input = torch.cat((input,ground_truth),1)
    y   = self.LReLU(self.conv1(y_input))
    TD1 = self.TD1(y)
    TD2 = self.TD2(TD1)
    TD3 = self.TD3(TD2)
    TD4 = self.TD4(TD3)
    output = self.Sigm(self.conv2(TD4))
    return output

In [7]:
############### Testing the functionality of generator & discriminator by giving random tensor)
test_gen = Generator()
test_disc = Discriminator()
input_of_generator = torch.randn((1,3,512,512)) # random tensor
output_of_generator=test_gen.forward(input_of_generator)
output_of_discriminator=test_disc.forward(input_of_generator,output_of_generator)
print(output_of_generator.shape)
print(output_of_discriminator.shape)

In [None]:
############ Plotting architecture of generator & discriminator ###########
iput_of_generator = torch.randn((1,3,512,512))
iput_of_discriminator = torch.randn((1,1,512,512))
mymodel_generator=Generator()
mymodel_discriminator=Discriminator()
make_dot(mymodel_generator(iput_of_generator), params=dict(mymodel.named_parameters())).render("Generator_model.jpeg")

In [8]:
########## Defining Loss Function classes for each network ###########
class MaxminLoss(nn.Module):
  def __init__(self):
      super(MaxminLoss,self).__init__()
        
  def forward(self,real_output,fake_output):    
      loss = torch.log(real_output) + torch.log(1-fake_output)
      return loss


In [9]:
################# Initializing networks and parameters part ##############
D = Discriminator().to(device)
G = Generator().to(device)
BCE = nn.BCELoss()
BCE_with_Sigmoid = nn.BCEWithLogitsLoss()
MSE=nn.MSELoss()
d_optim = torch.optim.Adam(D.parameters() , lr = 0.01, betas = (0.9, 0.999))
g_optim = torch.optim.Adam(G.parameters() , lr = 0.01, betas = (0.9, 0.999))

In [None]:
'''
 DO NOT RUN THIS BLOCK !!! 
 this block is used to train the network but because I have trained the model and
 you can see the information of training the trained model is available 
 there is no need to train model and another reason that I don't recommend to run
 this block is that it takes a long time.
 If
'''
###################### Training Loop #####################
batch_size=1
lambdaa = 0.9
d_losses = []
g_losses = []
N_EPOCHS = 20

for epoch in range(N_EPOCHS):
    for batch_i, (arial_image,ground_truth) in enumerate(train_loader):
        arial_image=Variable(arial_image.to(device))
        ground_truth=Variable(ground_truth.to(device))
        #######################
        ##Training Discriminator##
        #######################
        #real segmentations
        real_outputs = D.forward(ground_truth,arial_image) # >>>> Expected all ones
        real_outputs_squeezed=real_outputs.squeeze()
        ones=Variable(torch.ones(real_outputs_squeezed.size()).to(device)) ### producing 8*8 kernel of ones for patchGAN
        d_loss_real  = BCE(real_outputs_squeezed, ones)
        #fake segmentations
        fake_segmentations = G.forward(arial_image)
        fake_outputs = D.forward(fake_segmentations,arial_image) # >>> Expected all zeros
        fake_outputs_squeezed=fake_outputs.squeeze()
        zeros=Variable(torch.zeros(fake_outputs_squeezed.size()).to(device))### producing 8*8 kernel of ones for patchGAN
        d_loss_fake = BCE(fake_outputs_squeezed, zeros)
        # GD Step
        d_loss = 0.5 * (d_loss_real + d_loss_fake)
        d_optim.zero_grad()
        g_optim.zero_grad()
        d_loss.backward()
        d_optim.step()
        
        #######################
        ####Training Generator####
        #######################
        #fake images
        fake_segmentations = G.forward(arial_image)
        fake_outputs = D.forward(fake_segmentations,arial_image)
        fake_outputs_squeezed=fake_outputs.squeeze()
        #the actual labels are zeros
        #since we want to trick the discriminator, we say it's ones.
        g_loss = BCE(fake_outputs_squeezed, ones) ##### VERY IMPORTANT
        d_optim.zero_grad()
        g_optim.zero_grad()
        g_loss.backward()
        g_optim.step()
        d_losses.append(d_loss.item())
        g_losses.append(g_loss.item())
        if batch_i % 10 == 0 :
            train_string1 = f"information of each batch => batch : {batch_i}// g_loss : {g_loss.item(): 0.4f} // d_loss : {d_loss.item(): 0.4f} "
            print(train_string1) 
    g_loss_mean=mean(g_losses)
    d_loss_mean=mean(d_losses)
    train_string2 = f"information of each epoch => epoch : {epoch}// g_loss_mean : {g_loss_mean: 0.4f} // d_loss : {d_loss_mean: 0.4f} "
    print(train_string2 ) 
    


In [None]:
''' Here we have plotted Discriminator and Generator loss together for 150 
    first samples to show that they are changing and competing with each other '''
plt.plot(g_losses, label = 'g_losses')
plt.plot(d_losses, label = 'd_losses')
plt.legend()

In [None]:
################# Saving the weight of trained generator ######################
torch.save(G.state_dict(), 'Road_GAN_Generator_1.pt') 
# torch.save(D.state_dict(), 'Road_GAN_Discriminator1.pt')

In [10]:
'''
 here you can use my saved model that I have sent. just upload it to your colab or kaggle and then write its directory
 here -> (torch.load(directory))
'''
mysaved_model=Generator()
mysaved_model.load_state_dict(torch.load('../input/generator-1/Road_GAN_Generator_1.pt',map_location=torch.device('cpu')))

In [11]:
########################### Quantitative Evaluation Functions ####################### 
def Pixels_Counter(segmentation,ground_truth):
    segmentation= (segmentation + 1)/2
    TP=0
    TN=0
    FP=0 
    FN=0
    for i in range(1):
        for j in range(512):
            for k in range(512):
                if segmentation[i,j,k] >= 0.9 :
                    segmentation[i,j,k]=1
                    if segmentation[i,j,k] == ground_truth[i,j,k]:
                        TP = TP + 1 
                    else :
                        FP = FP + 1 
                else:
                    segmentation[i,j,k]=0
                    if segmentation[i,j,k] == ground_truth[i,j,k]:
                        TN = TN + 1 
                    else :
                        FN = FN + 1 

    return TP,TN,FP,FN
##################################################################
############### Matthews Correlation Coeiffivient ################
def MCC (TP,TN,FP,FN):
    MCC= ((TP*TN)-(FP*FN))/math.sqrt((TP+FN)*(TP+FP)*(TN+FN)*(TN+FP))
    return MCC
##################################################################
########################### Recall ###############################
def RECALL (TP,FN):
    recall= TP/(TP+FN)
    return recall         
##################################################################
########################## Precision #############################
def Precision (TP,FP):
    precision= TP/(TP+FP)
    return precision        
##################################################################
########################## F1 Score ##############################
def F1 (Precision,Recall):
    f1= (2*Precision*Recall)/(Precision+Recall)
    return f1        
##################################################################
############### Mean Intersection Over Union ################
def MIOU (TP,FP,FN):
    miou= TP/(TP+FP+FN)
    return miou        
          

In [12]:
##################### Testing the generator model on testing dataset #####################
'''  In this block we are evaluating method visually and theorically by using metrcis which have been menthioned in the paper'''

path11=glob.glob("../input/massachusetts-roads-dataset/tiff/test/10378780_15.tiff")
path21=glob.glob("../input/massachusetts-roads-dataset/tiff/test_labels/10378780_15.tif")
path12=glob.glob("../input/massachusetts-roads-dataset/tiff/test/12328750_15.tiff")
path22=glob.glob("../input/massachusetts-roads-dataset/tiff/test_labels/12328750_15.tif")
path13=glob.glob("../input/massachusetts-roads-dataset/tiff/test/20728960_15.tiff")
path23=glob.glob("../input/massachusetts-roads-dataset/tiff/test_labels/20728960_15.tif")

arial_image1 =Image.open(path11[0])
ground_truth1=Image.open(path21[0])
arial_image1 = transforms.ToTensor()(arial_image1)
ground_truth1 = transforms.ToTensor()(ground_truth1)
arial_image1=arial_image1[:,128:640,128:640]
arial_image1=arial_image1.reshape(1,3,512,512)
ground_truth1=ground_truth1[:,128:640,128:640]
Segmentation1=mysaved_model.forward(arial_image1).detach()

fig, ax = plt.subplots(1,3, figsize =(15,15))
ax[0].imshow(arial_image1[0].permute(1,2,0),)
ax[1].imshow(Segmentation1[0].permute(1,2,0),'gray')
ax[2].imshow(ground_truth1.permute(1,2,0),'gray')
TP1,TN1,FP1,FN1=Pixels_Counter(Segmentation1[0],ground_truth1)
mcc=MCC(TP1,TN1,FP1,FN1)
recall=RECALL(TP1,FN1)
precision=Precision(TP1,FP1)
f1=F1(precision,recall)
miou=MIOU(TP1,FP1,FN1)
print(f" True_Positive = {TP1} //  True_Nagative = {TN1} //  False_Positive = {FP1} //  False_Nagative = {FN1}")
print(f" MCC : {mcc: 0.4f} // Reacll : {recall: 0.4f} // Precision : {precision: 0.4f} // F1 : {f1: 0.4f} // MIOU : {miou: 0.4f}")      

arial_image2 =Image.open(path12[0])
ground_truth2=Image.open(path22[0])
arial_image2 = transforms.ToTensor()(arial_image2)
ground_truth2 = transforms.ToTensor()(ground_truth2)
arial_image2=arial_image2[:,128:640,128:640]
arial_image2=arial_image2.reshape(1,3,512,512)
ground_truth2=ground_truth2[:,128:640,128:640]
Segmentation2=mysaved_model.forward(arial_image2).detach()

fig, ax = plt.subplots(1,3, figsize =(15,15))
ax[0].imshow(arial_image2[0].permute(1,2,0),)
ax[1].imshow(Segmentation2[0].permute(1,2,0),'gray')
ax[2].imshow(ground_truth2.permute(1,2,0),'gray')
TP2,TN2,FP2,FN2=Pixels_Counter(Segmentation2[0],ground_truth2)
mcc=MCC(TP2,TN2,FP2,FN2)
recall=RECALL(TP2,FN2)
precision=Precision(TP2,FP2)
f1=F1(precision,recall)
miou=MIOU(TP2,FP2,FN2)
print(f" True_Positive = {TP2} //  True_Nagative = {TN2} //  False_Positive = {FP2} //  False_Nagative = {FN2}")
print(f" MCC : {mcc: 0.4f} // Reacll : {recall: 0.4f} // Precision : {precision: 0.4f} // F1 : {f1: 0.4f} // MIOU : {miou: 0.4f}")      

arial_image3 =Image.open(path13[0])
ground_truth3=Image.open(path23[0])
arial_image3 = transforms.ToTensor()(arial_image3)
ground_truth3 = transforms.ToTensor()(ground_truth3)
arial_image3=arial_image3[:,128:640,128:640]
arial_image3=arial_image3.reshape(1,3,512,512)
ground_truth3=ground_truth3[:,128:640,128:640]
Segmentation3=mysaved_model.forward(arial_image3).detach()

fig, ax = plt.subplots(1,3, figsize =(15,15))
ax[0].imshow(arial_image3[0].permute(1,2,0),)
ax[1].imshow(Segmentation3[0].permute(1,2,0),'gray')
ax[2].imshow(ground_truth3.permute(1,2,0),'gray')
TP3,TN3,FP3,FN3=Pixels_Counter(Segmentation3[0],ground_truth3)
mcc=MCC(TP3,TN3,FP3,FN3)
recall=RECALL(TP3,FN3)
precision=Precision(TP3,FP3)
f1=F1(precision,recall)
miou=MIOU(TP3,FP3,FN3)
print(f" True_Positive = {TP3} //  True_Nagative = {TN3} //  False_Positive = {FP3} //  False_Nagative = {FN3}")
print(f" MCC : {mcc: 0.4f} // Reacll : {recall: 0.4f} // Precision : {precision: 0.4f} // F1 : {f1: 0.4f} // MIOU : {miou: 0.4f}")      
