# Data preprocessing

In [None]:
import pathlib

import numpy as np

import os

import PIL
from PIL import Image

import cv2

import glob

import random

from skimage.io import imshow

import matplotlib.pyplot as plt

In [None]:
data_dir = pathlib.Path("/kaggle/input/satellite-images-of-water-bodies/Water Bodies Dataset")

In [None]:
image_count = len(list(data_dir.glob('*/*.jpg')))
print(image_count)

In [None]:
images = list(data_dir.glob('Images/*'))
PIL.Image.open(str(images[100]))

In [None]:
def prep_image(image, crop_size, size_y, size_x):
    # Resize the image
    prepd_image = image.resize((size_y, size_x))
    # Crop the image to remove the border black pixels
    borders = (crop_size, crop_size, size_x - crop_size, size_y - crop_size)
    prepd_image = prepd_image.crop(borders)
    return prepd_image

In [None]:
size_x=148
size_y=148
crop_size=10

In [None]:
img=Image.open(r"/kaggle/input/satellite-images-of-water-bodies/Water Bodies Dataset/Images/water_body_100.jpg")
img=prep_image(img,crop_size,size_y,size_x)
img=np.asarray(img)
imshow(img)

In [None]:
################# Storing Train Images into an array #############
train_images=[]

for directory_path in glob.glob("/kaggle/input/satellite-images-of-water-bodies/Water Bodies Dataset/Images"):
    for img_path in glob.glob(os.path.join(directory_path,"*.jpg")):
        img=Image.open(img_path)
        img=prep_image(img,crop_size,size_y,size_x)
        img = np.asarray(img)
        train_images.append(img)
        
train_images = np.array(train_images) #converting list to array

In [None]:
def prep_mask(image, crop_size, size_y, size_x):
    # Resize the image
    prepd_image = cv2.resize(image,(size_y, size_x))
    # Crop the image to remove the border black pixels
    prepd_image = prepd_image[crop_size:-crop_size, crop_size:-crop_size]
    return prepd_image

In [None]:
################# Storing Train Masks into an array #############
train_masks = []

for directory_path in glob.glob("/kaggle/input/satellite-images-of-water-bodies/Water Bodies Dataset/Masks"):
    for mask_path in glob.glob(os.path.join(directory_path,"*.jpg")):
        mask=cv2.imread(mask_path,0)
        mask=prep_mask(mask,crop_size,size_y,size_x)
        mask=mask.reshape(size_x - 2*crop_size, size_y - 2*crop_size, 1)
        train_masks.append(mask)
        
train_masks = np.array(train_masks) #converting list to array

In [None]:
len(train_images), len(train_masks)

In [None]:
train_images[0].shape, train_masks[0].shape

In [None]:
#normalizing
x = train_images/255
y = train_masks/255

In [None]:
########## Displaying random image from X_train and Y_train ######### 
random_num = random.randint(0,516)
imshow(x[100])
plt.show()
imshow(y[100])
plt.show() 

# Models

## Preparing dataset

In [None]:
from sklearn.model_selection import train_test_split
x_train, x_val, y_train, y_val = train_test_split(x, y, test_size=0.1, random_state=17)
x_val, x_test, y_val, y_test = train_test_split(x_val, y_val, test_size=0.5, random_state=23)

In [None]:
from torch.utils.data import DataLoader
batch_size = 256
data_tr = DataLoader(list(zip(np.rollaxis(x_train, 3, 1), np.rollaxis(y_train, 3, 1))), 
                     batch_size=batch_size, shuffle=True)
data_val = DataLoader(list(zip(np.rollaxis(x_val, 3, 1), np.rollaxis(y_val, 3, 1))), batch_size=batch_size, shuffle=True)
data_ts = DataLoader(list(zip(np.rollaxis(x_test, 3, 1), np.rollaxis(y_test, 3, 1))),
                     batch_size=batch_size, shuffle=True)

## Model

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [None]:
model = torch.hub.load('mateuszbuda/brain-segmentation-pytorch', 'unet',
    in_channels=3, out_channels=1, init_features=64, pretrained=False)

Let us have look, if this works. (If you use this, rerun dataloader part)

In [None]:
from torch import Tensor


