# Advanced Computational Techniques for Big Imaging and Signal Data
The Airbus Ship Detection Challenge: /kaggle/input/airbus-ship-detection

_A.Y. 2023-2024_

_Alessio De Luca [919790]_

## Kaggle function to run on colab

In [None]:

# IMPORTANT: RUN THIS CELL IN ORDER TO IMPORT YOUR KAGGLE DATA SOURCES
# TO THE CORRECT LOCATION (/kaggle/input) IN YOUR NOTEBOOK,
# THEN FEEL FREE TO DELETE THIS CELL.
# NOTE: THIS NOTEBOOK ENVIRONMENT DIFFERS FROM KAGGLE'S PYTHON
# ENVIRONMENT SO THERE MAY BE MISSING LIBRARIES USED BY YOUR
# NOTEBOOK.

#NOTE: this takes several minutes, if possible run on kaggle

import os
import sys
from tempfile import NamedTemporaryFile
from urllib.request import urlopen
from urllib.parse import unquote, urlparse
from urllib.error import HTTPError
from zipfile import ZipFile
import tarfile
import shutil

CHUNK_SIZE = 40960
DATA_SOURCE_MAPPING = 'airbus-ship-detection:https%3A%2F%2Fstorage.googleapis.com%2Fkaggle-competitions-data%2Fkaggle-v2%2F9988%2F868324%2Fbundle%2Farchive.zip%3FX-Goog-Algorithm%3DGOOG4-RSA-SHA256%26X-Goog-Credential%3Dgcp-kaggle-com%2540kaggle-161607.iam.gserviceaccount.com%252F20240617%252Fauto%252Fstorage%252Fgoog4_request%26X-Goog-Date%3D20240617T081844Z%26X-Goog-Expires%3D259200%26X-Goog-SignedHeaders%3Dhost%26X-Goog-Signature%3Db7cd8f1060f63ee891f15c1da87af83e90a91a820360881cc45235e7534c09436f7786c8e71c2367e425d9cdfc308d626d9556e0ad633fc67851d4590f2c8a69bcce02bfed9c46321a081c1d807e21ddd7b81896835cfd1eeccfc6db0b2f94afaafcedb4f04502e908d7e66bd9c97eedd7378fc379ec160766ff01b6a968f652198eb9318b73d24a2f65d086069018f0eeda4f7e908378ed99d959a283867e601235ca8bf1c2db9e16813692cbaf85f61f4815a24d4e404952b23ab182f6236cb9f0b76316d0ac9f7d71701407f35b7416e8ff5e6860774b244853aa2f6743563391b97c028f9bba9206e59ce36ce0ac3f66a88ada4e5afe3934d4fe6874ee51'

KAGGLE_INPUT_PATH='/kaggle/input'
KAGGLE_WORKING_PATH='/kaggle/working'
KAGGLE_SYMLINK='kaggle'

!umount /kaggle/input/ 2> /dev/null
shutil.rmtree('/kaggle/input', ignore_errors=True)
os.makedirs(KAGGLE_INPUT_PATH, 0o777, exist_ok=True)
os.makedirs(KAGGLE_WORKING_PATH, 0o777, exist_ok=True)

try:
  os.symlink(KAGGLE_INPUT_PATH, os.path.join("..", 'input'), target_is_directory=True)
except FileExistsError:
  pass
try:
  os.symlink(KAGGLE_WORKING_PATH, os.path.join("..", 'working'), target_is_directory=True)
except FileExistsError:
  pass

for data_source_mapping in DATA_SOURCE_MAPPING.split(','):
    directory, download_url_encoded = data_source_mapping.split(':')
    download_url = unquote(download_url_encoded)
    filename = urlparse(download_url).path
    destination_path = os.path.join(KAGGLE_INPUT_PATH, directory)
    try:
        with urlopen(download_url) as fileres, NamedTemporaryFile() as tfile:
            total_length = fileres.headers['content-length']
            print(f'Downloading {directory}, {total_length} bytes compressed')
            dl = 0
            data = fileres.read(CHUNK_SIZE)
            while len(data) > 0:
                dl += len(data)
                tfile.write(data)
                done = int(50 * dl / int(total_length))
                sys.stdout.write(f"\r[{'=' * done}{' ' * (50-done)}] {dl} bytes downloaded")
                sys.stdout.flush()
                data = fileres.read(CHUNK_SIZE)
            if filename.endswith('.zip'):
              with ZipFile(tfile) as zfile:
                zfile.extractall(destination_path)
            else:
              with tarfile.open(tfile.name) as tarfile:
                tarfile.extractall(destination_path)
            print(f'\nDownloaded and uncompressed: {directory}')
    except HTTPError as e:
        print(f'Failed to load (likely expired) {download_url} to path {destination_path}')
        continue
    except OSError as e:
        print(f'Failed to load {download_url} to path {destination_path}')
        continue

