# A. Libraries 📚⬇

In [None]:
import os, cv2
import numpy as np
import pandas as pd
import random, tqdm
import seaborn as sns
import matplotlib.pyplot as plt
%matplotlib inline

import warnings
warnings.filterwarnings("ignore")

import torch
import torch.nn as nn
from torch.utils.data import DataLoader

import albumentations as album

In [None]:
!pip install -q -U segmentation-models-pytorch albumentations > /dev/null
import segmentation_models_pytorch as smp

# B. Read Data & Create train / valid splits 📁

## 1. Read Data

In [None]:
#DATA_DIR = '/kaggle/input/deepglobe-road-extraction-dataset'
DATA_DIR = '/kaggle/input/dataset-new/Dataset/Satellite Road Extraction Dataset'

metadata_df = pd.read_csv(os.path.join(DATA_DIR, 'metadata.csv'))
metadata_df = metadata_df[metadata_df['split']=='train']
metadata_df = metadata_df[['image_id', 'sat_image_path', 'mask_path']]
metadata_df['sat_image_path'] = metadata_df['sat_image_path'].apply(lambda img_pth: os.path.join(DATA_DIR, img_pth))
metadata_df['mask_path'] = metadata_df['mask_path'].apply(lambda img_pth: os.path.join(DATA_DIR, img_pth))
# Shuffle DataFrame
metadata_df = metadata_df.sample(frac=1).reset_index(drop=True)

## 2. Create train / valid splits

In [None]:
# Perform 90/10 split for train / val
valid_df = metadata_df.sample(frac=0.1, random_state=42)
train_df = metadata_df.drop(valid_df.index)
len(train_df), len(valid_df)

In [None]:
class_dict = pd.read_csv(os.path.join(DATA_DIR, 'class_dict.csv'))
# Get class names
class_names = class_dict['name'].tolist()
# Get class RGB values
class_rgb_values = class_dict[['r','g','b']].values.tolist()

print('All dataset classes and their corresponding RGB values in labels:')
print('Class Names: ', class_names)
print('Class RGB values: ', class_rgb_values)

## 3. Shortlist specific classes to segment

In [None]:
# Useful to shortlist specific classes in datasets with large number of classes
select_classes = ['background', 'road']

# Get RGB values of required classes
select_class_indices = [class_names.index(cls.lower()) for cls in select_classes]
select_class_rgb_values =  np.array(class_rgb_values)[select_class_indices]

print('Selected classes and their corresponding RGB values in labels:')
print('Class Names: ', class_names)
print('Class RGB values: ', class_rgb_values)

# C. Helper functions for viz. & one-hot encoding/decoding

In [None]:
# helper function for data visualization
def visualize(**images):
    
    """
    Plot images in one row
    """
    
    n_images = len(images) #count number of images
    plt.figure(figsize=(20,8)) #set the figure size 20 by 8 inches
    for idx, (name, image) in enumerate(images.items()): #iterate through the images
        plt.subplot(1, n_images, idx + 1) #create a subplot for each image
        plt.xticks([]); #hide x ticks
        plt.yticks([]) #hide y ticks
        # get title from the parameter names
        plt.title(name.replace('_',' ').title(), fontsize=20) #set the title of the image
        plt.imshow(image) #show the image
    plt.show() #display the plot

# Perform one hot encoding on label
def one_hot_encode(label, label_values): #label is the image, label_values is the RGB values of the classes
    
    """
    Convert a segmentation image label array to one-hot format
    by replacing each pixel value with a vector of length num_classes
    # Arguments
        label: The 2D array segmentation image label
        label_values
    # Returns
        A 2D array with the same width and hieght as the input, but
        with a depth size of num_classes
    """
    
    semantic_map = [] #create an empty list
    for colour in label_values: #iterate through the RGB values of the classes
        equality = np.equal(label, colour) #check if the pixel value is equal to the RGB value of the class
        class_map = np.all(equality, axis = -1) #check if all the values in the array are True
        semantic_map.append(class_map) #append the class map to the semantic map list
    semantic_map = np.stack(semantic_map, axis=-1) #stack the semantic map list along the depth axis

    return semantic_map #return the semantic map
    
