In [None]:
import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)

import torch
import time
import os
import numpy as np
from pathlib import Path
from PIL import Image
from skimage.transform import resize
import helper
import matplotlib.pyplot as plt
from matplotlib import pyplot
from matplotlib.image import imread
from torchvision import datasets
from torchvision import datasets, transforms, models
from torch import nn, optim, Tensor
import torch.nn.functional as F
import torch.nn as nn
from torch.utils.data.sampler import SubsetRandomSampler
from keras.preprocessing.image import load_img
from keras.preprocessing.image import img_to_array
from torch.utils.data import DataLoader, Dataset, TensorDataset
import random
seed=42

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Using device:', device)
print()

In [None]:
base_path = Path('../input/95cloud-cloud-segmentation-on-satellite-images/95-cloud_training_only_additional_to38-cloud')
red_dir   = base_path/'train_red_additional_to38cloud'
blue_dir  = base_path/'train_blue_additional_to38cloud'
green_dir = base_path/'train_green_additional_to38cloud'
nir_dir   = base_path/'train_nir_additional_to38cloud'
gt_dir    = base_path/'train_gt_additional_to38cloud'

In [None]:
def OpenAsArray(self, idx, invert=False):
        
        raw_rgb=np.stack([np.array(Image.open(self.files[idx]['red'])),
                          np.array(Image.open(self.files[idx]['green'])),
                          np.array(Image.open(self.files[idx]['blue']))], axis = 2)
     
        
        if self.transform:
            raw_rgb= self.transform(Image.fromarray(raw_rgb.astype(np.uint8), 'RGB'))
             
    
        if invert:
            raw_rgb = np.array(raw_rgb, dtype=np.uint16).transpose((2, 0, 1))
    
    
        return (raw_rgb / np.finfo(raw_rgb).max)

In [None]:
class RGB_CloudDataset (Dataset):
    def __init__(self, red_dir, blue_dir, green_dir, gt_dir, transform= None):
        
        
        self.transform   = transform
    
        
        # Listing subdirectories
        # Loop through the files in red folder  
        # and combine, into a dictionary, the other bands
        
        self.files = [self.combine_files(f, green_dir, blue_dir, gt_dir) 
                      for f in red_dir.iterdir() if not f.is_dir()]
        
        random.seed (seed)
        self.files = random.sample (self.files, k= 2000)
        
        
    def combine_files(self, red_file: Path, green_dir, blue_dir, gt_dir):
        
        files = {'red': red_file, 
                 'green':green_dir/red_file.name.replace('red', 'green'),
                 'blue': blue_dir/red_file.name.replace('red', 'blue'), 
                 'gt': gt_dir/red_file.name.replace('red', 'gt')}

        return files
    

    
    
    def OpenAsArray(self, idx, invert=False):
        
        TrueColor = np.stack([np.array(Image.open(self.files[idx]['red'])),
                              np.array(Image.open(self.files[idx]['green'])),
                              np.array(Image.open(self.files[idx]['blue']))], axis = 2)
     
    
        if invert:
            TrueColor = TrueColor.transpose((2, 0, 1))
            
        
        return TrueColor
    
    
    
    def OpenMask(self, idx, add_dims=False):
        raw_mask=np.array(Image.open(self.files[idx]['gt']))
        raw_mask = np.where(raw_mask==255, 1, 0)
        
        return np.expand_dims(raw_mask, 0) if add_dims else raw_mask


        
    def __len__(self):
        return len(self.files)
    
    
    def __getitem__(self, idx):
        x = self.OpenAsArray(idx, invert=True)
        y = self.OpenMask(idx, add_dims=False)
        
        
        if self.transform is not None:
            x, y = self.transform((x, y))
        
        
        return torch.from_numpy(x), torch.from_numpy(y)
    
    
    def open_as_pil(self, idx):
        arr = self.OpenAsArray(idx)
        
        return Image.fromarray(arr.astype(np.uint8), 'RGB')  
    
    
    
    def __repr__(self):
        s = 'Dataset class with {} files'.format(self.__len__())

        return s