print('Data source import complete.')


## Libraries

In [None]:
!pip install torchsummary
!pip install segmentation-models-pytorch

In [None]:
import torch
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import cv2
import glob
from tqdm import tqdm
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from sklearn.model_selection import train_test_split
import torch.nn as nn
import torch.optim as optim
from torchvision import models
from skimage.measure import label, regionprops
from skimage.color import label2rgb
from skimage.measure import label as skimage_label
from skimage.measure import find_contours
from collections import defaultdict
import random
from PIL import Image

import segmentation_models_pytorch as smp

from torchsummary import summary
from torchvision import models


In [None]:
# run this if you are using colab
# from google.colab import drive
# drive.mount('/content/drive')

## Data Processing

### Data Loading

In [None]:
# verify if gpu is available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device used:", device)

In [None]:
# Get all images from the train set
train_image_dir = "/kaggle/input/airbus-ship-detection/train_v2/"

# load datas
train_df = pd.read_csv("/kaggle/input/airbus-ship-detection/train_ship_segmentations_v2.csv")
train_images_paths = glob.glob(train_image_dir + '*.jpg')
print("Dataset Dimension:", len(train_images_paths), "images")

In [None]:
train_df.head()

In [None]:
#Visualize random image shape
random_image_path = np.random.choice(train_images_paths)


random_image = cv2.imread(random_image_path)

image_shape = random_image.shape

print("Images shape is:", image_shape)

Data Preprocessing

In [None]:
# Function to create a random subset (modify seed value to change 'randomness')
def create_random_subset(image_paths, subset_fraction, seed=42):
    np.random.seed(seed)
    subset_size = int(len(image_paths) * subset_fraction)
    indices = np.random.choice(len(image_paths), subset_size, replace=False)
    subset_paths = [image_paths[i] for i in indices]
    return subset_paths

# Creation of subset
subset_fraction = 0.01 # modify to increase/decrease subset fraction (values over 4 % don't work due to RAM limits)
train_paths = create_random_subset(train_images_paths, subset_fraction)

print("Dimension of reduced dataset:", len(train_paths), "images")

In [None]:
# Function that filters images with no ships
def filter_images_with_ships(train_df, train_paths, keep_ratio=0.1): #modify keep ratio due increase/decrease number of shipless images
    images_with_ships = []
    images_without_ships = []

    for image_path in tqdm(train_paths):
        image_id = image_path.split('/')[-1]
        masks = train_df[train_df['ImageId'] == image_id]['EncodedPixels'].values
        if len(masks) == 1 and pd.isna(masks[0]):
            images_without_ships.append(image_path)
        else:
            images_with_ships.append(image_path)

    n_keep = int(len(images_without_ships) * keep_ratio)
    filtered_images = images_with_ships + images_without_ships[:n_keep]

    return filtered_images

# Apply filter function
filtered_images_paths = filter_images_with_ships(train_df, train_paths, keep_ratio=0.1)
delta_images= len(train_paths) - len(filtered_images_paths)

In [None]:
print("Dimension of reduced dataset after filter:", len(filtered_images_paths))
print("A total of ", delta_images, "images without ships were removed.")

In [None]:
# function to decode csv format masks to images
def rle_decode(mask_rle, shape):
    s = mask_rle.split()
    starts, lengths = [np.asarray(x, dtype=int) for x in (s[0:][::2], s[1:][::2])]
    starts -= 1
    ends = starts + lengths
    img = np.zeros(shape[0] * shape[1], dtype=np.uint8)
    for lo, hi in zip(starts, ends):
        img[lo:hi] = 1
    return img.reshape(shape).T  

In [None]:
# Train-val-test split
train_paths, val_paths = train_test_split(filtered_images_paths, test_size=0.6, random_state=42)
test_paths, val_paths = train_test_split(val_paths, test_size=0.5, random_state=42)

print(f"Train set: {len(train_paths)}, Validation set: {len(val_paths)}, Test set: {len(test_paths)}")

### Data Augmentation