# Perform reverse one-hot-encoding on labels / preds
def reverse_one_hot(image): #image is the one-hot encoded image
    
    """
    Transform a 2D array in one-hot format (depth is num_classes),
    to a 2D array with only 1 channel, where each pixel value is
    the classified class key.
    # Arguments
        image: The one-hot format image 
    # Returns
        A 2D array with the same width and hieght as the input, but
        with a depth size of 1, where each pixel value is the classified 
        class key.
    """
    
    x = np.argmax(image, axis = -1) #get the class key of the one-hot encoded image
    return x #return the class key

# Perform colour coding on the reverse-one-hot outputs
def colour_code_segmentation(image, label_values): #image is the class key, label_values is the RGB values of the classes
   
    """
    Given a 1-channel array of class keys, colour code the segmentation results.
    # Arguments
        image: single channel array where each value represents the class key.
        label_values
    # Returns
        Colour coded image for segmentation visualization
    """
    
    colour_codes = np.array(label_values) #get the RGB values of the classes
    x = colour_codes[image.astype(int)] #get the RGB values of the classes based on the class key

    return x #return the RGB image

In [None]:
class RoadsDataset(torch.utils.data.Dataset):                   #create a custom dataset class

    """DeepGlobe Road Extraction Challenge Dataset. Read images, apply augmentation and preprocessing transformations.
    
    Args:
        df (str): DataFrame containing images / labels paths
        class_rgb_values (list): RGB values of select classes to extract from segmentation mask
        augmentation (albumentations.Compose): data transfromation pipeline 
            (e.g. flip, scale, etc.)
        preprocessing (albumentations.Compose): data preprocessing 
            (e.g. noralization, shape manipulation, etc.)
    """
    
    def __init__(                                               #initialize the class
            self, 
            df,
            class_rgb_values=None, 
            augmentation=None, 
            preprocessing=None,
            target_size=(1024, 1024)
    ):
        self.image_paths = df['sat_image_path'].tolist()        #get the image paths
        self.mask_paths = df['mask_path'].tolist()              #get the mask paths
        
        self.class_rgb_values = class_rgb_values                #get the RGB values of the classes
        self.augmentation = augmentation                        #get the augmentation
        self.preprocessing = preprocessing                      #get the preprocessing
        self.target_size = target_size                          #get the target size
    
    def __getitem__(self, i):                                   #get the item at index i
        
        # read images and masks
        image = cv2.cvtColor(cv2.imread(self.image_paths[i]), cv2.COLOR_BGR2RGB)
        mask = cv2.cvtColor(cv2.imread(self.mask_paths[i]), cv2.COLOR_BGR2RGB)
        
        # one-hot-encode the mask
        mask = one_hot_encode(mask, self.class_rgb_values).astype('float')
        
        # apply augmentations
        if self.augmentation:                                   #if augmentation is not None
            sample = self.augmentation(image=image, mask=mask)  #apply the augmentation
            image, mask = sample['image'], sample['mask']       #get the augmented image and mask

        image = cv2.resize(image, self.target_size)             #resize the image
        mask = cv2.resize(mask, self.target_size)               #resize the mask
        
        # apply preprocessing
        if self.preprocessing: #if preprocessing is not None
            sample = self.preprocessing(image=image, mask=mask) #apply the preprocessing
            image, mask = sample['image'], sample['mask']       #get the preprocessed image and mask

        
        return image, mask                                      #return the image and mask
        
    def __len__(self):                                          #return the length of the dataset
        # return length of
        return len(self.image_paths)

## 3. Visualize Sample Image and Mask 📈

In [None]:
dataset = RoadsDataset(train_df, class_rgb_values=select_class_rgb_values)  #create the dataset
random_idx = random.randint(0, len(dataset)-1)                              #get a random index
image, mask = dataset[2]                                                    #get the image and mask at the random index

visualize(                                                                  #visualize the image and mask
    original_image = image,                                                 #original image
    #ground truth mask
    ground_truth_mask = colour_code_segmentation(reverse_one_hot(mask), select_class_rgb_values),
    one_hot_encoded_mask = reverse_one_hot(mask)                            #one-hot encoded mask
)

## 4. Defining Augmentations 🙃

