In [None]:
import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)

import torch
import time
import os
import numpy as np
from pathlib import Path
from PIL import Image
from skimage.transform import resize
import helper
import matplotlib.pyplot as plt
from matplotlib import pyplot
from matplotlib.image import imread
from torchvision import datasets
from torchvision import datasets, transforms, models
from torch import nn, optim, Tensor
import torch.nn.functional as F
import torch.nn as nn
from torch.utils.data.sampler import SubsetRandomSampler
from keras.preprocessing.image import load_img
from keras.preprocessing.image import img_to_array
from torch.utils.data import DataLoader, Dataset, TensorDataset
import random
seed=42
torch.manual_seed(14)

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Using device:', device)
print()

In [None]:
class CloudDataset (Dataset):
    def __init__(self, red_dir, blue_dir, green_dir, nir_dir, gt_dir):
        

        # Listing subdirectories
        # Loop through the files in red folder  
        # and combine, into a dictionary, the other bands
        
        self.files = [self.combine_files(f, green_dir, blue_dir, nir_dir, gt_dir) 
                      for f in red_dir.iterdir() if not f.is_dir()]
        
                
        random.seed (seed)
        self.files = random.sample (self.files, k= 10000)
        
    def combine_files(self, red_file: Path, green_dir, blue_dir, nir_dir, gt_dir):
        
        files = {'red': red_file, 
                 'green':green_dir/red_file.name.replace('red', 'green'),
                 'blue': blue_dir/red_file.name.replace('red', 'blue'), 
                 'nir': nir_dir/red_file.name.replace('red', 'nir'),
                 'gt': gt_dir/red_file.name.replace('red', 'gt')}

        return files
    
    
    
    def OpenAsArray(self, idx, invert=False, include_nir=False):
        
        raw_rgb=np.stack([np.array(Image.open(self.files[idx]['red'])),
                          np.array(Image.open(self.files[idx]['green'])),
                          np.array(Image.open(self.files[idx]['blue']))], axis = 2)
     
     
        if include_nir:
            nir = np.expand_dims(np.array(Image.open(self.files[idx]['nir'])), axis = 2)
            raw_rgb = np.concatenate([raw_rgb, nir], axis = 2) 
                          
        if invert:
            raw_rgb = raw_rgb.transpose((2, 0, 1))
    
    
        return (raw_rgb / np.iinfo(raw_rgb.dtype).max)
    
    
    
    
    def OpenMask(self, idx, add_dims=False):
        
        raw_mask=np.array(Image.open(self.files[idx]['gt']))
        raw_mask = np.where(raw_mask==255, 1, 0)
        
        
        return np.expand_dims(raw_mask, 0) if add_dims else raw_mask



        
    def __len__(self):
        
        return len(self.files)
    
    
    
    def __getitem__(self, idx):
        
        x = torch.tensor(self.OpenAsArray(idx, invert=True, include_nir=True), dtype=torch.float32)
        y = torch.tensor(self.OpenMask(idx, add_dims=False), dtype=torch.int64)
        
        return x, y
    
    
    
    def open_as_pil(self, idx):
        
        arr = 256 * self.OpenAsArray(idx)
        
        return Image.fromarray(arr.astype(np.uint8), 'RGB')  
    
    
    
    def __repr__(self):
        
        s = 'Dataset class with {} files'.format(self.__len__())

        return s

In [None]:
base_path = Path('../input/95cloud-cloud-segmentation-on-satellite-images/95-cloud_training_only_additional_to38-cloud')
red_dir   = base_path/'train_red_additional_to38cloud'
blue_dir  = base_path/'train_blue_additional_to38cloud'
green_dir = base_path/'train_green_additional_to38cloud'
nir_dir   = base_path/'train_nir_additional_to38cloud'
gt_dir    = base_path/'train_gt_additional_to38cloud' 

In [None]:
data = CloudDataset(red_dir, blue_dir, green_dir, nir_dir, gt_dir)  
# splitting the data into train, validation, and test datasets

train_size = int(0.75 * len(data))
valid_size = int(0.15 * len(data))
test_size  = len(data) - train_size - valid_size
remaining_size = len(data) - train_size 

train_dataset, remaining_dataset = torch.utils.data.random_split(data, [train_size, remaining_size])
valid_dataset, test_dataset      = torch.utils.data.random_split(remaining_dataset, [valid_size, test_size])


print('\t\t\tDataset')
print("Train data: \t\t{}".format(len(train_dataset)),
      "\nValidation data: \t{}".format(len(valid_dataset)),
     "\nTest data: \t\t{}".format(len(test_dataset)))



