In [1]:
import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)

import torch
import time
import os
import numpy as np
from pathlib import Path
from PIL import Image
from skimage.transform import resize
import helper
import matplotlib.pyplot as plt
from matplotlib import pyplot
from matplotlib.image import imread
from torchvision import datasets
from torchvision import datasets, transforms, models
from torch import nn, optim, Tensor
import torch.nn.functional as F
import torch.nn as nn
from torch.utils.data.sampler import SubsetRandomSampler
from keras.preprocessing.image import load_img
from keras.preprocessing.image import img_to_array
from torch.utils.data import DataLoader, Dataset, TensorDataset
import random
seed=42

In [2]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Using device:', device)
print()

In [3]:
class RGBCloudDataset (Dataset):
    def __init__(self, red_dir, blue_dir, green_dir, gt_dir):
        

        # Listing subdirectories
        # Loop through the files in red folder  
        # and combine, into a dictionary, the other bands
        
        self.files = [self.combine_files(f, green_dir, blue_dir,gt_dir) 
                      for f in red_dir.iterdir() if not f.is_dir()]
        
        random.seed (seed)
        self.files = random.sample (self.files, k= 4000)
        
    def combine_files(self, red_file: Path, green_dir, blue_dir,  gt_dir):
        
        files = {'red': red_file, 
                 'green':green_dir/red_file.name.replace('red', 'green'),
                 'blue': blue_dir/red_file.name.replace('red', 'blue'), 
                 'gt': gt_dir/red_file.name.replace('red', 'gt')}

        return files
    
    
    
    def OpenAsArray(self, idx, invert=False):
        
        raw_rgb=np.stack([np.array(Image.open(self.files[idx]['red'])),
                          np.array(Image.open(self.files[idx]['green'])),
                          np.array(Image.open(self.files[idx]['blue']))], axis = 2)
     

        if invert:
            raw_rgb = raw_rgb.transpose((2, 0, 1))
    
    
        return (raw_rgb / np.iinfo(raw_rgb.dtype).max)
    
    
    
    
    def OpenMask(self, idx, add_dims=False):
        
        raw_mask=np.array(Image.open(self.files[idx]['gt']))
        raw_mask = np.where(raw_mask==255, 1, 0)
        
        
        return np.expand_dims(raw_mask, 0) if add_dims else raw_mask



        
    def __len__(self):
        
        return len(self.files)
    
    
    
    def __getitem__(self, idx):
        
        x = torch.tensor(self.OpenAsArray(idx, invert=True), dtype=torch.float32)
        y = torch.tensor(self.OpenMask(idx, add_dims=False), dtype=torch.int64)
        
        return x, y
    
    
    
    def open_as_pil(self, idx):
        
        arr = 256 * self.OpenAsArray(idx)
        
        return Image.fromarray(arr.astype(np.uint8), 'RGB')  
    
    
    
    def __repr__(self):
        
        s = 'Dataset class with {} files'.format(self.__len__())

        return s

In [4]:
base_path = Path('../input/95cloud-cloud-segmentation-on-satellite-images/95-cloud_training_only_additional_to38-cloud')
red_dir   = base_path/'train_red_additional_to38cloud'
blue_dir  = base_path/'train_blue_additional_to38cloud'
green_dir = base_path/'train_green_additional_to38cloud'
nir_dir   = base_path/'train_nir_additional_to38cloud'
gt_dir    = base_path/'train_gt_additional_to38cloud' 

In [5]:
class Resize_data(object):
    def __init__(self, size = 256):
        self.size = size
    def __call__(self, sample):
        x, y = sample
        return (resize(x, (x.shape[0], self.size, self.size), mode = "constant", 
                      preserve_range = True, anti_aliasing = False),
                resize(y, (self.size, self.size), mode = "constant", 
                      preserve_range = True, anti_aliasing = False))

In [6]:
train_transforms=transforms.Compose([transforms.Resize(256),
                                    transforms.ToTensor()])
RGBdata = RGBCloudDataset(red_dir, blue_dir, green_dir, gt_dir)

