In [None]:
#For Google Collab

from google.colab import drive
drive.mount('/content/drive')

!unzip '/content/drive/MyDrive/HydroLens/Dataset/WaterBodies_Dataset.zip'

In [None]:
#Required Dependencies

!pip install --upgrade opencv-contrib-python
!pip install segmentation-models-pytorch
!pip install -U git+https://github.com/albumentations-team/albumentations

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
import os
import glob
import cv2
from PIL import Image
from torch.utils.data import Dataset, DataLoader
from albumentations import HorizontalFlip, VerticalFlip, Rotate
from tqdm import tqdm
import torch.nn.functional as F
import matplotlib.image as mpimg
import albumentations as A

height,width = (256, 256)


torch.cuda.is_available()

In [None]:
class LoadData(Dataset):
    def __init__(self, images_path, masks_path):
        super().__init__()

        self.images_path = images_path
        self.masks_path = masks_path
        self.len = len(images_path)
        self.transform = A.Compose([
            A.Resize(height,width),
            A.HorizontalFlip(),
            A.RandomBrightnessContrast(p=0.5),
            A.ShiftScaleRotate(shift_limit=0.05, scale_limit=0.05, rotate_limit=15, p=0.5),
        ])


    def __getitem__(self, idx):
        img = Image.open(self.images_path[idx])
        mask = Image.open(self.masks_path[idx]).convert('L')

        img,mask=np.array(img),np.array(mask)
        transformed = self.transform(image=img, mask=mask)
        img = transformed['image']
        mask = transformed['mask']

        img = np.transpose(img, (2, 0, 1))
        img = img/255.0
        img = torch.tensor(img)

        mask = np.expand_dims(mask, axis=0)
        mask = mask/255.0
        mask = torch.tensor(mask)

        return img, mask

    def __len__(self):
        return self.len

In [None]:
# For Google Collab

X = sorted(glob.glob('WaterBodiesDataset/Images/*'))
y = sorted(glob.glob('WaterBodiesDataset/Masks/*'))

# For Running Locally
#X = sorted(glob.glob('Dataset/WaterBodies_Dataset/WaterBodiesDataset/Images/*'))
#y = sorted(glob.glob('Dataset/WaterBodies_Dataset/WaterBodiesDataset/Masks/*'))

In [None]:
len(y)

In [None]:
from sklearn.model_selection import train_test_split

X_train, X_val, y_train, y_val = train_test_split(X, y, test_size=0.1, random_state=42)

In [None]:
train_dataset = LoadData(X_train, y_train)
valid_dataset = LoadData(X_val, y_val)

In [None]:
img, mask = train_dataset[18]

f, axarr = plt.subplots(1,2)
axarr[1].imshow(np.squeeze(mask.numpy()), cmap='gray')
axarr[0].imshow(np.transpose(img.numpy(), (1,2,0)))

In [None]:
DEVICE='cuda'

EPOCHS=45
BATCH_SIZE=32
LR=0.001

ratio=0.5 #Various ratios could perform better for visualization
sample_num=18

ENCODER='resnet50'
WEIGHTS='imagenet'

In [None]:
train_loader = DataLoader(
        dataset=train_dataset,
        batch_size=BATCH_SIZE,
        shuffle=True,
        num_workers=2,
)

valid_loader = DataLoader(
        dataset=valid_dataset,
        batch_size=BATCH_SIZE,
        shuffle=False,
        num_workers=2,
)

In [None]:
from torch import nn
import segmentation_models_pytorch as smp
from segmentation_models_pytorch.losses import DiceLoss

In [None]:
class SegmentationModel(nn.Module):
    def __init__(self):
        super(SegmentationModel,self).__init__()

        self.arc=smp.Unet(
            encoder_name=ENCODER,
            encoder_weights=WEIGHTS,
            in_channels=3,
            classes=1,
            activation=None
        )
    def forward(self,images,masks=None):
        logits=self.arc(images)

        if masks!=None:
            loss1=DiceLoss(mode='binary')(logits,masks)
            loss2=nn.BCEWithLogitsLoss()(logits,masks)
            return logits,loss1,loss2
        return logits

In [None]:
model=SegmentationModel()
model.to(DEVICE)

In [None]:
def train_fn(data_loader,model,optimizer):
    model.train()
    total_diceloss=0.0
    total_bceloss=0.0
    for images ,masks in tqdm(data_loader):
        images=images.to(DEVICE, dtype=torch.float32)
        masks=masks.to(DEVICE, dtype=torch.float32)

        optimizer.zero_grad()

        logits,diceloss,bceloss=model(images,masks)
        diceloss.backward(retain_graph=True)
        bceloss.backward()
        optimizer.step()
        total_diceloss+=diceloss.item()
        total_bceloss+=bceloss.item()


    return total_diceloss/len(data_loader),total_bceloss/len(data_loader)