In [None]:
def get_training_augmentation():                    #create a function to get the training augmentation
    train_transform = [                             #create a list of augmentations
        album.HorizontalFlip(p=0.5),                #apply horizontal flip with a probability of 0.5
        album.VerticalFlip(p=0.5),                  #apply vertical flip with a probability of 0.5
    ]
    return album.Compose(train_transform)           #return the augmentation


def to_tensor(x, **kwargs):                         #create a function to convert the image and mask to tensor
    return x.transpose(2, 0, 1).astype('float32')   #return the image and mask as a tensor


def get_preprocessing(preprocessing_fn=None):       #create a function to get the preprocessing transform

    """Construct preprocessing transform    
    Args:
        preprocessing_fn (callable): data normalization function 
            (can be specific for each pretrained neural network)
    Return:
        transform: albumentations.Compose
    """
    
    _transform = []                                                     #create an empty list
    if preprocessing_fn:                                                #if preprocessing function is not None
        _transform.append(album.Lambda(image=preprocessing_fn))         #append the preprocessing function to the list
    _transform.append(album.Lambda(image=to_tensor, mask=to_tensor))    #append the to_tensor function to the list
        
    return album.Compose(_transform)                                    #return the list as a Compose object

## 5. Visualize Augmented Images & Masks

In [None]:
augmented_dataset = RoadsDataset(               #create an augmented dataset
    train_df,                                   #training dataframe
    augmentation=get_training_augmentation(),   #training augmentation
    class_rgb_values=select_class_rgb_values,   #RGB values of the classes
)

random_idx = random.randint(0, len(augmented_dataset)-1) #get a random index

# Different augmentations on image/mask pairs
for idx in range(3):                                            #iterate through the range 3
    image, mask = augmented_dataset[idx]                        #get the image and mask at the index
    visualize(                                                  #visualize the image and mask
        original_image = image,                                 #original image
        #ground truth mask
        ground_truth_mask = colour_code_segmentation(reverse_one_hot(mask),select_class_rgb_values),
        one_hot_encoded_mask = reverse_one_hot(mask)            #one-hot encoded mask
    )

# D. Training DeepLabV3+

## 1. Model Definition

In [None]:
ENCODER = 'resnet50'            #encoder name
ENCODER_WEIGHTS = 'imagenet'    #encoder weights
CLASSES = select_classes        #classes
ACTIVATION = 'sigmoid'          #activation name

# create segmentation model with pretrained encoder
model = smp.DeepLabV3Plus(              #create a DeepLabV3Plus model
    encoder_name=ENCODER,               #encoder name
    encoder_weights=ENCODER_WEIGHTS,    #encoder weights
    classes=len(CLASSES),               #number of classes
    activation=ACTIVATION,              #activation function
)

preprocessing_fn = smp.encoders.get_preprocessing_fn(ENCODER, ENCODER_WEIGHTS) #get the preprocessing function

## 2. Get Train / Val DataLoaders

In [None]:
# Get train and val dataset instances
train_dataset = RoadsDataset(                           #create the training dataset
    train_df,                                           #training dataframe
    augmentation=get_training_augmentation(),           #training augmentation
    preprocessing=get_preprocessing(preprocessing_fn),  #preprocessing
    class_rgb_values=select_class_rgb_values,           #RGB values of the classes
)

valid_dataset = RoadsDataset(                           #create the validation dataset
    valid_df,                                           #validation dataframe
    preprocessing=get_preprocessing(preprocessing_fn),  #preprocessing
    class_rgb_values=select_class_rgb_values,           #RGB values of the classes
)

# Get train and val data loaders
train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True, num_workers=4) #create the training data loader
valid_loader = DataLoader(valid_dataset, batch_size=4, shuffle=False, num_workers=4) #create the validation data loader

## 3. Set Model Hyperparams

In [None]:
import segmentation_models_pytorch.utils as smp_utils #import segmentation models pytorch utils

In [None]:
# Set flag to train the model or not. If set to 'False', only prediction is performed (using an older model checkpoint)
TRAINING = True

# Set num of epochs
EPOCHS = 3

# Set device: `cuda` or `cpu`
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# define loss function
loss = smp.utils.losses.DiceLoss()

# define metrics
metrics = [                                 #create a list of metrics
    smp.utils.metrics.IoU(threshold=0.5),   #IoU
]

# define optimizer
optimizer = torch.optim.Adam([                      #create an Adam optimizer
    dict(params=model.parameters(), lr=0.00008),    #set the learning rate
])