train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=2)
valid_loader = DataLoader(valid_dataset, batch_size=32, shuffle=True, num_workers=2)
test_loader  = DataLoader(test_dataset , batch_size=32, shuffle=False, num_workers=2)

rgb_img, mask = next(iter(valid_loader))

print('\n')
print('Raw RGB image shape on batch size = {}'.format(rgb_img.size()))
print('Cloud Mask shape on batch size    = {}'.format(mask.size()))

In [None]:
x, y =data[100]
x.shape, y. shape

In [None]:
fig, ax = plt.subplots(1,2, figsize = (10,9))
ax[0].imshow(data.OpenAsArray(100))
ax[1].imshow(data.OpenMask(100))

## **Model: A simple Unet**

In [None]:
class U_net(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(U_net, self).__init__()
            
        # downsampling part
        self.DownConv1 = self.ContractBlock(in_channels, 32, 7, 3)
        self.DownConv2 = self.ContractBlock(32, 64, 3, 1)
        self.DownConv3 = self.ContractBlock(64, 128, 3, 1)
            
        # upsampling part
        self.UpConv3 = self.ExpandBlock(128, 64, 3, 1)
        self.UpConv2 = self.ExpandBlock(64*2, 32, 3, 1)
        self.UpConv1 = self.ExpandBlock(32*2, out_channels, 3, 1)
        
    def __call__(self, x):
         
        DownConv1 = self.DownConv1(x)
        DownConv2 = self.DownConv2(DownConv1) 
        DownConv3 = self.DownConv3(DownConv2)   
        UpConv3   = self.UpConv3 (DownConv3)
        UpConv2   = self.UpConv2 (torch.cat([UpConv3, DownConv2], 1))
        UpConv1   = self.UpConv1 (torch.cat([UpConv2, DownConv1], 1))
        
        return UpConv1
        
        
    def ContractBlock(self, in_channels, out_channels, kernel_size, padding):
        
        contract = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=1, padding=padding),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(),
        
            nn.Conv2d(out_channels, out_channels, kernel_size=kernel_size, stride=1, padding=padding),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(),
        
            nn.MaxPool2d(kernel_size=3, stride=2, padding=1))
    
        return contract



    def ExpandBlock(self, in_channels, out_channels, kernel_size, padding):
        
        expand = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=1, padding=padding),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(),
        
            nn.Conv2d(out_channels, out_channels, kernel_size=kernel_size, stride=1, padding=padding),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(),
        
            nn.ConvTranspose2d(out_channels, out_channels, kernel_size=3, stride=2, padding=1, output_padding=1) )
    
        return expand

In [None]:
# Evaluation Metrics

class Evaluation_Metrics(nn.Module):
    def __init__(self):
        super(Evaluation_Metrics, self).__init__()

    def forward(self, prediction, gt, smooth=1):
        
        pred = torch.round(prediction.softmax(dim=1)[:, 1])

        # true positives, false positives, true negatives, false negatives
        TP = torch.sum(pred * gt)
        FP = torch.sum(pred * (1-gt))
        TN = torch.sum((1-pred) * (1-gt))
        FN = torch.sum((1-pred) * gt)
    
    
        # Dice_Score/F1_Score
        Dice_Score = (2 * TP + smooth)/(2*TP + FP + FN  + smooth)
    
        # Jaccard Coefficient: Intersection over Union
        IoU = (TP + smooth)/(TP + FP + FN + smooth)
    
        # Recall
        Recall= (TP + smooth)/(TP + FN + smooth)
    
        # Precision
        Precision = (TP + smooth)/(TP + FP + smooth)

    
        return {'Dice_Score/F1_Score':Dice_Score, 'IoU':IoU, 'Recall': Recall, 'Precision': Precision}


# Dice Loss function
class DiceLoss(nn.Module):
    def __init__(self):
        super(DiceLoss, self).__init__()

    def forward(self,prediction, gt, smooth=1):
        
        # Softmax to get probabilities beteen 0 and 1
        # Transform the prediction tensor of shape (N, C, H, W) --> tensor of shape (N, H, W)
        pred = prediction.softmax(dim=1)[:, 1]
        
        #flatten label and prediction tensors
        pred = pred.contiguous().view(-1)
        gt = gt.contiguous().view(-1).to(torch.float32)
        
        intersection = (pred * gt).sum()                            
        dice_loss =(2.*intersection + smooth)/(pred.sum() + gt.sum() + smooth)  
        
        return -torch.log(dice_loss)
    