def dice_coeff(input: Tensor, target: Tensor, reduce_batch_first: bool = False, epsilon: float = 1e-6):
    # Average of Dice coefficient for all batches, or for a single mask
    assert input.size() == target.size()
    assert input.dim() == 3 or not reduce_batch_first

    sum_dim = (-1, -2) if input.dim() == 2 or not reduce_batch_first else (-1, -2, -3)

    inter = 2 * (input * target).sum(dim=sum_dim)
    sets_sum = input.sum(dim=sum_dim) + target.sum(dim=sum_dim)
    sets_sum = torch.where(sets_sum == 0, inter, sets_sum)

    dice = (inter + epsilon) / (sets_sum + epsilon)
    return dice.mean()


def multiclass_dice_coeff(input: Tensor, target: Tensor, reduce_batch_first: bool = False, epsilon: float = 1e-6):
    # Average of Dice coefficient for all classes
    return dice_coeff(input.flatten(0, 1), target.flatten(0, 1), reduce_batch_first, epsilon)


def dice_loss(input: Tensor, target: Tensor, multiclass: bool = False):
    # Dice loss (objective to minimize) between 0 and 1
    fn = multiclass_dice_coeff if multiclass else dice_coeff
    return 1 - fn(input, target, reduce_batch_first=True)

In [None]:
from torchvision import transforms

input_batch, true_output_batch = next(iter(data_tr))

input_batch = input_batch.float()

if torch.cuda.is_available():
    input_batch = input_batch.to('cuda')
    model = model.to('cuda')

with torch.no_grad():
    output = model(input_batch).cpu()

print(output.shape)
print(dice_loss(torch.round(output[0]), true_output_batch[0]))
print(F.binary_cross_entropy_with_logits(torch.round(output[0]), true_output_batch[0]) + dice_loss(torch.round(output[0]), true_output_batch[0]))

prediction = np.rollaxis(torch.round(output[0]).numpy(), 0, 3)
image = np.rollaxis(input_batch[0].cpu().numpy(), 0, 3)
mask = np.rollaxis(true_output_batch[0].numpy(), 0, 3)

imshow(prediction)
plt.show()

imshow(mask)
plt.show()

imshow(image)
plt.show()

In [None]:
# PyTroch version

SMOOTH = 1e-3

def iou_pytorch(outputs: torch.Tensor, labels: torch.Tensor):
    # You can comment out this line if you are passing tensors of equal shape
    # But if you are passing output from UNet or something it will most probably
    # be with the BATCH x 1 x H x W shape
    outputs = outputs.squeeze(1).int()  # BATCH x 1 x H x W => BATCH x H x W
    labels = torch.round(labels).int()
    intersection = (outputs * labels).float().sum((1, 2))  # Will be zero if Truth=0 or Prediction=0
    union = ((outputs + labels).float() / 2).sum((1, 2))         # Will be zzero if both are 0
    
    iou = (intersection + SMOOTH) / (union + SMOOTH)  # We smooth our devision to avoid 0/0
    
    thresholded = torch.clamp(20 * (iou - 0.5), 0, 10).ceil() / 10  # This is equal to comparing with thresolds
    
    return thresholded  # Or thresholded.mean() if you are interested in average across the batch
    

In [None]:
iou_pytorch(torch.round(output[0].cpu()), true_output_batch[0]).item()

In [None]:
dice_coeff(torch.round(output[0].cpu()), torch.round(true_output_batch[0])).item()

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

In [None]:
def iou_score_model(model, data):
    model.eval()  # testing mode
    scores = 0
    for X_batch, Y_label in data:
        Y_pred = torch.unsqueeze((model(X_batch.float().to(device)).max(dim=1)[0] > 0.5).float(), dim = 1).to('cpu').detach()
        Y_lab = (Y_label.max(dim=1)[0] > 0.5).int()
        scores += iou_pytorch(Y_pred.to(device), Y_lab.to(device)).mean().item()
    return scores/len(data)

In [None]:
def dice_score_model(model, data):
    model.eval()  # testing mode
    scores = 0
    for X_batch, Y_label in data:
        Y_pred = (model(X_batch.float().to(device)).max(dim=1)[0] > 0.5).float().to('cpu').detach()
        Y_lab = (Y_label.max(dim=1)[0] > 0.5).int()
        scores += dice_coeff(Y_pred.to(device), Y_lab.to(device)).mean().item()
    return scores/len(data)

In [None]:
dice_score_model(model, data_ts)

In [None]:
model.to(device)

In [None]:
n_train = len(x_train)
n_train

In [None]:
from tqdm import tqdm

from torch import optim

from IPython.display import clear_output