In [None]:
#augmentation functions
def apply_augmentation(image, mask, augmentation_type):
    if augmentation_type == 'flipping':
        image = cv2.flip(image, 1)
        mask = cv2.flip(mask, 1)
    elif augmentation_type == 'rotation':
        angle = random.choice([90, 180, 270])
        rows, cols, _ = image.shape
        M = cv2.getRotationMatrix2D((cols / 2, rows / 2), angle, 1)
        image = cv2.warpAffine(image, M, (cols, rows), flags=cv2.INTER_LINEAR)
        mask = cv2.warpAffine(mask, M, (cols, rows), flags=cv2.INTER_NEAREST)
    elif augmentation_type == 'zoom':
        zoom_factor = random.uniform(1.1, 1.5)
        rows, cols, _ = image.shape
        M = cv2.getRotationMatrix2D((cols / 2, rows / 2), 0, zoom_factor)
        image = cv2.warpAffine(image, M, (cols, rows), flags=cv2.INTER_LINEAR)
        mask = cv2.warpAffine(mask, M, (cols, rows), flags=cv2.INTER_NEAREST)
    elif augmentation_type == 'clouds':
        noise = np.random.normal(loc=128, scale=64, size=image.shape).astype(np.uint8)
        image = cv2.addWeighted(image, 0.5, noise, 0.5, 0)
    else:
        raise ValueError("Augmentation type not valid")
    return image, mask

def augment_dataset(train_paths, train_df):
    augmented_images = []
    augmented_masks = []
    augmentation_types = ['flipping', 'rotation', 'zoom', 'clouds']

    for image_path in tqdm(train_paths):
        image_id = image_path.split('/')[-1]
        image = cv2.imread(image_path)
        mask_rle = train_df[train_df['ImageId'] == image_id]['EncodedPixels'].values
        if len(mask_rle) == 1 and pd.isna(mask_rle[0]):
            mask = np.zeros((image_shape[1], image_shape[0]), dtype=np.uint8)
        else:
            mask = np.zeros((image_shape[1], image_shape[0]), dtype=np.uint8)
            for rle in mask_rle:
                mask += rle_decode(rle, (image_shape[1], image_shape[0]))

        augmented_images.append(image)
        augmented_masks.append(mask)

        for aug_type in augmentation_types:
            aug_image, aug_mask = apply_augmentation(image.copy(), mask.copy(), aug_type)
            augmented_images.append(aug_image)
            augmented_masks.append(aug_mask)

    return augmented_images, augmented_masks

In [None]:
# Apply augmentation only on train dataset
augmented_images, augmented_masks = augment_dataset(train_paths, train_df)


In [None]:
print(f"Total train images after augmentation: {len(augmented_images)}")
train_split= len(augmented_images)/(len(augmented_images) + len(val_paths)+ len(test_paths))
print(f"New train-val-split is:", round(train_split,2)*100,"% train", (100-round(train_split,2)*100)/2,"% val", (100-round(train_split,2)*100)/2,"% test",)

In [None]:
def show_images_and_masks(images, masks, titles, title):
    fig, axs = plt.subplots(2, len(images), figsize=(20, 10))
    for i, (img, mask, ttl) in enumerate(zip(images, masks, titles)):
        axs[0, i].imshow(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
        axs[0, i].set_title(ttl)
        axs[0, i].axis('off')
        axs[1, i].imshow(mask, cmap='gray')
        axs[1, i].axis('off')
    plt.suptitle(title)
    plt.show()

def show_original_and_augmented_with_masks(image, augmented_images, masks, augmented_masks, titles, title):
    images_to_show = [image] + augmented_images
    masks_to_show = [masks] + augmented_masks
    titles = ["Immagine originale"] + titles

    show_images_and_masks(images_to_show, masks_to_show, titles, title)

In [None]:
# Take an example from the dataset and display the original image, augmented images, and associated masks
random_idx = random.randint(0, len(train_paths) - 1)
image_path = train_paths[random_idx]
image = cv2.imread(image_path)

# Get the mask for the original image
image_id = image_path.split('/')[-1]
mask_rle = train_df[train_df['ImageId'] == image_id]['EncodedPixels'].values
if len(mask_rle) == 1 and pd.isna(mask_rle[0]):
    mask = np.zeros((image_shape[1], image_shape[0]), dtype=np.uint8)
else:
    mask = np.zeros((image_shape[1], image_shape[0]), dtype=np.uint8)
    for rle in mask_rle:
        mask += rle_decode(rle, (image_shape[1], image_shape[0]))

# Generate augmented images and their masks
aug_images = []
aug_masks = []
titles = ['Flipping', 'Rotation', 'Zoom', 'Clouds']
for aug_type in titles:
    aug_image, aug_mask = apply_augmentation(image.copy(), mask.copy(), aug_type.lower())
    aug_images.append(aug_image)
    aug_masks.append(aug_mask)

# Display the images and masks with titles
show_original_and_augmented_with_masks(image, aug_images, mask, aug_masks, titles, title="Original and Augmented Images with Masks")


### Dataset creation

In [None]:
class AugmentedDataset(Dataset):
    def __init__(self, images, masks, transform=None):
        self.images = images
        self.masks = masks
        self.transform = transform

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        image = self.images[idx]
        mask = self.masks[idx]
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) #Ensure image is RGB
        image = image.astype(np.float32) / 255.0  # Convert image to float32 and normalize
        mask = mask.astype(np.float32)  # Ensure mask is in float32

        if self.transform:
            image = self.transform(image)

        return image, mask