# splitting the data into train, validation, and test datasets

train_size = int(0.75 * len(RGBdata))
valid_size = int(0.15 * len(RGBdata))
test_size  = len(RGBdata) - train_size - valid_size
remaining_size = len(RGBdata) - train_size 

RGBtrain_dataset, RGBremaining_dataset = torch.utils.data.random_split(RGBdata, [train_size, remaining_size])
RGBvalid_dataset, RGBtest_dataset      = torch.utils.data.random_split(RGBremaining_dataset, [valid_size, test_size])


print('\t\t\tDataset')
print("Train data: \t\t{}".format(len(RGBtrain_dataset)),
      "\nValidation data: \t{}".format(len(RGBvalid_dataset)),
     "\nTest data: \t\t{}".format(len(RGBtest_dataset)))

In [7]:
RGBtrain_loader = DataLoader(RGBtrain_dataset, batch_size=5, shuffle=True, num_workers=2)
RGBvalid_loader = DataLoader(RGBvalid_dataset, batch_size=5, shuffle=True, num_workers=2)
RGBtest_loader  = DataLoader(RGBtest_dataset , batch_size=5, shuffle=True, num_workers=2)

rgb_img, mask = next(iter(RGBtrain_loader))

print('\n')
print('Raw RGB image shape on batch size = {}'.format(rgb_img.size()))
print('Cloud Mask shape on batch size    = {}'.format(mask.size()))

In [8]:
x, y =RGBdata[100]
x.shape, y. shape

In [9]:
fig, ax = plt.subplots(1,2, figsize = (10,9))
ax[0].imshow(RGBdata.OpenAsArray(100))
ax[1].imshow(RGBdata.OpenMask(100))

## **Unet++**

In [8]:
class UnetPlusPlus_block_nested (nn.Module):
    
    def __init__(self, in_channels, mid_channels, out_channels):
        super(UnetPlusPlus_block_nested, self).__init__()
        
        self.conv1 = nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1, bias=False)
        self.bn1   = nn.BatchNorm2d(mid_channels)
        self.relu  = nn.ReLU()
        self.conv2 = nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1, bias=False)
        self.bn2   = nn.BatchNorm2d(out_channels)
    
    def forward (self, x):
        
        x = self.conv1(x)
        x = self.bn1 (x)
        x = self.relu(x)
        x = self.conv2(x)
        x = self.bn2 (x)
        x = self.relu(x)  
        
        return x