In [None]:
def eval_fn(data_loader,model):
    model.eval()
    total_diceloss=0.0
    total_bceloss=0.0
    with torch.no_grad():
        for images ,masks in tqdm(data_loader):
            images=images.to(DEVICE, dtype=torch.float32)
            masks=masks.to(DEVICE, dtype=torch.float32)

            logits,diceloss,bceloss=model(images,masks)
            total_diceloss+=diceloss.item()
            total_bceloss+=bceloss.item()

        #Visualization
        for i in range(1):
            image,mask=next(iter(valid_loader))
            image=image[sample_num]
            mask=mask[sample_num]
            logits_mask=model(image.to('cuda', dtype=torch.float32).unsqueeze(0))
            pred_mask=torch.sigmoid(logits_mask)
            pred_mask=(pred_mask > ratio)*1.0
            f, axarr = plt.subplots(1,3)
            axarr[1].imshow(np.squeeze(mask.numpy()), cmap='gray')
            axarr[0].imshow(np.transpose(image.numpy(), (1,2,0)))
            axarr[2].imshow(np.transpose(pred_mask.detach().cpu().squeeze(0), (1,2,0)))
            plt.show()

    return total_diceloss/len(data_loader),total_bceloss/len(data_loader)

In [None]:
optimizer=torch.optim.SGD(model.parameters(), lr=LR, momentum=0.9)
#torch.optim.Adam(model.parameters(),lr=LR)

In [None]:
# To check the working, execute every cell apart from this one. (This is the training so will take lot of time)

best_val_dice_loss=np.Inf
best_val_bce_loss=np.Inf

for i in range(EPOCHS):
    train_loss = train_fn(train_loader,model,optimizer)
    valid_loss = eval_fn(valid_loader,model)

    train_dice,train_bce=train_loss
    valid_dice,valid_bce=valid_loss
    print(f'Epochs:{i+1}\nTrain_loss --> Dice: {train_dice} BCE: {train_bce} \nValid_loss --> Dice: {valid_dice} BCE: {valid_bce}')
    if valid_dice < best_val_dice_loss or valid_bce < best_val_bce_loss:
        # For Google Collab
        torch.save(model.state_dict(),'/content/drive/MyDrive/HydroLens/hydrolens_best_model.pt')

        # For Running Locally
        #torch.save(model.state_dict(),'hydrolens_best_model.pt')
        print('Model Saved')
        best_val_dice_loss=valid_dice
        best_val_bce_loss=valid_bce

In [None]:
num=12 # Change for other images in the dataset (Choose number between 0 and 32(excluded))
ratio=0.5

# For Google Collab
model.load_state_dict(torch.load('/content/drive/MyDrive/HydroLens/hydrolens_best_model.pt'))

# For Running Locally
#model.load_state_dict(torch.load('hydrolens_best_model.pt'))

image,mask=next(iter(valid_loader))
image=image[num]
mask=mask[num]
logits_mask=model(image.to('cuda', dtype=torch.float32).unsqueeze(0))
pred_mask=torch.sigmoid(logits_mask)
pred_mask=(pred_mask > ratio)*1.0

f, axarr = plt.subplots(1,3)
axarr[1].imshow(np.squeeze(mask.numpy()), cmap='gray')
axarr[0].imshow(np.transpose(image.numpy(), (1,2,0)))
axarr[2].imshow(np.transpose(pred_mask.detach().cpu().squeeze(0), (1,2,0)))

In [None]:
# Trying the prediction on a random single image downloaded from google

from PIL import Image
import torchvision.transforms as transforms

# Load your single satellite image
image_path = '/content/drive/MyDrive/wallpaperflare.com_wallpaper.jpg'
image = Image.open(image_path)

# Apply the same transformations used during training
transform = A.Compose([
     A.Resize(height,width),
     #Add any other transformations here
])

image = transform(image=np.array(image))['image']
image = np.transpose(image, (2, 0, 1)) / 255.0
image = torch.tensor(image).unsqueeze(0).to(DEVICE, dtype=torch.float32)

# Make a prediction
model.eval()
with torch.no_grad():
    logits = model(image)
    pred_mask = torch.sigmoid(logits)
    pred_mask = (pred_mask > ratio).float()

# Visualize the results

f, axarr = plt.subplots(1, 2)
axarr[0].imshow(np.transpose(image.squeeze().cpu().numpy(), (1, 2, 0)))

# Ensure pred_mask has the correct shape before visualization
pred_mask_np = pred_mask.squeeze().cpu().numpy()
if pred_mask_np.shape[-1] == 1:
    pred_mask_np = pred_mask_np.squeeze(-1)  # Remove the singleton channel dimension if present

axarr[1].imshow(pred_mask_np)
plt.show()

output_mask = Image.fromarray((pred_mask_np * 255).astype(np.uint8))
output_mask.save('/content/drive/MyDrive/predicted_mask.png')