class OriginalDataset(Dataset):
    def __init__(self, image_paths, masks_df, transform=None):
        self.image_paths = image_paths
        self.masks_df = masks_df
        self.transform = transform

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        image_path = self.image_paths[idx]
        image = cv2.imread(image_path)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) #Ensure image is RGB
        image = image.astype(np.float32) / 255.0  # Convert image to float32 and normalize

        image_id = image_path.split('/')[-1]
        masks = self.masks_df[self.masks_df['ImageId'] == image_id]['EncodedPixels'].tolist()

        mask_combined = np.zeros((768, 768), dtype=np.float32)  # Initialize mask as float32
        if any(pd.notna(m) for m in masks):
            for mask in masks:
                if mask is not np.nan:
                    mask_decoded = rle_decode(mask, (768, 768))
                    mask_combined += mask_decoded

        if self.transform:
            image = self.transform(image)
            mask = mask_combined

        return image, mask


In [None]:
transform = transforms.Compose([
    transforms.ToTensor(),
])

In [None]:
# Dataset creation
augmented_dataset = AugmentedDataset(augmented_images, augmented_masks, transform=transform)
val_dataset = OriginalDataset(val_paths, train_df, transform=transform)
test_dataset = OriginalDataset(test_paths, train_df, transform=transform)

In [None]:
batch_size = 8

# Create DataLoader
train_loader = DataLoader(augmented_dataset, batch_size, shuffle=True, num_workers=4)
val_loader = DataLoader(val_dataset, batch_size, shuffle=True, num_workers=4)
test_loader = DataLoader(test_dataset, batch_size, shuffle=True, num_workers=4)

In [None]:
print("Train loader size:", len(train_loader))
print("Validation loader size:", len(val_loader))
print("Test loader size:", len(test_loader))

In [None]:
def visualize_batch(loader, title_prefix):
    batch = next(iter(loader))
    images = batch[0]
    masks = batch[1]

    print("Batch size of images:", images.size())
    print("Batch size of masks:", masks.size())

    fig, axes = plt.subplots(len(images), 2, figsize=(15, len(images) * 5))
    axes = axes.reshape(-1, 2)  # Reshape axes if it is one-dimensional

    for i in range(len(images)):
        img = images[i].permute(1, 2, 0).numpy()
        mask = masks[i].numpy()

        axes[i, 0].imshow(img)
        axes[i, 0].set_title(f"{title_prefix} Image {i+1}")
        axes[i, 0].axis('off')

        axes[i, 1].imshow(mask, cmap='gray')
        axes[i, 1].set_title(f"{title_prefix} Mask {i+1}")
        axes[i, 1].axis('off')

    plt.tight_layout()
    plt.show()


In [None]:
# Visualize a batch of training data
visualize_batch(train_loader, title_prefix="Train")

# # Visualize a batch of validation data
# visualize_batch(val_loader, title_prefix="Validation")

# # Visualize a batch of test data
# visualize_batch(test_loader, title_prefix="Test")


### Ship counting

In [None]:
# Function to count the number of ships in an image and obtain the colored mask
def count_and_color_ships(mask):
    labeled_mask = label(mask)
    num_ships = labeled_mask.max()  # The number of labels corresponds to the number of connected components
    colored_mask = label2rgb(labeled_mask, bg_label=0)
    return num_ships, colored_mask


In [None]:
MAX_EXAMPLES_PER_CLASS = 1  # Set the maximum limit of examples per class

def get_ship_counts_and_examples(dataset):
    ship_counts = defaultdict(int)
    example_images = defaultdict(list)

    # Iterate through the dataset
    for i in tqdm(range(len(dataset)), desc="Processing dataset"):
        image, mask = dataset[i]
        num_ships, colored_mask = count_and_color_ships(mask)
        ship_counts[num_ships] += 1

        # Add only if the number of examples per class does not exceed the maximum limit
        if len(example_images[num_ships]) < MAX_EXAMPLES_PER_CLASS:
            example_images[num_ships].append((image, mask, colored_mask))

    return ship_counts, example_images