#Binary cross-entropy (BCE)-Dice loss function
class DiceBCELoss(nn.Module):
    def __init__(self):
        super(DiceBCELoss, self).__init__()

    def forward(self, prediction, gt, smooth=1):
        
        # Softmax to get probabilities beteen 0 and 1
        # Transform the prediction tensor of shape (N, C, H, W) --> tensor of shape (N, H, W)
        pred = prediction.softmax(dim=1)[:, 1]
        
        #flatten label and prediction tensors
        pred = pred.contiguous().view(-1)
        gt = gt.contiguous().view(-1).to(torch.float32)
        
        intersection = (pred * gt).sum()                            
        dice_loss = 1 - (2.*intersection + smooth)/(pred.sum() + gt.sum() + smooth)  
        BCE = F.binary_cross_entropy(pred, gt, reduction='mean')
        Dice_BCE = BCE + dice_loss
        
        return Dice_BCE

In [None]:
# checking if the network works, using one batch of the training set

U_net_model = U_net(4, 2)
rgb_img, mask = next(iter(train_loader))
output = U_net_model (rgb_img)
output.shape

In [None]:
# Use GPU if it is available
model = U_net_model.to(device)

# Set the evaluation metrics and the loss function
Evaluation = Evaluation_Metrics()
loss_fn    = DiceBCELoss()
logDice_loss = DiceLoss()

# Set the optimizer 
optimizer = optim.Adam(model.parameters(), lr=0.00001)

In [None]:
def plot_DSC(train_DSC, valid_DSC):
    fig = plt.figure(figsize=(10, 10))
    plt.plot(train_DSC, '-bx')
    plt.plot(valid_DSC, '-rx')
    plt.xlabel('epoch')
    plt.ylabel('accuracy')
    plt.title('Dice Coefficient vs. No. of epochs');
    
def plot_losses(train_losses, valid_losses):
    fig = plt.figure(figsize=(10, 10))
    plt.plot(train_losses, '-bx')
    plt.plot(valid_losses, '-rx')
    plt.xlabel('epoch')
    plt.ylabel('loss')
    plt.legend(['Training', 'Validation'])
    plt.title('Dice Loss vs. No. of epochs');

