<a href="https://colab.research.google.com/github/Codamaze/OilSpillDetection/blob/main/Gradio_testing_oil_spill.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [2]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [3]:
import os
import cv2
import numpy as np
from torch.utils.data import Dataset
from albumentations import Compose, RandomCrop, HorizontalFlip, Rotate, GaussNoise, RandomBrightnessContrast, HueSaturationValue
from albumentations.pytorch import ToTensorV2
import matplotlib.colors as mcolors
import matplotlib.pyplot as plt
from tqdm import tqdm

# Define color-to-class mapping (RGB values for each class)
COLOR_MAP = [
    [0, 0, 0],        # Class 0: Black
    [0, 255, 255],    # Class 1: Cyan
    [153, 76, 0],      # Class 2: Brown
    [255, 0, 0],     # Class 3: Red
    [0, 153, 0],      # Class 4: Green
]

def process_mask(rgb_mask, colormap):
    """
    Converts an RGB mask to a one-hot encoded class mask using the provided colormap.
    Args:
        rgb_mask: RGB mask (H, W, 3)
        colormap: List of RGB values for each class.
    Returns:
        One-hot encoded mask (H, W, num_classes)
    """
    output_mask = []
    for i, color in enumerate(colormap):
        cmap = np.all(np.equal(rgb_mask, color), axis=-1).astype(np.uint8)  # Check if pixel matches color and cast to uint8
        output_mask.append(cmap)

    output_mask = np.stack(output_mask, axis=-1)  # Stack the individual class masks to create a one-hot mask
    return output_mask

# Define the augmentation pipeline
transform = Compose([
    RandomCrop(height=256, width=256),             # Random cropping
    # HorizontalFlip(p=0.5),                         # Horizontal flipping
    # Rotate(limit=30, p=0.5),                       # Random rotation
    HueSaturationValue(hue_shift_limit=10, sat_shift_limit=15, val_shift_limit=10, p=0.4),  # Adjust colors slightly
    RandomBrightnessContrast(brightness_limit=0.1, contrast_limit=0.1, p=0.5),  # Mild adjustments
    ToTensorV2(),                      # Convert to PyTorch tensors
])

# Dataset class
class MultiClassOilSpillDataset(Dataset):
    def __init__(self, image_dir, mask_dir, color_mapping, transform=None):
        self.image_dir = image_dir
        self.mask_dir = mask_dir
        self.images = os.listdir(image_dir)
        self.color_mapping = color_mapping
        self.transform = transform

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        img_path = os.path.join(self.image_dir, self.images[idx])
        image = cv2.imread(img_path, cv2.IMREAD_UNCHANGED)  # Read unchanged for SAR images

        mask_path = os.path.join(self.mask_dir, self.images[idx].replace(".jpg", ".png"))
        mask = cv2.imread(mask_path, cv2.IMREAD_COLOR)
        mask_rgb = cv2.cvtColor(mask, cv2.COLOR_BGR2RGB)  # Ensure RGB format

        ''' to apply visulaisation
        plt.figure(figsize=(12, 6))
        plt.subplot(1, 2, 1)
        plt.imshow(image, cmap='gray')
        plt.title("Raw Image")

        plt.subplot(1, 2, 2)
        plt.imshow(mask_rgb)
        plt.title("Raw Mask")

        plt.show()'''

        # Load image and mask
        image = cv2.imread(img_path, cv2.IMREAD_UNCHANGED)  # Load SAR image as is
        mask = cv2.imread(mask_path, cv2.IMREAD_COLOR)      # Load mask as RGB

        # Validate data
        if image is None or mask is None:
            raise FileNotFoundError(f"Missing file: {img_path} or {mask_path}")

        # Normalize SAR image dynamically
        image = (image - np.min(image)) / (np.max(image) - np.min(image))
        image = image.astype(np.float32) # Cast the image to float32


        # Resize image and mask
        image = cv2.resize(image, (256, 256), interpolation=cv2.INTER_LINEAR)
        mask = cv2.resize(mask, (256, 256), interpolation=cv2.INTER_NEAREST)

        # Convert mask to class indices
        mask = cv2.cvtColor(mask, cv2.COLOR_BGR2RGB)
        processed_mask = process_mask(mask, self.color_mapping)

        # Convert one-hot encoded mask to class indices
        processed_mask = np.argmax(processed_mask, axis=-1)


        # Apply augmentations if specified
        if self.transform:
            augmented = self.transform(image=image, mask=processed_mask)
            image = augmented['image']
            mask = augmented['mask']

            image = np.clip(image, 0, 1)

        return image, mask