class UnetPlusPlus_nested(nn.Module):
    
    def __init__(self, in_channels=3, out_channels=2):
        super(UnetPlusPlus_nested, self).__init__()
        
        n = 64
        filters = [n, n * 2, n * 4, n * 8, n * 16]        
        
        self.pool     = nn.MaxPool2d (kernel_size=2, stride=2)
        self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
        
        self.conv0_0 = UnetPlusPlus_block_nested (in_channels, filters[0], filters[0])
        self.conv1_0 = UnetPlusPlus_block_nested (filters[0] , filters[1], filters[1])
        self.conv2_0 = UnetPlusPlus_block_nested (filters[1] , filters[2], filters[2])
        self.conv3_0 = UnetPlusPlus_block_nested (filters[2] , filters[3], filters[3])
        self.conv4_0 = UnetPlusPlus_block_nested (filters[3] , filters[4], filters[4])        

        self.conv0_1 = UnetPlusPlus_block_nested (filters[0] + filters[1], filters[0], filters[0])
        self.conv1_1 = UnetPlusPlus_block_nested (filters[1] + filters[2], filters[1], filters[1])
        self.conv2_1 = UnetPlusPlus_block_nested (filters[2] + filters[3], filters[2], filters[2])
        self.conv3_1 = UnetPlusPlus_block_nested (filters[3] + filters[4], filters[3], filters[3])        
        
        self.conv0_2 = UnetPlusPlus_block_nested (filters[0]*2 + filters[1], filters[0], filters[0])
        self.conv1_2 = UnetPlusPlus_block_nested (filters[1]*2 + filters[2], filters[1], filters[1])
        self.conv2_2 = UnetPlusPlus_block_nested (filters[2]*2 + filters[3], filters[2], filters[2])
        
        self.conv0_3 = UnetPlusPlus_block_nested (filters[0]*3 + filters[1], filters[0], filters[0])
        self.conv1_3 = UnetPlusPlus_block_nested (filters[1]*3 + filters[2], filters[1], filters[1])
        
        self.conv0_4 = UnetPlusPlus_block_nested (filters[0]*4 + filters[1], filters[0], filters[0])
        
        self.final   = nn.Conv2d(filters[0], out_channels, kernel_size=1)
        
        
    def forward (self, x):
        
        x0_0 = self.conv0_0(x)
        x1_0 = self.conv1_0(self.pool(x0_0))
        x0_1 = self.conv0_1(torch.cat([x0_0, self.upsample(x1_0)], 1))
        
        x2_0 = self.conv2_0(self.pool(x1_0))
        x1_1 = self.conv1_1(torch.cat([x1_0, self.upsample(x2_0)], 1))
        x0_2 = self.conv0_2(torch.cat([x0_0, x0_1, self.upsample(x1_1)], 1))
        
        x3_0 = self.conv3_0(self.pool(x2_0))
        x2_1 = self.conv2_1(torch.cat([x2_0, self.upsample(x3_0)], 1))
        x1_2 = self.conv1_2(torch.cat([x1_0, x1_1, self.upsample(x2_1)], 1))
        x0_3 = self.conv0_3(torch.cat([x0_0, x0_1, x0_2, self.upsample(x1_2)], 1))

        x4_0 = self.conv4_0(self.pool(x3_0))
        x3_1 = self.conv3_1(torch.cat([x3_0, self.upsample(x4_0)], 1))
        x2_2 = self.conv2_2(torch.cat([x2_0, x2_1, self.upsample(x3_1)], 1))
        x1_3 = self.conv1_3(torch.cat([x1_0, x1_1, x1_2, self.upsample(x2_2)], 1))
        x0_4 = self.conv0_4(torch.cat([x0_0, x0_1, x0_2, x0_3, self.upsample(x1_3)], 1))
        
        output = self.final(x0_4)
        
        return output

In [9]:
# Evaluation Metrics

class Evaluation_Metrics(nn.Module):
    def __init__(self):
        super(Evaluation_Metrics, self).__init__()

    def forward(self, prediction, gt, smooth=1):
        
        pred = torch.round(prediction.softmax(dim=1)[:, 1])

        # true positives, false positives, true negatives, false negatives
        TP = torch.sum(pred * gt)
        FP = torch.sum(pred * (1-gt))
        TN = torch.sum((1-pred) * (1-gt))
        FN = torch.sum((1-pred) * gt)
    
    
        # Dice_Score/F1_Score
        Dice_Score = (2 * TP + smooth)/(2*TP + FP + FN  + smooth)
    
        # Jaccard Coefficient: Intersection over Union
        IoU = (TP + smooth)/(TP + FP + FN + smooth)
    
        # Recall
        Recall= (TP + smooth)/(TP + FN + smooth)
    
        # Precision
        Precision = (TP + smooth)/(TP + FP + smooth)

    
        return {'Dice_Score/F1_Score':Dice_Score, 'IoU':IoU, 'Recall': Recall, 'Precision': Precision}


# Dice Loss function
class DiceBCELoss(nn.Module):
    def __init__(self):
        super(DiceBCELoss, self).__init__()

    def forward(self, prediction, gt, smooth=1):
        
        # Softmax to get probabilities beteen 0 and 1
        # Transform the prediction tensor of shape (N, C, H, W) --> tensor of shape (N, H, W)
        pred = prediction.softmax(dim=1)[:, 1]
        
        #flatten label and prediction tensors
        pred = pred.contiguous().view(-1)
        gt = gt.contiguous().view(-1).to(torch.float32)
        
        intersection = (pred * gt).sum()                            
        dice_loss = 1 - (2.*intersection + smooth)/(pred.sum() + gt.sum() + smooth)  
        BCE = F.binary_cross_entropy(pred, gt, reduction='mean')
        Dice_BCE = BCE + dice_loss
        
        return Dice_BCE
    