def plot_class_histogram(ship_counts, dataset_name):
    plt.figure(figsize=(12, 6))
    colors = plt.cm.tab10.colors[:len(ship_counts)]
    bars = plt.bar(ship_counts.keys(), ship_counts.values(), color=colors)

    for bar in bars:
        yval = bar.get_height()
        plt.text(bar.get_x() + bar.get_width()/2, yval, int(yval), va='bottom', ha='center')

    plt.xticks(range(max(ship_counts.keys()) + 1))
    plt.xlabel("Number of ships per image")
    plt.ylabel("Number of images")
    plt.title(f"Distribution of images by number of ships - {dataset_name}")
    plt.show()

def plot_class_examples(example_images, dataset_name):
    for num_ships in sorted(example_images.keys()):
        for image_data in example_images[num_ships]:
            image, original_mask, colored_mask = image_data

            original_image = image.permute(1, 2, 0).numpy()

            fig, axes = plt.subplots(1, 3, figsize=(15, 5))

            axes[0].imshow(original_image)
            axes[0].set_title(f"Original Image - {dataset_name}")
            axes[0].axis('off')

            axes[1].imshow(original_mask, cmap='gray')
            axes[1].set_title(f"Original Label Mask - {dataset_name}")
            axes[1].axis('off')

            axes[2].imshow(colored_mask)
            axes[2].set_title(f"Mask with {num_ships} {'ship' if num_ships == 1 else 'ships'} - {dataset_name}")
            axes[2].axis('off')

            plt.tight_layout()
            plt.show()

def plot_ship_presence_histogram(ship_counts, dataset_name):
    no_ships = ship_counts[0]
    one_or_more_ships = sum(count for ships, count in ship_counts.items() if ships > 0)
    total_images = no_ships + one_or_more_ships

    categories = ['0 Ships', '1+ Ships']
    values = [no_ships, one_or_more_ships]
    percentages = [f"{(v/total_images)*100:.2f}%" for v in values]

    plt.figure(figsize=(10, 5))
    bars = plt.bar(categories, values, color=['skyblue', 'orange'])

    for bar, value, percentage in zip(bars, values, percentages):
        yval = bar.get_height()
        plt.text(bar.get_x() + bar.get_width()/2, yval/2, f'{int(value)}\n({percentage})', va='center', ha='center')

    plt.xlabel("Categories")
    plt.ylabel("Number of images")
    plt.title(f"Number of images with 0 ships vs 1+ ships - {dataset_name}")
    plt.show()


In [None]:
# count ships and get one sample for each count
ship_counts_train, example_images_train = get_ship_counts_and_examples(augmented_dataset)

In [None]:
#histogram plotting
plot_ship_presence_histogram(ship_counts_train, "Train")

plot_class_histogram(ship_counts_train, "Train")

In [None]:
# show images for each ship count
plot_class_examples(example_images_train, "Train")

In [None]:
# count ships and get one sample for each count
ship_counts_val, example_images_val = get_ship_counts_and_examples(val_dataset)

In [None]:
#histogram plotting
plot_ship_presence_histogram(ship_counts_val, "Validation")

plot_class_histogram(ship_counts_val, "Validation")


In [None]:
# show images for each ship count
plot_class_examples(example_images_val, "Validation")

In [None]:
# count ships and get one sample for each count
ship_counts_test, example_images_test = get_ship_counts_and_examples(test_dataset)

In [None]:
#histogram plotting
plot_ship_presence_histogram(ship_counts_test, "Test")

plot_class_histogram(ship_counts_test, "Test")

In [None]:
# show images for each ship count
plot_class_examples(example_images_test, "Test")

## Segmentation

### Model definition

In [None]:
# Create ResNet101 model
model = smp.Unet(
    encoder_name="resnet101",       
    encoder_weights="imagenet",     
    in_channels=3,                  
    classes=1                       
)


model = model.to(device)

# Allow all layers to be trainable
for param in model.encoder.parameters():
    param.requires_grad = False

## Optional, freeze part of the decoder 
# for param in model.decoder.parameters():
#     param.requires_grad = False


In [None]:
### RUN THIS CELL TO LOAD THE MODEL WEIGHTS AND AVOID TRAIN/VAL ###
model_path = '/kaggle/input/resnet101deluca/pytorch/resnet101deluca/1/ResNet101.pth' #define your path to the model
# Load the state_dict
state_dict = torch.load(model_path, map_location=torch.device('cpu')) #set device to cpu
model.load_state_dict(state_dict)