In [4]:
import torch
import torch.nn as nn
import torchvision.models as models
from torchvision.models.resnet import resnet18
from torchvision.models.resnet import ResNet18_Weights


class UNetResNet18(nn.Module):
    def __init__(self, input_channels, IMG_CLASSES):
        super(UNetResNet18, self).__init__()

        # ResNet18 as the encoder (feature extractor)
        resnet = resnet18(weights=ResNet18_Weights.DEFAULT)

        # Remove the fully connected layer and the classification layer
        self.encoder = nn.Sequential(
            resnet.conv1,  # c1
            resnet.bn1,
            resnet.relu,
            resnet.maxpool,
            resnet.layer1, # c2
            resnet.layer2, # c3
            resnet.layer3, # c4
            resnet.layer4  # c5
        )

        # Decoder blocks (matching the original U-Net structure)
        self.upconv4 = self.upconv_block(512, 256)
        self.upconv3 = self.upconv_block(256 + 256, 128)  # Concatenate with c4
        self.upconv2 = self.upconv_block(128 + 128, 64)   # Concatenate with c3
        self.upconv1 = self.upconv_block(64 + 64, 32)     # Concatenate with c2

        # Final convolution to output IMG_CLASSES channels
        self.final_conv = nn.Conv2d(32, IMG_CLASSES, kernel_size=1)

    def forward(self, x):
        # Encoder forward pass (ResNet18 backbone)
        c1 = self.encoder[0:4](x)    # Conv1 to maxpool
        c2 = self.encoder[4](c1)     # Layer1 (64 channels -> 64)
        c3 = self.encoder[5](c2)     # Layer2 (128 channels -> 128)
        c4 = self.encoder[6](c3)     # Layer3 (256 channels -> 256)
        c5 = self.encoder[7](c4)     # Layer4 (512 channels -> 512)

        # Decoder forward pass (Upsample and concatenate with encoder layers)
        u4 = self.upconv4(c5)
        u3 = self.upconv3(torch.cat([u4, c4], dim=1))  # Skip connection with c4
        u2 = self.upconv2(torch.cat([u3, c3], dim=1))  # Skip connection with c3
        u1 = self.upconv1(torch.cat([u2, c2], dim=1))  # Skip connection with c2

        # Final output layer
        out = self.final_conv(u1)
        return out

    def upconv_block(self, in_channels, out_channels):
        """ Upsampling block (ConvTranspose + Conv2d) """
        return nn.Sequential(
            nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2, padding=0),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1)
        )


In [5]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision.models import mobilenet_v3_large
from torchvision.models.mobilenetv3 import MobileNet_V3_Large_Weights