# Dice Loss function
class DiceLoss(nn.Module):
    def __init__(self):
        super(DiceLoss, self).__init__()

    def forward(self,prediction, gt, smooth=1):
        
        # Softmax to get probabilities beteen 0 and 1
        # Transform the prediction tensor of shape (N, C, H, W) --> tensor of shape (N, H, W)
        pred = prediction.softmax(dim=1)[:, 1]
        
        #flatten label and prediction tensors
        pred = pred.contiguous().view(-1)
        gt = gt.contiguous().view(-1).to(torch.float32)
        
        intersection = (pred * gt).sum()                            
        dice_loss =(2.*intersection + smooth)/(pred.sum() + gt.sum() + smooth)  
        
        return -torch.log(dice_loss)

In [10]:
# Use GPU if it is available
model = UnetPlusPlus_nested(3,2).to(device)

# Set the evaluation metrics and the loss function
Evaluation = Evaluation_Metrics()
loss_fn    = DiceBCELoss()
logDice_loss = DiceLoss()

# Set the optimizer 
optimizer = optim.Adam(model.parameters(), lr=0.0001)

In [13]:
!pip install torchsummary
from torchsummary import summary
summary(model, (3, 384, 384))  

In [None]:
#checking if we have the correct dimensions for our (3,384,384) images
#summary(Vgg16, (3, 384, 384))   
torch.cuda.empty_cache()
# checking if the network works, using one batch of the training set
rgb_img, mask = next(iter(RGBtrain_loader))

rgb_img, mask = rgb_img.to(device), mask.to(device) 
output = model (rgb_img)
output.shape, mask.shape 

In [None]:
torch.cuda.memory_summary(device=None, abbreviated=False)

In [11]:
def plot_DSC(train_DSC, valid_DSC, metric):
    fig = plt.figure(figsize=(10, 10))
    plt.plot(train_DSC, '-bx')
    plt.plot(valid_DSC, '-rx')
    plt.xlabel('epoch')
    # Metric = 'Accuracy' or 'Dice Coefficient'
    plt.ylabel(metric)
    plt.title(metric+' vs. No. of epochs'); 
    
    
def plot_losses(train_losses, valid_losses, metric):
    fig = plt.figure(figsize=(10, 10))
    plt.plot(train_losses, '-bx')
    plt.plot(valid_losses, '-rx')
    plt.xlabel('epoch')
    # Metric = 'Dice Loss' or '- Log Dice loss'
    plt.ylabel(metric)
    plt.legend(['Training', 'Validation'])
    plt.title(metric+' vs. No. of epochs');