In [None]:

# Define loss function
criterion = smp.losses.DiceLoss(smp.losses.BINARY_MODE)

# Define optimizer
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4, weight_decay=1e-5)

In [None]:
# Summary of the model
summary(model, input_size=(3, 768, 768))

In [None]:
# Define a threshold for binary segmentation
threshold = 0.5

### Training and validation

In [None]:
def train_model(model, criterion, optimizer, train_loader, val_loader, num_epochs, patience):
    train_losses = []
    val_losses = []
    train_accuracies = []
    val_accuracies = []

    best_val_loss = float('inf')  # Initialize best_val_loss to infinity
    best_model_wts = None
    epochs_no_improve = 0

    for epoch in range(num_epochs):
        model.train()  # Set model to training mode
        train_loss = 0.0
        train_correct = 0
        train_total = 0

        for images, masks in tqdm(train_loader, desc=f"Epoch {epoch + 1}/{num_epochs}"):
            images = images.to(device).float()
            masks = masks.to(device).float().unsqueeze(1)  # Add channel dimension to masks

            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, masks)
            loss.backward()
            optimizer.step()

            train_loss += loss.item()

            # Calculate accuracy
            predicted_masks = (outputs > threshold).float()
            train_correct += (predicted_masks == masks).sum().item()
            train_total += masks.numel()

        train_loss /= len(train_loader)
        train_accuracy = train_correct / train_total
        train_losses.append(train_loss)
        train_accuracies.append(train_accuracy)

        model.eval()  # Set model to evaluation mode
        val_loss = 0.0
        val_correct = 0
        val_total = 0

        with torch.no_grad():
            for images, masks in tqdm(val_loader, desc="Validation"):
                images = images.to(device).float()
                masks = masks.to(device).float().unsqueeze(1)  # Add channel dimension to masks

                outputs = model(images)
                loss = criterion(outputs, masks)
                val_loss += loss.item()

                # Calculate accuracy
                predicted_masks = (outputs > threshold).float()
                val_correct += (predicted_masks == masks).sum().item()
                val_total += masks.numel()

        val_loss /= len(val_loader)
        val_accuracy = val_correct / val_total
        val_losses.append(val_loss)
        val_accuracies.append(val_accuracy)

        # Check for improvement
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            best_model_wts = model.state_dict()
            epochs_no_improve = 0
        else:
            epochs_no_improve += 1

        # Early stopping
        if epochs_no_improve >= patience:
            print("Early stopping!")
            break

        print(f"Epoch {epoch + 1}/{num_epochs}")
        print(f"Train Loss: {train_loss:.4f}, Train Accuracy: {train_accuracy:.4f}")
        print(f"Val Loss: {val_loss:.4f}, Val Accuracy: {val_accuracy:.4f}")

    # Load best model weights
    if best_model_wts is not None:
        model.load_state_dict(best_model_wts)
        print(f"Best model loaded with validation loss: {best_val_loss:.4f}")

    return train_losses, val_losses, train_accuracies, val_accuracies, best_model_wts


In [None]:
# Train model
train_losses, val_losses, train_accuracies, val_accuracies, best_model_wts = train_model(
    model, criterion, optimizer, train_loader, val_loader, num_epochs=50, patience=10)

In [None]:
# Cell to save the model
PATH = '/kaggle/working/ResNet101.pth' #adjust this this path if running on colab

# Save best model
if best_model_wts is not None:
    torch.save(best_model_wts, PATH)

In [None]:
#visualization
def plot_metrics(train_losses, val_losses, train_accuracies, val_accuracies):
    epochs = range(1, len(train_losses) + 1)

    # Convert tensors to numpy arrays
    train_losses_np = np.array(train_losses)
    val_losses_np = np.array(val_losses)
    train_accuracies_np = np.array([item.cpu().numpy() if hasattr(item, 'cpu') else item for item in train_accuracies])
    val_accuracies_np = np.array([item.cpu().numpy() if hasattr(item, 'cpu') else item for item in val_accuracies])

    plt.figure(figsize=(12, 4))

    plt.subplot(1, 2, 1)
    plt.plot(epochs, train_losses_np, 'r-', label='Training Loss')
    plt.plot(epochs, val_losses_np, 'b-', label='Validation Loss')
    plt.title('Loss')
    plt.xlabel('Epochs')
    plt.ylabel('Loss')
    plt.legend()

    plt.subplot(1, 2, 2)
    plt.plot(epochs, train_accuracies_np, 'r-', label='Training Accuracy')
    plt.plot(epochs, val_accuracies_np, 'b-', label='Validation Accuracy')
    plt.title('Accuracy')
    plt.xlabel('Epochs')
    plt.ylabel('Accuracy')
    plt.legend()

    plt.show()