# Atrous Spatial Pyramid Pooling (ASPP) module
class ASPP(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(ASPP, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d((1, 1))
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False)
        self.gn1 = nn.GroupNorm(16, out_channels)
        self.relu = nn.ReLU()

        self.conv2 = nn.Conv2d(in_channels, out_channels, kernel_size=1, dilation=1, bias=False)
        self.gn2 = nn.GroupNorm(16, out_channels)

        self.conv3 = nn.Conv2d(in_channels, out_channels, kernel_size=3, dilation=6, padding=6, bias=False)
        self.gn3= nn.GroupNorm(16, out_channels)

        self.conv4 = nn.Conv2d(in_channels, out_channels, kernel_size=3, dilation=12, padding=12, bias=False)
        self.gn4= nn.GroupNorm(16, out_channels)

        self.conv5 = nn.Conv2d(in_channels, out_channels, kernel_size=3, dilation=18, padding=18, bias=False)
        self.gn5= nn.GroupNorm(16, out_channels)

        self.output = nn.Conv2d(out_channels * 5, out_channels, kernel_size=1, bias=False)
        self.out_gn= nn.GroupNorm(16, out_channels)

    def forward(self, x):
        h, w = x.size(2), x.size(3)

        y1 = self.avg_pool(x)
        y1 = self.conv1(y1)
        y1 = self.gn1(y1)
        y1 = self.relu(y1)
        y1 = F.interpolate(y1, size=(h, w), mode="bilinear", align_corners=False)

        y2 = self.conv2(x)
        y2 = self.gn2(y2)
        y2 = self.relu(y2)

        y3 = self.conv3(x)
        y3 = self.gn3(y3)
        y3 = self.relu(y3)

        y4 = self.conv4(x)
        y4 = self.gn4(y4)
        y4 = self.relu(y4)

        y5 = self.conv5(x)
        y5 = self.gn5(y5)
        y5 = self.relu(y5)

        y = torch.cat([y1, y2, y3, y4, y5], dim=1)
        y = self.output(y)
        y = self.out_gn(y)
        return self.relu(y)


# DeepLabV3+ module
class DeepLabV3Plus(nn.Module):
    def __init__(self, num_classes):
        super(DeepLabV3Plus, self).__init__()
        self.encoder = mobilenet_v3_large(weights=MobileNet_V3_Large_Weights.DEFAULT).features

        # Extract intermediate layers
        self.low_level_idx = 3  # Low-level features from MobileNetV3
        self.high_level_idx = 16  # High-level features from MobileNetV3

        self.aspp = ASPP(in_channels=160, out_channels=256) # 160 channels for MobileNetV3 model
        # Low-level feature projection
        self.low_level_conv = nn.Conv2d(24, 48, kernel_size=1, bias=False)
        self.low_level_gn = nn.GroupNorm(16, 48)
        self.low_level_relu = nn.ReLU()

        # Decoder
        self.concat_conv1 = nn.Conv2d(256 + 48, 256, kernel_size=3, padding=1, bias=False)
        self.concat_gn1 = nn.GroupNorm(32, 256)
        self.concat_relu1 = nn.ReLU()

        self.concat_conv2 = nn.Conv2d(256, 256, kernel_size=3, padding=1, bias=False)
        self.concat_gn2= nn.GroupNorm(32, 256)
        self.concat_relu2 = nn.ReLU()

        self.final_conv = nn.Conv2d(256, num_classes, kernel_size=1)

    def forward(self, x):
        h, w = x.size(2), x.size(3)

        # Encoder
        low_level_features = self.encoder[:self.low_level_idx](x)
        high_level_features = self.encoder[:self.high_level_idx](x)

        # ASPP
        x = self.aspp(high_level_features)
        x = F.interpolate(x, size=(h // 4, w // 4), mode="bilinear", align_corners=False)

        # Low-level features
        low_level_features = self.low_level_conv(low_level_features)
        low_level_features = self.low_level_gn(low_level_features)
        low_level_features = self.low_level_relu(low_level_features)

        # Concatenate low-level and ASPP features
        x = torch.cat([x, low_level_features], dim=1)

        # Decoder
        x = self.concat_conv1(x)
        x = self.concat_gn1(x)
        x = self.concat_relu1(x)

        x = self.concat_conv2(x)
        x = self.concat_gn2(x)
        x = self.concat_relu2(x)

        # Upsample to original size
        x = F.interpolate(x, size=(h, w), mode="bilinear", align_corners=False)

        # Final classification layer
        x = self.final_conv(x)
        return x


In [6]:
from albumentations import Compose, RandomCrop, HorizontalFlip, Rotate, GaussNoise, RandomBrightnessContrast, HueSaturationValue
from albumentations.pytorch import ToTensorV2
transform = Compose([
    RandomCrop(height=256, width=256),             # Random cropping
    HorizontalFlip(p=0.5),                         # Horizontal flipping
    Rotate(limit=30, p=0.5),                       # Random rotation
    HueSaturationValue(hue_shift_limit=10, sat_shift_limit=15, val_shift_limit=10, p=0.4),  # Adjust colors slightly
    RandomBrightnessContrast(brightness_limit=0.1, contrast_limit=0.1, p=0.5),  # Mild adjustments
    ToTensorV2(),                      # Convert to PyTorch tensors
])

In [7]:
import torch
from PIL import Image
import numpy as np

# Load your models
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Instantiate your models
unet_model = UNetResNet18(input_channels=3, IMG_CLASSES=5)
deeplab_model = DeepLabV3Plus(num_classes=5)


unet_model.load_state_dict(torch.load("/content/drive/MyDrive/unetresnet18_100.pth", map_location='cpu'))
unet_model.to(device)
deeplab_model.load_state_dict(torch.load("/content/drive/MyDrive/deeplabv3+_mobilentv3_50.pth", map_location='cpu'))
deeplab_model.to(device)

unet_model.eval()
deeplab_model.eval()

Downloading: "https://download.pytorch.org/models/resnet18-f37072fd.pth" to /root/.cache/torch/hub/checkpoints/resnet18-f37072fd.pth


100%|██████████| 44.7M/44.7M [00:00<00:00, 104MB/s]


Downloading: "https://download.pytorch.org/models/mobilenet_v3_large-5c1a4163.pth" to /root/.cache/torch/hub/checkpoints/mobilenet_v3_large-5c1a4163.pth


100%|██████████| 21.1M/21.1M [00:00<00:00, 92.2MB/s]


DeepLabV3Plus(
  (encoder): Sequential(
    (0): Conv2dNormActivation(
      (0): Conv2d(3, 16, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (1): BatchNorm2d(16, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
      (2): Hardswish()
    )
    (1): InvertedResidual(
      (block): Sequential(
        (0): Conv2dNormActivation(
          (0): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=16, bias=False)
          (1): BatchNorm2d(16, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
          (2): ReLU(inplace=True)
        )
        (1): Conv2dNormActivation(
          (0): Conv2d(16, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (1): BatchNorm2d(16, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
        )
      )
    )
    (2): InvertedResidual(
      (block): Sequential(
        (0): Conv2dNormActivation(
          (0): Conv2d(16, 64, kernel_size=(1, 1), stride=(1, 1), b

In [8]:
import gradio as gr
import cv2
import numpy as np
import torch
from PIL import Image
from albumentations import Compose, HueSaturationValue, RandomBrightnessContrast, ToTensorV2

# Albumentations Transform (Your Original Preprocessing)
transform = Compose([
    HueSaturationValue(hue_shift_limit=10, sat_shift_limit=15, val_shift_limit=10, p=0.4),
    RandomBrightnessContrast(brightness_limit=0.1, contrast_limit=0.1, p=0.5),
    ToTensorV2(),
])

COLOR_MAP = [
    [0, 0, 0],        # Class 0: Black
    [0, 255, 255],    # Class 1: Cyan
    [153, 76, 0],     # Class 2: Brown
    [255, 0, 0],      # Class 3: Red
    [0, 153, 0],      # Class 4: Green
]

def segment_image(image):
    """ Runs segmentation using both DeepLabV3+ and U-Net. """

    image = np.array(image)
    image_normalized = (image - np.min(image)) / (np.max(image) - np.min(image))
    image_normalized = image_normalized.astype(np.float32)
    image_resized = cv2.resize(image_normalized, (256, 256), interpolation=cv2.INTER_LINEAR)

    augmented = transform(image=image_resized)
    input_tensor = augmented['image'].unsqueeze(0).to(device).float()

    # Get predictions from both models
    def get_prediction(model,target_size=(256,256)):
        with torch.no_grad():
            pred = model(input_tensor)
            pred = torch.argmax(pred, dim=1).squeeze(0).cpu().numpy()

        pred_mask = np.zeros((pred.shape[0], pred.shape[1], 3), dtype=np.uint8)

        for class_idx, color in enumerate(COLOR_MAP):
            pred_mask[pred == class_idx] = color

        pred_mask_resized = cv2.resize(pred_mask, target_size, interpolation=cv2.INTER_NEAREST)

        return Image.fromarray(pred_mask_resized)
        # return Image.fromarray(pred_mask)

    deeplab_pred = get_prediction(deeplab_model)
    unet_pred = get_prediction(unet_model)

    return deeplab_pred, unet_pred

# Gradio Interface
iface = gr.Interface(
    fn=segment_image,
    inputs=gr.Image(type="pil"),
    outputs=[
        gr.Image(label="DeepLabV3+ Prediction", type="pil"),
        gr.Image(label="U-Net Prediction", type="pil")
    ],
    title="Oil Spill Detection",
    description="Upload an image to see segmentation results from DeepLabV3+ and U-Net.",
    allow_flagging="never"
)


iface.launch(debug=True)




It looks like you are running Gradio on a hosted Jupyter notebook, which requires `share=True`. Automatically setting `share=True` (you can turn this off by setting `share=False` in `launch()` explicitly).

Colab notebook detected. This cell will run indefinitely so that you can see errors and logs. To turn off, set debug=False in launch().
* Running on public URL: https://ab50a944f1af76cac6.gradio.live

This share link expires in 1 week. For free permanent hosting and GPU upgrades, run `gradio deploy` from the terminal in the working directory to deploy to Hugging Face Spaces (https://huggingface.co/spaces)


Keyboard interruption in main thread... closing server.
Killing tunnel 127.0.0.1:7860 <> https://ab50a944f1af76cac6.gradio.live




In [9]:
!pip install gradio