In [None]:
class NirGB_CloudDataset (Dataset):
    def __init__(self, red_dir, blue_dir, green_dir, nir_dir, gt_dir, transform = None):
        
        
        self.transform = transform
        
        # Listing subdirectories
        # Loop through the files in red folder  
        # and combine, into a dictionary, the other bands
        
        self.files = [self.combine_files(f, green_dir, blue_dir, nir_dir, gt_dir) 
                      for f in red_dir.iterdir() if not f.is_dir()]
        
        
        
    def combine_files(self, red_file: Path, green_dir, blue_dir, nir_dir, gt_dir):
        
        files = {'red': red_file, 
                 'green':green_dir/red_file.name.replace('red', 'green'),
                 'blue': blue_dir/red_file.name.replace('red', 'blue'), 
                 'nir': nir_dir/red_file.name.replace('red', 'nir'),
                 'gt': gt_dir/red_file.name.replace('red', 'gt')}

        return files
    
    
    
    def OpenAsArray(self, idx):
        
        FalseColor = np.stack([np.array(Image.open(self.files[idx]['nir'])),
                               np.array(Image.open(self.files[idx]['green'])),
                               np.array(Image.open(self.files[idx]['blue']))], axis = 2)
     
                    
        FalseColor      = FalseColor.transpose((2, 0, 1))
    
        return (FalseColor / np.iinfo(FalseColor.dtype).max)
    
    
    
    def OpenMask(self, idx, add_dims=False):
        raw_mask=np.array(Image.open(self.files[idx]['gt']))
        raw_mask = np.where(raw_mask==255, 1, 0)
        
        return np.expand_dims(raw_mask, 0) if add_dims else raw_mask


        
    def __len__(self):
        return len(self.files)
    
    
    def __getitem__(self, idx):
        x = self.OpenAsArray(idx)
        y = self.OpenMask(idx, add_dims=False)
        return torch.from_numpy(x), torch.from_numpy(y)
    
    
    def open_as_pil(self, idx):
        arr = 256 * self.OpenAsArray(idx)
        return Image.fromarray(arr.astype(np.uint8), 'NirGB')  
    
    
    
    def __repr__(self):
        s = 'Dataset class with {} files'.format(self.__len__())
        return s

In [None]:
class Resize(object):
    def __init__(self, size = 256):
        self.size = size
    def __call__(self, sample):
        x, y = sample
        return (resize(x, (x.shape[0], self.size, self.size), mode = "constant", 
                      preserve_range = True, anti_aliasing = False),
                resize(y, (self.size, self.size), mode = "constant", 
                      preserve_range = True, anti_aliasing = False))
    

class Normalize(object):
    def __init__(self, mean, std):
        self.mean = mean
        self.std  = std
        
    def __call__(self, sample):
        x,y =sample
        x = x.transpose(1,2,0)
        x=(x-self.mean)/self.std
        return x.transpose(2,0,1), y
    
    
# Define transforms for the training data and testing data
train_transforms=transforms.Compose([Resize(256),
                                     Normalize([0.485, 0.456, 0.406], [0.229, 0.224,0.225])])

                       
test_transforms=transforms.Compose(Normalize([0.485, 0.456, 0.406], [0.229, 0.224,0.225]))

In [None]:
RGB_data   = RGB_CloudDataset(red_dir, blue_dir, green_dir, gt_dir, transform = train_transforms) 
NirGB_data = NirGB_CloudDataset(red_dir, blue_dir, green_dir, nir_dir, gt_dir, transform = None)

# splitting the data into train, validation, and test datasets

RGBtrain_size = int(0.75 * len(RGB_data))
RGBvalid_size = int(0.15 * len(RGB_data))
RGBtest_size  = len(RGB_data) - RGBtrain_size - RGBvalid_size
RGBremaining_size = len(RGB_data) - RGBtrain_size 

RGBtrain_dataset, RGBremaining_dataset = torch.utils.data.random_split(RGB_data, 
                                                                       [RGBtrain_size, RGBremaining_size])
RGBvalid_dataset, RGBtest_dataset      = torch.utils.data.random_split(RGBremaining_dataset, 
                                                                       [RGBvalid_size, RGBtest_size])


print('\t\t\tDataset')
print("Train data: \t\t{}".format(len(RGBtrain_dataset)),
      "\nValidation data: \t{}".format(len(RGBvalid_dataset)),
     "\nTest data: \t\t{}".format(len(RGBtest_dataset)))



RGBtrain_loader = DataLoader(RGBtrain_dataset, batch_size=12, shuffle=True, num_workers=2)
RGBvalid_loader = DataLoader(RGBvalid_dataset, batch_size=12, shuffle=True, num_workers=2)
RGBtest_loader  = DataLoader(RGBtest_dataset , batch_size=12, shuffle=True, num_workers=2)