plot_metrics(train_losses, val_losses, train_accuracies, val_accuracies)

In [None]:
# Run to see the performance on validation images (just to see different images)
# def visualize_batch(images, masks, predicted_masks):
#     num_samples = images.size(0)

#     for i in range(num_samples):
#         plt.figure(figsize=(18, 6))

#         # Original Image
#         plt.subplot(1, 4, 1)
#         plt.imshow(images[i].permute(1, 2, 0).cpu())
#         plt.title('Original Image')
#         plt.axis('off')

#         # Ground Truth Mask
#         plt.subplot(1, 4, 2)
#         plt.imshow(masks[i].squeeze().cpu(), cmap='gray')
#         plt.title('Ground Truth Mask')
#         plt.axis('off')

#         # Predicted Mask
#         plt.subplot(1, 4, 3)
#         predicted_mask = predicted_masks[i].squeeze().cpu()
#         predicted_mask_binary = (predicted_mask > threshold).float()  # Apply threshold
#         plt.imshow(predicted_mask_binary, cmap='gray')
#         plt.title('Predicted Mask')
#         plt.axis('off')

#         # Original Image with Overlay Contours
#         plt.subplot(1, 4, 4)
#         plt.imshow(images[i].permute(1, 2, 0).cpu())
#         plt.title('Original Image with Contours')
#         plt.axis('off')

#         # Find contours in the predicted binary mask
#         contours = find_contours(predicted_mask_binary.cpu().numpy(), level=0.5)

#         # Draw contours on the original image
#         for contour in contours:
#             plt.plot(contour[:, 1], contour[:, 0], linewidth=2, color='red')

#         plt.show()

# def evaluate_model(val_loader, model, device, threshold=threshold):
#     model.eval()  # Set the model to evaluation mode
#     with torch.no_grad():  # No gradient calculation during evaluation
#         for images, masks in val_loader:
#             images = images.to(device)
#             masks = masks.to(device)

#             # Forward pass
#             predicted_masks = model(images)

#             # Visualize a batch of data
#             visualize_batch(images, masks, predicted_masks, threshold)
#             break  # Stop after visualizing one batch

# # Usage
# evaluate_model(val_loader, model, device)


### Testing the model

In [None]:
def evaluate_model(test_loader, model, device):
    model.eval()  # Set the model to evaluation mode
    with torch.no_grad():  # No gradient calculation during evaluation
        for images, masks in test_loader:
            images = images.to(device)
            masks = masks.to(device)

            # Forward pass
            predicted_masks = model(images)

            # Visualize one batch of data
            visualize_batch(images, masks, predicted_masks)
            break  # Stop after visualizing one batch

# Function to visualize a batch of data
def visualize_batch(images, masks, predicted_masks):
    num_samples = images.size(0)

    for i in range(num_samples):
        plt.figure(figsize=(18, 6))

        # Original image
        plt.subplot(1, 4, 1)
        plt.imshow(images[i].permute(1, 2, 0).cpu())
        plt.title('Original Image')
        plt.axis('off')

        # Ground truth mask
        plt.subplot(1, 4, 2)
        plt.imshow(masks[i].squeeze().cpu(), cmap='gray')
        plt.title('Original Mask')
        plt.axis('off')

        # Predicted mask
        plt.subplot(1, 4, 3)
        predicted_mask = predicted_masks[i].squeeze().cpu()  # Remove the extra dimension
        # Apply threshold
        predicted_mask_binary = (predicted_mask > threshold).float()
        plt.imshow(predicted_mask_binary, cmap='gray')
        plt.title('Predicted Mask')
        plt.axis('off')

        # Original image with overlaid contours
        plt.subplot(1, 4, 4)
        plt.imshow(images[i].permute(1, 2, 0).cpu())
        plt.title('Image with Contours')
        plt.axis('off')

        # Find contours in the predicted binary mask
        contours = find_contours(predicted_mask_binary.cpu().numpy(), level=0.5)

        # Draw contours on the original image
        for contour in contours:
            plt.plot(contour[:, 1], contour[:, 0], linewidth=2, color='red')

        plt.show()

# Usage
evaluate_model(test_loader, model, device)


### Evaluation Metrics