In [None]:
def train(
    model,
    loss_fn,
    epochs,
    data_tr,
    data_val,
    amp: bool = False,
    learning_rate: float = 1e-5,
    weight_decay: float = 1e-8,
    momentum: float = 0.999,
    gradient_clipping: float = 1.0
):
    
    X_val, Y_val = next(iter(data_val))
    
    loss_train = []
    loss_val = []
    scores = []
    
    criterion = nn.BCEWithLogitsLoss()
    optimizer = optim.RMSprop(model.parameters(),
                              lr=learning_rate, weight_decay=weight_decay, momentum=momentum, foreach=True)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'max', patience=5)  # goal: maximize Dice score
    grad_scaler = torch.cuda.amp.GradScaler(enabled=amp)
    global_step = 0
    
    for epoch in range(1, epochs + 1):
        model.train()
        epoch_loss = 0
        with tqdm(total=n_train, desc=f'Epoch {epoch}/{epochs}', unit='img') as pbar:
            for images, true_masks in data_tr:
                images = images.float().to(device)
                true_masks = true_masks.to(device)
                
                masks_pred = model(images)
                loss = criterion(masks_pred.squeeze(1), true_masks.float().squeeze(1))
                loss += dice_loss(masks_pred.squeeze(1), true_masks.float().squeeze(1), multiclass=False)

                optimizer.zero_grad(set_to_none=True)
                grad_scaler.scale(loss).backward()
                torch.nn.utils.clip_grad_norm_(model.parameters(), gradient_clipping)
                grad_scaler.step(optimizer)
                grad_scaler.update()

                pbar.update(images.shape[0])
                global_step += 1
                epoch_loss += loss.item()
                pbar.set_postfix(**{'loss (batch)': loss.item()})
        
        loss_train.append(epoch_loss / len(data_tr))
        model.eval()  # testing mode
        temp = model(X_val.float().to(device)).to('cpu')
        loss_valid = criterion(temp.squeeze(1), Y_val.float().squeeze(1)) + dice_loss(temp.squeeze(1), Y_val.float().squeeze(1), multiclass=False)
        Y_val_pred = torch.round(temp)
        loss_val.append(float(loss_valid.detach()))
        Y_hat = Y_val_pred.detach().numpy() # detach and put into cpu
        scores.append((iou_score_model(model, data_val), dice_score_model(model, data_val)))
        
        clear_output(wait=True)
        plt.figure(figsize=(18, 6))
        for k in range(6):
            plt.subplot(3, 6, k+1)
            plt.imshow(np.rollaxis(X_val[k].numpy(), 0, 3), cmap='gray')
            plt.title('Real')
            plt.axis('off')
            plt.subplot(3, 6, k+7)
            plt.imshow(Y_hat[k, 0], cmap='gray')
            plt.title('Output')
            plt.axis('off')
            plt.subplot(3, 6, k+13)
            plt.imshow(Y_val[k, 0], cmap='gray')
            plt.title('Real Mask')
            plt.axis('off')
        plt.show()
        
    return loss_train, loss_val, scores

In [None]:
model.to(device)
max_epochs = 300
loss_func = nn.BCEWithLogitsLoss()
tr_loss, val_loss, scores = train(model, loss_func, max_epochs, data_tr, data_val)

In [None]:
iou_score_model(model, data_ts), dice_score_model(model, data_ts)

In [None]:
import numpy as np


t = np.arange(0., max_epochs, 1)
x1 = [elem for elem in tr_loss]
x2 = [elem for elem in val_loss]
x3 = [elem for elem in scores]

plt.figure(figsize=(18, 6))
fig, (ax1, ax2) = plt.subplots(1, 2, sharex=True, sharey=True)

ax1.plot(t, x1, 'r', t, x2, 'b', lw=2)
ax2.plot(t, x3, 'g')

ax1.set_ylabel('loss')
ax2.set_ylabel('score')

fig.suptitle('smth')

plt.show()

In [None]:
import pickle

with open("unet.pkl", 'wb') as file:
    pickle.dump(model.to('cpu'), file)

In [None]:
def make_mask_from_image(image_path, model):
    img = Image.open(image_path)
    img = prep_image(img,crop_size,size_y,size_x)
    img = np.asarray(img) / 255
    
    model.to(device)
    tensor = torch.unsqueeze(torch.FloatTensor(np.rollaxis(img, 2, 0)), 0).to(device)
    with torch.no_grad():
        mask = np.rollaxis(torch.round(model(tensor).cpu()[0]).numpy(), 0, 3)
    return img, mask

In [None]:
orig, mask = make_mask_from_image("/kaggle/input/satellite-images-of-water-bodies/Water Bodies Dataset/Images/water_body_102.jpg", model)

imshow(orig)
plt.show()

imshow(mask)
plt.show()