In [None]:
def train(model, train_dl, valid_dl, loss_fn, optimizer, Evaluation, epochs, save_path):
    
    start = time.time()
    min_valid_epoch_loss = np.inf
    best_DSC = 0.0
    train_losses, valid_losses = [], []
    train_DSC, train_IoU, train_recall, train_precision = [], [], [], [] 
    valid_DSC, valid_IoU, valid_recall, valid_precision = [], [], [], [] 

    for e in range(epochs):
        print('Epoch  {}/{}'.format(e, epochs-1))
        print('-' * 10)
        
        # Each epoch has a training and validation phase
        for phase in ['train', 'valid']:
            if phase == 'train':
                model.train()  # Set model to train mode
                dataloader = train_dl
            else:
                model.eval()  # Set model to evaluate mode
                dataloader = valid_dl

            running_loss = 0.0
            running_DSC  = 0.0
            running_IoU  = 0.0
            running_recall  = 0.0
            running_precision  = 0.0


            # iterate over data
            for x, y in dataloader:
                x = x.to(device, dtype=torch.float)
                y = y.to(device)
                

                # forward pass 
                if phase == 'train':
                    # zero the gradients
                    optimizer.zero_grad()
                    outputs = model(x)
                    loss    = loss_fn(outputs, y)
                    # backward + optimize only if in training phase
                    # the backward pass frees the graph memory, so there is no 
                    # need for torch.no_grad in this training pass
                    loss.backward()
                    optimizer.step()
                    # scheduler.step()

                else:
                    with torch.no_grad():
                        outputs = model(x)
                        loss    = loss_fn(outputs, y)

                # stats - whatever is the phase  
                acc = Evaluation(outputs, y)
                running_DSC        += acc['Dice_Score/F1_Score'].item()* dataloader.batch_size
                running_IoU        += acc['IoU'].item()* dataloader.batch_size
                running_recall     += acc['Recall'].item()* dataloader.batch_size
                running_precision  += acc['Precision'].item()* dataloader.batch_size    
                
                running_loss += loss.item() * dataloader.batch_size
                    
            epoch_loss       = running_loss/ len(dataloader.dataset)
            epoch_DSC        = running_DSC / len(dataloader.dataset)
            epoch_IoU        = running_IoU / len(dataloader.dataset)
            epoch_Recall     = running_recall / len(dataloader.dataset)
            epoch_Precision  = running_precision / len(dataloader.dataset)
            

            print('{}\nDice Loss: {:.3f}\tDice Coefficient: {:.3f}\tJaccard Coefficient: {:.3f}\tPrecision: {:.3f}\tRecall: {:.3f}'.format( phase, epoch_loss, epoch_DSC, epoch_IoU, epoch_Precision, epoch_Recall))
            print()

            train_losses.append(epoch_loss), train_DSC.append(epoch_DSC),train_IoU.append(epoch_IoU),train_recall.append(epoch_Recall),train_precision.append(epoch_Precision) if phase=='train' else valid_losses.append(epoch_loss), valid_DSC.append(epoch_DSC),valid_IoU.append(epoch_IoU),valid_recall.append(epoch_Recall),valid_precision.append(epoch_Precision)

            
             
             # save model if validation loss has decreased
            if phase == 'valid':
                if epoch_loss <= min_valid_epoch_loss:
                    print('Validation Dice loss decreased ({:.3f} --> {:.3f}).  Saving model ...'.format(
                        min_valid_epoch_loss, epoch_loss)) 
                    print('Best Validation Dice Score ({:.3f} --> {:.3f}).'.format(
                        best_DSC, epoch_DSC))
                    torch.save(model.state_dict(), save_path)
                    min_valid_epoch_loss = epoch_loss
                    best_DSC = epoch_DSC
            

    time_elapsed = time.time() - start
    print('Training complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))    
    
    
    return train_losses, valid_losses, train_DSC, train_IoU, train_recall, train_precision,valid_DSC, valid_IoU, valid_recall, valid_precision, best_DSC

In [None]:
Unetpp = './Unetpp_DiceBCE.pt'
train_losses, valid_losses, train_DSC, train_IoU, train_recall, train_precision,valid_DSC, valid_IoU, valid_recall, valid_precision, best_DSC = train(model, RGBtrain_loader, RGBvalid_loader, loss_fn, optimizer, Evaluation, epochs=20, save_path = Unetpp)

In [None]:
plot_DSC(train_DSC, valid_DSC, 'Dice Coefficient')
plt.savefig('./Dice_plot_UnetPlusPlus.png')

In [None]:
plot_losses(train_losses, valid_losses, 'Dice Loss')
plt.savefig('./Losses_plot_UnetPlusPlus.png')

## **3 Layers Unet++**

In [None]:
class UnetPlusPlus_block_nested (nn.Module):
    
    def __init__(self, in_channels, mid_channels, out_channels):
        super(UnetPlusPlus_block_nested, self).__init__()
        
        self.conv1 = nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1, bias=False)
        self.bn1   = nn.BatchNorm2d(mid_channels)
        self.relu  = nn.ReLU()
        self.conv2 = nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1, bias=False)
        self.bn2   = nn.BatchNorm2d(out_channels)
    
    def forward (self, x):
        
        x = self.conv1(x)
        x = self.bn1 (x)
        x = self.relu(x)
        x = self.conv2(x)
        x = self.bn2 (x)
        x = self.relu(x)  
        
        return x