In [None]:
# Define metric functions
def intersection_over_union(pred, target):
    if torch.sum(target) == 0 and torch.sum(pred) == 0:
        return 1.0  # Perfect score when there's nothing to predict and model predicts nothing
    elif torch.sum(target) == 0 and torch.sum(pred) > 0:
        return 0.0  # Score of 0 when there's nothing to predict but model predicts something
    else:
        intersection = torch.logical_and(pred, target).sum()
        union = torch.logical_or(pred, target).sum()
        if union == 0:
            return 0.0  # Return 0.0 instead of NaN when union is 0
        iou = intersection.float() / union.float()
        return iou.item()

def dice_coefficient(pred, target):
    eps = 1e-9  # to avoid division by zero
    if torch.sum(target) == 0 and torch.sum(pred) == 0:
        return 1.0  # Perfect score when there's nothing to predict and model predicts nothing
    elif torch.sum(target) == 0 and torch.sum(pred) > 0:
        return 0.0  # Score of 0 when there's nothing to predict but model predicts something
    else:
        intersection = torch.sum(pred * target)
        fsum = torch.sum(pred)
        ssum = torch.sum(target)
        dice = (2 * intersection) / (fsum + ssum + eps)
        return dice.item()

def precision_score(pred, target):
    eps = 1e-9  # to avoid division by zero
    if torch.sum(target) == 0 and torch.sum(pred) == 0:
        return 1.0  # Perfect score when there's nothing to predict and model predicts nothing
    elif torch.sum(target) == 0 and torch.sum(pred) > 0:
        return 0.0  # Score of 0 when there's nothing to predict but model predicts something
    else:
        tp = torch.logical_and(pred, target).sum().float()
        fp = torch.logical_and(pred, 1 - target).sum().float()
        precision = tp / (tp + fp + eps)
        return precision.item()

def recall_score(pred, target):
    eps = 1e-9  # to avoid division by zero
    if torch.sum(target) == 0 and torch.sum(pred) == 0:
        return 1.0  # Perfect score when there's nothing to predict and model predicts nothing
    elif torch.sum(target) == 0 and torch.sum(pred) > 0:
        return 0.0  # Score of 0 when there's nothing to predict but model predicts something
    else:
        tp = torch.logical_and(pred, target).sum().float()
        fn = torch.logical_and(1 - pred, target).sum().float()
        recall = tp / (tp + fn + eps)
        return recall.item()

def f1_score(pred, target):
    p = precision_score(pred, target)
    r = recall_score(pred, target)
    f1 = 2 * (p * r) / (p + r + 1e-9)
    return f1

# Example usage in evaluation pipeline
def evaluate_model(test_loader, model, device, threshold):
    iou_scores = []
    dice_scores = []
    precision_scores = []
    recall_scores = []
    f1_scores = []

    model.eval()

    with torch.no_grad():
        for images, masks in test_loader:
            images, masks = images.to(device), masks.to(device)

            # Calculate model predictions
            predicted_masks = model(images)

            # Apply thresholding to predicted masks
            predicted_masks_binary = (predicted_masks > threshold).float()

            # Calculate IoU, DICE, Precision, Recall, and F1 for each sample in the batch
            for i in range(images.size(0)):
                pred_mask = predicted_masks_binary[i].squeeze()
                true_mask = masks[i].squeeze()

                iou = intersection_over_union(pred_mask, true_mask)
                dice = dice_coefficient(pred_mask, true_mask)
                precision = precision_score(pred_mask, true_mask)
                recall = recall_score(pred_mask, true_mask)
                f1 = f1_score(pred_mask, true_mask)

                iou_scores.append(iou)
                dice_scores.append(dice)
                precision_scores.append(precision)
                recall_scores.append(recall)
                f1_scores.append(f1)

    # Calculate the mean of metrics across all samples in the test dataset
    mean_iou = sum(iou_scores) / len(iou_scores) if iou_scores else float('nan')
    mean_dice = sum(dice_scores) / len(dice_scores)
    mean_precision = sum(precision_scores) / len(precision_scores)
    mean_recall = sum(recall_scores) / len(recall_scores)
    mean_f1 = sum(f1_scores) / len(f1_scores)

    # Print the results
    print(f'Mean IoU: {mean_iou:.4f}')
    print(f'Mean DICE Coefficient: {mean_dice:.4f}')
    print(f'Mean Precision: {mean_precision:.4f}')
    print(f'Mean Recall: {mean_recall:.4f}')
    print(f'Mean F1 Score: {mean_f1:.4f}')

In [None]:
evaluate_model(test_loader, model, device, threshold)

In [None]:
Mean IoU: 0.5689
Mean DICE Coefficient: 0.6303
Mean Precision: 0.6805
Mean Recall: 0.6352
Mean F1 Score: 0.6303