RGBdata_iter = iter(RGBvalid_loader)
rgb_img, mask = next(RGBdata_iter)

print('\n')
print('Raw RGB image shape on batch size = {}'.format(rgb_img.size()))
print('Cloud Mask shape on batch size    = {}'.format(mask.size()))

In [None]:
fig, ax = plt.subplots(1,2, figsize = (10,9))

ax[0].imshow((RGB_data.OpenAsArray(5)))
ax[1].imshow(RGB_data.OpenMask(5))

## **Model: Pretrained Vgg16**

In [None]:
class Vgg16(nn.Module):
    def __init__(self):
        super().__init__()
        # Use a pretrained model
        network = models.vgg16(pretrained=True)
        # Replace the classifier
        self.modified_network = nn.Sequential(*list((*list(network.children())[:-2],
                                                     nn.Conv2d(512,2, kernel_size = 1),
                                                     nn.Upsample(size=(256,256), mode='bilinear', align_corners=False)))) 
    
    
    
    def forward(self, xb):
        return self.modified_network(xb)
    
    def freeze(self):
        # To freeze the CONV layers
        for param in self.modified_network[0].parameters():
            param.require_grad = False
        for param in self.modified_network[1:].parameters():
            param.require_grad = True
    
    def unfreeze(self):
        # Unfreeze all layers
        for param in self.modified_network.parameters():
            param.require_grad = True

In [None]:
model = Vgg16().to(device)
model.modified_network [0], model.modified_network [1:]

In [None]:
# Evaluation Metrics

class Evaluation_Metrics(nn.Module):
    def __init__(self):
        super(Evaluation_Metrics, self).__init__()

    def forward(self, prediction, gt, smooth=1):
        
        pred = torch.round(prediction.softmax(dim=1)[:, 1])

        # true positives, false positives, true negatives, false negatives
        TP = torch.sum(pred * gt)
        FP = torch.sum(pred * (1-gt))
        TN = torch.sum((1-pred) * (1-gt))
        FN = torch.sum((1-pred) * gt)
    
    
        # Dice_Score/F1_Score
        Dice_Score = (2 * TP + smooth)/(2*TP + FP + FN  + smooth)
    
        # Jaccard Coefficient: Intersection over Union
        IoU = (TP + smooth)/(TP + FP + FN + smooth)
    
        # Recall
        Recall= (TP + smooth)/(TP + FN + smooth)
    
        # Precision
        Precision = (TP + smooth)/(TP + FP + smooth)

    
        return {'Dice_Score/F1_Score':Dice_Score, 'IoU':IoU, 'Recall': Recall, 'Precision': Precision}


# Dice Loss function
class DiceLoss(nn.Module):
    def __init__(self):
        super(DiceLoss, self).__init__()

    def forward(self, prediction, gt, smooth=1):
        
        # Softmax to get probabilities beteen 0 and 1
        # Transform the prediction tensor of shape (N, C, H, W) --> tensor of shape (N, H, W)
        pred = prediction.softmax(dim=1)[:, 1]
        Dice =  (2 * torch.sum(pred * gt)+ smooth)/(torch.sum(pred + gt)+ smooth)

        
        return 1 - Dice

    
# Set the evaluation metrics and the loss function
Evaluation = Evaluation_Metrics()
loss_fn = DiceLoss()

# Set the optimizer 
optimizer = optim.Adam(model.parameters(), lr=0.0001)

In [None]:
def plot_DSC(train_DSC, valid_DSC):
    plt.plot(train_DSC, '-bx')
    plt.plot(valid_DSC, '-rx')
    plt.xlabel('epoch')
    plt.ylabel('accuracy')
    plt.title('Dice Coefficient vs. No. of epochs');
    
def plot_losses(train_losses, valid_losses):
    plt.plot(train_losses, '-bx')
    plt.plot(valid_losses, '-rx')
    plt.xlabel('epoch')
    plt.ylabel('loss')
    plt.legend(['Training', 'Validation'])
    plt.title('Dice Loss vs. No. of epochs');

In [None]:
model.freeze()
Pretrained_Vgg16 = './Pretrained_Vgg16_CloudSegModel.pt'
train_losses, valid_losses, train_DSC, train_IoU, train_recall, train_precision,valid_DSC, valid_IoU, valid_recall, valid_precision, best_DSC = train(model,RGBtrain_loader, RGBvalid_loader, loss_fn, optimizer, Evaluation, epochs=4, save_path = Pretrained_Vgg16)