# define learning rate scheduler (not used in this NB)
#lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts( 
#    optimizer, T_0=1, T_mult=2, eta_min=5e-5,
#)

# load best saved model checkpoint from previous commit (if present)
if os.path.exists('../input/road-extraction-from-satellite-images-deeplabv3/best_model.pth'): #if the model checkpoint exists
    model = torch.load('../input/road-extraction-from-satellite-images-deeplabv3/best_model.pth', map_location=DEVICE) #load the model checkpoint
    print('Loaded pre-trained DeepLabV3+ model!') #print a message

In [None]:
# define epoch parameters
train_epoch = smp.utils.train.TrainEpoch(   #create a training epoch
    model,                                  #model
    loss=loss,                              #loss function
    metrics=metrics,                        #metrics
    optimizer=optimizer,                    #optimizer
    device=DEVICE,                          #device
    verbose=True,                           #verbose
)

valid_epoch = smp.utils.train.ValidEpoch(   #create a validation epoch
    model,                                  #model
    loss=loss,                              #loss function
    metrics=metrics,                        #metrics
    device=DEVICE,                          #device
    verbose=True,                           #verbose
)

## 4. Training DeepLabV3+

In [None]:
#measure the time taken to train the model
%%time

if TRAINING:                                                #if TRAINING is True

    best_iou_score = 0.0                                    #initialize the best IoU score
    train_logs_list, valid_logs_list = [], []               #create empty lists to store the training and validation logs

    for i in range(0, EPOCHS):                              #iterate through the number of epochs

        # Perform training & validation
        print('\nEpoch: {}'.format(i))                      #print the epoch number
        train_logs = train_epoch.run(train_loader)          #run the training epoch
        valid_logs = valid_epoch.run(valid_loader)          #run the validation epoch
        train_logs_list.append(train_logs)                  #append the training logs to the list
        valid_logs_list.append(valid_logs)                  #append the validation logs to the list

        # Save model if a better val IoU score is obtained
        if best_iou_score < valid_logs['iou_score']:        #if the best IoU score is less than the validation IoU score
            best_iou_score = valid_logs['iou_score']        #update the best IoU score
            torch.save(model, './best_model.pth')           #save the model checkpoint
            print('Model saved!')                           #print a message

# E. Prediction on Test Data

## 1. Load Trained Model

In [None]:
# load best saved model checkpoint from the current run
if os.path.exists('./best_model.pth'):                                  #if the model checkpoint exists
    best_model = torch.load('./best_model.pth', map_location=DEVICE)    #load the model checkpoint
    print('Loaded DeepLabV3+ model from this run.')                     #print a message

# load best saved model checkpoint from previous commit (if present)
elif os.path.exists('../input/road-extraction-from-satellite-images-deeplabv3/best_model.pth'): #if the model checkpoint exists
    best_model = torch.load('../input/road-extraction-from-satellite-images-deeplabv3/best_model.pth', map_location=DEVICE) #load the model checkpoint
    print('Loaded DeepLabV3+ model from a previous commit.') #print a message

## 2. Visualize Sample Test Dataset

In [None]:
# create test dataloader to be used with DeepLabV3+ model (with preprocessing operation: to_tensor(...))
test_dataset = RoadsDataset(                                #create the test dataset
    valid_df,                                               #validation dataframe
    preprocessing=get_preprocessing(preprocessing_fn),      #preprocessing
    class_rgb_values=select_class_rgb_values,               #RGB values of the classes
)

test_dataloader = DataLoader(test_dataset)                  #create the test data loader

# test dataset for visualization (without preprocessing augmentations & transformations)
test_dataset_vis = RoadsDataset(                            #create the test dataset for visualization
    valid_df,                                               #validation dataframe
    class_rgb_values=select_class_rgb_values,               #RGB values of the classes
)

# get a random test image/mask index
random_idx = random.randint(0, len(test_dataset_vis)-1)     #get a random index
image, mask = test_dataset_vis[random_idx]                  #get the image and mask at the random index

visualize(                                                  #visualize the image and mask
    original_image = image,                                 #original image
    #ground truth mask
    ground_truth_mask = colour_code_segmentation(reverse_one_hot(mask), select_class_rgb_values),
    one_hot_encoded_mask = reverse_one_hot(mask)            #one-hot encoded mask
)