In [None]:
def train(model, train_dl, valid_dl, loss_fn, optimizer, Evaluation, epochs, save_path):
    
    start = time.time()
    min_valid_epoch_loss = np.inf
    best_DSC = 0.0
    train_losses, valid_losses = [], []
    train_DSC, train_IoU, train_recall, train_precision = [], [], [], [] 
    valid_DSC, valid_IoU, valid_recall, valid_precision = [], [], [], [] 

    for e in range(epochs):
        print('Epoch  {}/{}'.format(e, epochs-1))
        print('-' * 10)
        
        # Each epoch has a training and validation phase
        for phase in ['train', 'valid']:
            if phase == 'train':
                model.train()  # Set model to train mode
                dataloader = train_dl
            else:
                model.eval()  # Set model to evaluate mode
                dataloader = valid_dl

            running_loss = 0.0
            running_DSC  = 0.0
            running_IoU  = 0.0
            running_recall  = 0.0
            running_precision  = 0.0


            # iterate over data
            for x, y in dataloader:
                x = x.to(device, dtype=torch.float)
                y = y.to(device)
                

                # forward pass 
                if phase == 'train':
                    # zero the gradients
                    optimizer.zero_grad()
                    outputs = model(x)
                    loss    = loss_fn(outputs, y)
                    # backward + optimize only if in training phase
                    # the backward pass frees the graph memory, so there is no 
                    # need for torch.no_grad in this training pass
                    loss.backward()
                    optimizer.step()
                    # scheduler.step()

                else:
                    with torch.no_grad():
                        outputs = model(x)
                        loss    = loss_fn(outputs, y)

                # stats - whatever is the phase  
                acc = Evaluation(outputs, y)
                running_DSC        += acc['Dice_Score/F1_Score'].item()* dataloader.batch_size
                running_IoU        += acc['IoU'].item()* dataloader.batch_size
                running_recall     += acc['Recall'].item()* dataloader.batch_size
                running_precision  += acc['Precision'].item()* dataloader.batch_size    
                
                running_loss += loss.item() * dataloader.batch_size
                    
            epoch_loss       = running_loss/ len(dataloader.dataset)
            epoch_DSC        = running_DSC / len(dataloader.dataset)
            epoch_IoU        = running_IoU / len(dataloader.dataset)
            epoch_Recall     = running_recall / len(dataloader.dataset)
            epoch_Precision  = running_precision / len(dataloader.dataset)
            

            print('{}\nDice Loss: {:.3f}\tDice Coefficient: {:.3f}\tJaccard Coefficient: {:.3f}\tPrecision: {:.3f}\tRecall: {:.3f}'.format( phase, epoch_loss, epoch_DSC, epoch_IoU, epoch_Precision, epoch_Recall))
            print()

            train_losses.append(epoch_loss), train_DSC.append(epoch_DSC),train_IoU.append(epoch_IoU),train_recall.append(epoch_Recall),train_precision.append(epoch_Precision) if phase=='train' else valid_losses.append(epoch_loss), valid_DSC.append(epoch_DSC),valid_IoU.append(epoch_IoU),valid_recall.append(epoch_Recall),valid_precision.append(epoch_Precision)

            
             
             # save model if validation loss has decreased
            if phase == 'valid':
                if epoch_loss <= min_valid_epoch_loss:
                    print('Validation Dice loss decreased ({:.3f} --> {:.3f}).  Saving model ...'.format(
                        min_valid_epoch_loss, epoch_loss)) 
                    print('Best Validation Dice Score ({:.3f} --> {:.3f}).'.format(
                        best_DSC, epoch_DSC))
                    torch.save(model.state_dict(), save_path)
                    min_valid_epoch_loss = epoch_loss
                    best_DSC = epoch_DSC
            

    time_elapsed = time.time() - start
    print('Training complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))    
    
    
    return train_losses, valid_losses, train_DSC, train_IoU, train_recall, train_precision,valid_DSC, valid_IoU, valid_recall, valid_precision, best_DSC

In [None]:
Unet = './Benchmark_Unet_DiceBCE_lr5.pt'
train_losses, valid_losses, train_DSC, train_IoU, train_recall, train_precision,valid_DSC, valid_IoU, valid_recall, valid_precision, best_DSC = train(model, train_loader, valid_loader, loss_fn, optimizer, Evaluation, epochs=50, save_path = Unet)

In [None]:
plot_DSC(train_DSC, valid_DSC)
plt.savefig('./Dice_plot_Benchmark_Unet_lr5.png')

In [None]:
plot_losses(train_losses, valid_losses)
plt.savefig('./Losses_plot_Benchmark_Unet_lr5.png')

In [None]:
def batch_to_img(xb, idx):
    img = np.array(xb[idx,0:3])
    return img.transpose((1,2,0))

def predb_to_mask(predb, idx):
    p = torch.functional.F.softmax(predb[idx], 0)
    return p.argmax(0).cpu()

In [None]:
PATH = "../input/unet-logdice/Benchmark_Unet_DiceBCE_lr5.pt"

model = U_net(4, 2)
map_location=torch.device('cpu')
model.load_state_dict(torch.load(PATH,map_location=torch.device('cpu')),strict=False)
model.eval()

In [None]:
xb, yb = next(iter(test_loader))

with torch.no_grad():
    predb = model(xb.to(device))

predb.shape

In [None]:
bs = 32
fig, ax = plt.subplots(bs,3, figsize=(15,bs*5))
for i in range(bs):
    ax[i,0].imshow(batch_to_img(xb,i))
    ax[i,1].imshow(yb[i])
    ax[i,2].imshow(predb_to_mask(predb, i))

In [None]:
test_loss = 0.0
test_DSC  = 0.0
test_IoU  = 0.0
test_recall  = 0.0
test_precision  = 0.0


# iterate over test data
for x, y in test_loader:
    x = x.to(device, dtype=torch.float)
    y = y.to(device)
    
    
    with torch.no_grad():
        model.eval()
        outputs = model(x)
        loss    = logDice_loss(outputs, y)
                
        # stats - whatever is the phase  
        acc = Evaluation(outputs, y)
        test_DSC        += acc['Dice_Score/F1_Score'].item()* test_loader.batch_size
        test_IoU        += acc['IoU'].item()* test_loader.batch_size
        test_recall     += acc['Recall'].item()* test_loader.batch_size
        test_precision  += acc['Precision'].item()* test_loader.batch_size    
        test_loss += loss.item() * test_loader.batch_size
                    
epoch_loss       = test_loss/ len(test_loader.dataset)
epoch_DSC        = test_DSC / len(test_loader.dataset)
epoch_IoU        = test_IoU / len(test_loader.dataset)
epoch_Recall     = test_recall / len(test_loader.dataset)
epoch_Precision  = test_precision / len(test_loader.dataset)
            
            
print('Dice Loss: {:.3f}\tDice Coefficient: {:.3f}\tJaccard Coefficient: {:.3f}\tPrecision: {:.3f}\tRecall: {:.3f}'.format(epoch_loss, epoch_DSC, epoch_IoU, epoch_Precision, epoch_Recall))
print()