class UnetPlusPlus_nested(nn.Module):
    
    def __init__(self, in_channels=3, out_channels=2):
        super(UnetPlusPlus_nested, self).__init__()
        
        n = 64
        filters = [n, n * 2, n * 4, n * 8, n * 16]        
        
        self.pool     = nn.MaxPool2d (kernel_size=2, stride=2)
        self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
        
        self.conv0_0 = UnetPlusPlus_block_nested (in_channels, filters[0], filters[0])
        self.conv1_0 = UnetPlusPlus_block_nested (filters[0] , filters[1], filters[1])
        self.conv2_0 = UnetPlusPlus_block_nested (filters[1] , filters[2], filters[2])
        self.conv3_0 = UnetPlusPlus_block_nested (filters[2] , filters[3], filters[3])       

        self.conv0_1 = UnetPlusPlus_block_nested (filters[0] + filters[1], filters[0], filters[0])
        self.conv1_1 = UnetPlusPlus_block_nested (filters[1] + filters[2], filters[1], filters[1])
        self.conv2_1 = UnetPlusPlus_block_nested (filters[2] + filters[3], filters[2], filters[2])       
        
        self.conv0_2 = UnetPlusPlus_block_nested (filters[0]*2 + filters[1], filters[0], filters[0])
        self.conv1_2 = UnetPlusPlus_block_nested (filters[1]*2 + filters[2], filters[1], filters[1])
        
        self.conv0_3 = UnetPlusPlus_block_nested (filters[0]*3 + filters[1], filters[0], filters[0])
        
        self.final   = nn.Conv2d(filters[0], out_channels, kernel_size=1)
        
        
    def forward (self, x):
        
        x0_0 = self.conv0_0(x)
        x1_0 = self.conv1_0(self.pool(x0_0))
        x0_1 = self.conv0_1(torch.cat([x0_0, self.upsample(x1_0)], 1))
        
        x2_0 = self.conv2_0(self.pool(x1_0))
        x1_1 = self.conv1_1(torch.cat([x1_0, self.upsample(x2_0)], 1))
        x0_2 = self.conv0_2(torch.cat([x0_0, x0_1, self.upsample(x1_1)], 1))
        
        x3_0 = self.conv3_0(self.pool(x2_0))
        x2_1 = self.conv2_1(torch.cat([x2_0, self.upsample(x3_0)], 1))
        x1_2 = self.conv1_2(torch.cat([x1_0, x1_1, self.upsample(x2_1)], 1))
        x0_3 = self.conv0_3(torch.cat([x0_0, x0_1, x0_2, self.upsample(x1_2)], 1))
        
        output = self.final(x0_3)
        
        return output

In [None]:
# Use GPU if it is available
model3 = UnetPlusPlus_nested(3,2).to(device)

# Set the evaluation metrics and the loss function
Evaluation = Evaluation_Metrics()
loss_fn    = DiceBCELoss()

# Set the optimizer 
optimizer = optim.Adam(model3.parameters(), lr=0.0001)
Unetpp3 = './Unetpp2_DiceBCE.pt'
train_losses, valid_losses, train_DSC, train_IoU, train_recall, train_precision,valid_DSC, valid_IoU, valid_recall, valid_precision, best_DSC = train(model3, RGBtrain_loader, RGBvalid_loader, loss_fn, optimizer, Evaluation, epochs=40, save_path = Unetpp3)

In [None]:
plot_DSC(train_DSC, valid_DSC, 'Dice Coefficient')   
plt.savefig('./Dice_plot_UnetPlusPlus3.png')   

In [None]:
plot_losses(train_losses, valid_losses, 'Dice Loss')
plt.savefig('./Losses_plot_UnetPlusPlus3.png')