## 3. Create Sample Prediction Folder

In [None]:
sample_preds_folder = 'sample_predictions/' #create a folder to store the sample predictions
if not os.path.exists(sample_preds_folder): #if the folder does not exist
    os.makedirs(sample_preds_folder)        #create the folder

## 4. Test & Predict Test Dataset

In [None]:
for idx in range(len(test_dataset)):                                #iterate through the test dataset

    image, gt_mask = test_dataset[idx]                              #get the image and mask at the index
    image_vis = test_dataset_vis[idx][0].astype('uint8')            #get the image at the index and convert it to uint8
    x_tensor = torch.from_numpy(image).to(DEVICE).unsqueeze(0)      #convert the image to a tensor and move it to the device
    # Predict test image
    pred_mask = best_model(x_tensor)                                #get the predicted mask
    pred_mask = pred_mask.detach().squeeze().cpu().numpy()          #detach the mask, squeeze it and convert it to a numpy array
    # Convert pred_mask from `CHW` format to `HWC` format
    pred_mask = np.transpose(pred_mask,(1,2,0))
    # Get prediction channel corresponding to foreground
    pred_road_heatmap = pred_mask[:,:,select_classes.index('road')]  #get the prediction channel corresponding to the road class
    pred_mask = colour_code_segmentation(reverse_one_hot(pred_mask), select_class_rgb_values)     #colour code the segmentation mask
    # Convert gt_mask from `CHW` format to `HWC` format
    gt_mask = np.transpose(gt_mask,(1,2,0))                          #transpose the ground truth mask
    gt_mask = colour_code_segmentation(reverse_one_hot(gt_mask), select_class_rgb_values)   #colour code the ground truth mask
    #save the image, ground truth mask and predicted mask
    cv2.imwrite(os.path.join(sample_preds_folder, f"sample_pred_{idx}.png"),np.hstack([image_vis, gt_mask, pred_mask])[:,:,::-1])
    
    visualize(                                                       #visualize the image, ground truth mask and predicted mask
        original_image = image_vis,                                  #original image
        ground_truth_mask = gt_mask,                                 #ground truth mask
        predicted_mask = pred_mask,                                  #predicted mask
        pred_road_heatmap = pred_road_heatmap,                       #predicted road heatmap
    )

## 5. Model Evaluation on Test Dataset

In [None]:
test_epoch = smp.utils.train.ValidEpoch(
    model,
    loss=loss, 
    metrics=metrics, 
    device=DEVICE,
    verbose=True,
)

valid_logs = test_epoch.run(test_dataloader)
print("Evaluation on Test Data: ")
print(f"Mean IoU Score: {valid_logs['iou_score']:.4f}")
print(f"Mean Dice Loss: {valid_logs['dice_loss']:.4f}")

## 6. Plot Dice Loss for Train vs. Val

In [None]:
train_logs_df = pd.DataFrame(train_logs_list)
valid_logs_df = pd.DataFrame(valid_logs_list)
train_logs_df.T

In [None]:
plt.figure(figsize=(20,8))
plt.plot(train_logs_df.index.tolist(), train_logs_df.iou_score.tolist(), lw=3, label = 'Train')
plt.plot(valid_logs_df.index.tolist(), valid_logs_df.iou_score.tolist(), lw=3, label = 'Valid')
plt.xlabel('Epochs', fontsize=20)
plt.ylabel('IoU Score', fontsize=20)
plt.title('IoU Score Plot', fontsize=20)
plt.legend(loc='best', fontsize=16)
plt.grid()
plt.savefig('iou_score_plot.png')
plt.show()

In [None]:
plt.figure(figsize=(20,8))
plt.plot(train_logs_df.index.tolist(), train_logs_df.dice_loss.tolist(), lw=3, label = 'Train')
plt.plot(valid_logs_df.index.tolist(), valid_logs_df.dice_loss.tolist(), lw=3, label = 'Valid')
plt.xlabel('Epochs', fontsize=20)
plt.ylabel('Dice Loss', fontsize=20)
plt.title('Dice Loss Plot', fontsize=20)
plt.legend(loc='best', fontsize=16)
plt.grid()
plt.savefig('dice_loss_plot.png')
plt.show()