## Import Libraries and Seeds Configuration

In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from scipy.ndimage import gaussian_filter, distance_transform_edt
import random
import os
import cv2
import tarfile
from PIL import Image

from tqdm import tqdm
import time

from sklearn.model_selection import train_test_split

import torch
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from torchvision.transforms import Normalize, Resize, Compose
from torchvision.transforms.functional import to_pil_image
from torch.optim.lr_scheduler import ExponentialLR

import torch.nn as nn
import torch.nn.functional as F

In [None]:
try:
    from torchinfo import summary
except:
    print("[INFO] Couldn't find torchinfo... installing it.")
    !pip install -q torchinfo
    from torchinfo import summary

try:
    import rasterio
except:
    print("[INFO] Couldn't find rasterio... installing it.")
    !pip install rasterio
    import rasterio

In [None]:
# Set random seeds to ensure reproducibility
torch.manual_seed(44)
np.random.seed(44)
random.seed(44)
random_state = 44

## Data Loading/ Filtering


In [None]:
# Change only these three paths accordingly to run-environment
raw_data = r'/kaggle/input/kelpsegmentationdataset/data/data/raw'
processed_data = r'/kaggle/input/kelpsegmentationdataset/data/data/processed'
outputs = r'/kaggle/working/'

# Raw data
train_satellite = os.path.join(raw_data, 'train_satellite')
train_kelp = os.path.join(raw_data, 'train_kelp')
test_satellite = os.path.join(raw_data, 'test_satellite')
metadata_csv = os.path.join(raw_data, 'metadata_fTq0l2T.csv')

# Processed data
data_statistics = os.path.join(processed_data, 'train_img_statistics.csv')
normalization_statistics_sb = os.path.join(processed_data, 'norm_stats_sb_low5_up5_perc_v2.npz')
normalization_statistics_dist= os.path.join(processed_data, 'norm_stats_dist_total.npz')
normalization_statistics_dem = os.path.join(processed_data, 'norm_stats_dem_total.npz')
normalization_statistics_ndvi = os.path.join(processed_data, 'norm_stats_ndvi_low5_up5_perc_v2.npz')

train_satellite_np = os.path.join(processed_data, 'train_satellite_np')
train_distance_maps_np = os.path.join(processed_data, 'train_distance_maps_np')
train_ndvi_np = os.path.join(processed_data, 'train_ndvi_np')
train_kelp_np = os.path.join(processed_data, 'train_kelp_np')
test_satellite_np = os.path.join(processed_data, 'test_satellite_np')
test_distance_maps_np =  os.path.join(processed_data, 'test_distance_maps_np')
test_ndvi_np = os.path.join(processed_data, 'test_ndvi_np')

# Outputs
predictions =  os.path.join(outputs, 'predictions')
checkpoints =  os.path.join(outputs, 'checkpoints')

# Create necessary directories
for directory in [predictions, checkpoints]:
    os.makedirs(directory, exist_ok=True)
    print(f"Created or verified the existence of the directory: {directory}")

### Loading data and metadata files

In [None]:
# Load metadata and statistics
metadata = pd.read_csv(metadata_csv)
train_statistics = pd.read_csv(data_statistics)

# Select metadata for train and test datasets
train_metadata = metadata[metadata['in_train'] == True]
test_given_metadata = metadata[(metadata['in_train'] == False) & (metadata['type'] == 'satellite')]

# Join metadata file with img statistics
train_metadata_stats = pd.merge(train_metadata, train_statistics, right_on='image_id', left_on = 'tile_id', how='inner')

train_metadata_stats.head()

### Filter data according to statistics

In [None]:
# Get filtered dataset (to be splitted after into train, val, test)
filtered_train_metadata = train_metadata_stats[(train_metadata_stats['num_kelp_px'] > 0) & (train_metadata_stats['perc_clouds'] < 0.2) & (train_metadata_stats['perc_corrupt'] < 0.1)]
print(f'filtered_train_dataset: {filtered_train_metadata.shape}')

# Get test_un_2 dataset (metadata of all the images not included in the filtered_train_metadata dataset)
test_un_2 = train_metadata_stats[(train_metadata_stats['num_kelp_px'] == 0) | (train_metadata_stats['perc_clouds'] > 0.2) | (train_metadata_stats['perc_corrupt'] > 0.1)]
print(f'test_un_2: {test_un_2.shape}')

print(f'sum: {len(test_un_2) + len(filtered_train_metadata)}')

## Train-Test-Val Split

In [None]:
# Data Stratification
bin_count = 20
bin_numbers = pd.qcut(x=filtered_train_metadata['perc_kelp'], q=bin_count, labels=False, duplicates='drop')

In [None]:
# Define training way
# run_on = 'sample' or 'train_split' or 'train_total'
run_on = 'train_split'

# Selects only a subsample of the filtered dataset for debuging (ex: 5%) and splits into train, val, test
if run_on == 'sample':
  # Select subset of 5% of the data
  subset_metadata, _ = train_test_split(filtered_train_metadata, test_size=0.95, random_state=random_state)
  # Split the reduced dataset into train, val, test
  train_metadata, temp_metadata = train_test_split(subset_metadata, test_size=0.3, random_state=random_state)
  val_metadata, test_metadata = train_test_split(temp_metadata, test_size=0.5, random_state=random_state)

# Uses the total filtered dataset and splits it into train, val, test
elif run_on == 'train_split':
  # Split train dataset into train, val, test
  train_metadata, temp_metadata = train_test_split(filtered_train_metadata, test_size=960, random_state=random_state, stratify= bin_numbers)
  bin_numbers_temp = pd.qcut(x=temp_metadata['perc_kelp'], q=bin_count, labels=False, duplicates='drop')  
  val_metadata, test_metadata = train_test_split(temp_metadata, test_size=0.5, random_state=random_state, stratify= bin_numbers_temp)

# Uses the total filtered dataset and divides into train and val and uses the given train test
elif run_on == 'train_total':
  # Split into train and val (10%)
  train_metadata, val_metadata = train_test_split(filtered_train_metadata, test_size=640, random_state=random_state, stratify= bin_numbers)

## Conversion of from tif to npz

### Conversion of data


In [None]:
def calculate_distance_from_coast(image, sigma=1, constant_distance=1000):

    # Extract the DEM from the 7th layer of the image
    dem = image[6, :, :]

    # Smooth the DEM
    dem_sm = gaussian_filter(dem, sigma=sigma)

    # Binarize the smoothed DEM
    dem_bw = dem_sm > 0.0

    # Compute the Canny edges to find the coastline
    edges_canny = cv2.Canny((dem_bw * 255).astype(np.uint8), 50, 150)

    # Invert the edges to represent background (non-edge regions)
    background_mask = ~edges_canny.astype(bool)

    # Combine the binarized DEM and the background mask to exclude coastline edges
    background_binarized_dem = dem_bw & background_mask

    # If the whole image is water (or there's no coastline), set the distance map to a constant value
    if np.sum(background_binarized_dem) == 0:
        distance_map_background = np.ones_like(dem_bw) * constant_distance
    else:
        # Compute distance map for the background, indicating distance from coastline
        distance_map_background = distance_transform_edt(~background_binarized_dem)
    return distance_map_background

In [None]:
# def convert_tiff_to_npz(tif_files_path, original_data_target_path, distance_map_target_path=None, is_labels=False):
#     # List all TIFF files in the source directory
#     tiff_files = [f for f in os.listdir(tif_files_path) if f.endswith('.tif')]

#     # Process each TIFF file
#     for filename in tqdm(tiff_files, desc="Converting TIFF to NPZ"):
#         file_path = os.path.join(tif_files_path, filename)

#         with rasterio.open(file_path) as src:
#             image_array = src.read()  # Read the image as a NumPy array

#             # Determine the correct data type based on whether the image is a label
#             data_type = np.int8 if is_labels else np.int16

#             # Convert the image array to the determined data type
#             image_array = image_array.astype(data_type)

#             # Save the original image data
#             original_npz_filename = os.path.splitext(filename)[0] + '.npz'
#             original_npz_file_path = os.path.join(original_data_target_path, original_npz_filename)
#             np.savez_compressed(original_npz_file_path, image_array)

#             # Only calculate and save distance maps for non-label images
#             if not is_labels:
#                 # Calculate the distance map, assume the function `calculate_distance_from_coast` is defined
#                 distance_map = calculate_distance_from_coast(image_array)

#                 # Save the distance map
#                 distance_map_npz_filename = os.path.splitext(filename)[0] + '_distance_map.npz'
#                 distance_map_npz_file_path = os.path.join(distance_map_target_path, distance_map_npz_filename)
#                 np.savez_compressed(distance_map_npz_file_path, distance_map)

# convert_tiff_to_npz(train_satellite, train_satellite_np, train_distance_maps_np, is_labels=False)
# convert_tiff_to_npz(train_kelp, train_kelp_np, None, is_labels=True)
# convert_tiff_to_npz(test_satellite, test_satellite_np, test_distance_maps_np, is_labels=False)

In [None]:
def count_and_size_npz_images(directory_path):
    # List all files in the directory
    files = os.listdir(directory_path)
    # Initialize total size
    total_size = 0
    # Count files with the .npz extension and sum their sizes
    npz_count = 0
    for f in files:
        if f.endswith('.npz'):
            npz_count += 1
            file_path = os.path.join(directory_path, f)
            total_size += os.path.getsize(file_path)
    # Convert total size from bytes to gigabytes
    total_size_gb = total_size / (1024**3)
    return npz_count, total_size_gb


# Print the number of NPZ images and their total size in each directory
count, size = count_and_size_npz_images(train_satellite_np)
print(f"Train Satellite Data: {count} files, {size:.2f} GB")

count, size = count_and_size_npz_images(train_distance_maps_np)
print(f"Train Distance Maps: {count} files, {size:.2f} GB")

count, size = count_and_size_npz_images(train_ndvi_np)
print(f"Train NDVI: {count} files, {size:.2f} GB")

count, size = count_and_size_npz_images(train_kelp_np)
print(f"Train Kelp Labels: {count} files, {size:.2f} GB")

count, size = count_and_size_npz_images(test_satellite_np)
print(f"Test Satellite Data: {count} files, {size:.2f} GB")

count, size = count_and_size_npz_images(test_distance_maps_np)
print(f"Test Distance Maps: {count} files, {size:.2f} GB")

count, size = count_and_size_npz_images(test_ndvi_np)
print(f"Test NDVI: {count} files, {size:.2f} GB")

## Custom Dataset Class and Dataloader

In [None]:
class KelpDataset(Dataset):
    def __init__(self, metadata, data_path, distance_map_path=None, ndvi_path=None, label_path=None, data_transforms=None, label_transforms=None):
        self.metadata = metadata
        self.data_path = data_path
        self.distance_map_path = distance_map_path
        self.ndvi_path = ndvi_path
        self.label_path = label_path
        self.data_transforms = data_transforms
        self.label_transforms = label_transforms

    def __len__(self):
        return len(self.metadata)

    def __getitem__(self, idx):
        # Load spectral bands and cloud masks
        data_name = os.path.join(self.data_path, self.metadata.iloc[idx]['filename'].replace('.tif', '.npz'))
        with np.load(data_name) as data_file:
            data = data_file[data_file.files[0]]
        spectral_bands = data[[0, 1, 2, 3, 4], :, :]
        cloud_mask = np.expand_dims(data[5, :, :], axis=0)  # Shape is [1, height, width]
        dem = np.expand_dims(data[6, :, :], axis=0)
        
        # Convert from Numpy Arrays to Torch Tensors
        spectral_bands = torch.tensor(spectral_bands, dtype=torch.float32)
        dem = torch.tensor(dem, dtype=torch.float32)
        cloud_mask = torch.tensor(cloud_mask, dtype=torch.uint8)

        # Load distance map
        distance_map_name = os.path.join(self.distance_map_path, self.metadata.iloc[idx]['tile_id'] + '_satellite_distance_map.npz')
        with np.load(distance_map_name) as distance_map_file:
            distance_map = distance_map_file[distance_map_file.files[0]]
        distance_map = np.expand_dims(distance_map, axis=-1)  # Shape is [height, width, 1]
        distance_map = torch.tensor(distance_map, dtype=torch.float32).permute(2, 0, 1)  # Shape is [1, height, width]
        
        # Load ndvi
        ndvi_name = os.path.join(self.ndvi_path, self.metadata.iloc[idx]['tile_id'] + '_satellite.npz')
        with np.load(ndvi_name) as ndvi_file:
            ndvi = ndvi_file[ndvi_file.files[0]]
        ndvi = np.expand_dims(ndvi, axis=-1)  # Shape is [height, width, 1]
        ndvi = torch.tensor(ndvi, dtype=torch.float32).permute(2, 0, 1)  # Shape is [1, height, width]
        
        # Load label
        label = None
        if self.label_path:
            label_name = os.path.join(self.label_path, self.metadata.iloc[idx]['tile_id'] + '_kelp.npz')
            if os.path.exists(label_name):
                with np.load(label_name) as label_file:
                    label = torch.tensor(label_file[label_file.files[0]], dtype=torch.uint8)
        if label is None:
            # Initialize label with zeros if not available
            default_label_shape = (1, data.shape[1], data.shape[2])
            label = torch.zeros(default_label_shape, dtype=torch.uint8)

        # Stack spectral bands, distance map, dem & ndvi
        data = torch.cat([spectral_bands, distance_map, dem, ndvi], dim=0)

        # Apply transformations if provided
        if self.data_transforms is not None:
            data = self.data_transforms(data)
        if self.label_transforms is not None and label is not None:
            label = self.label_transforms(label)

        return data, label, cloud_mask, self.metadata.iloc[idx]['tile_id']


### Normalization: Band Statistics for Training Set

In [None]:
# Paths to your .npz files
normalization_npz_files = {
    'sb': normalization_statistics_sb,
    'dist': normalization_statistics_dist,
    'dem': normalization_statistics_dem,
    'ndvi': normalization_statistics_ndvi,
}

# Initialize lists to hold all means and stds
all_means = []
all_stds = []

# Loop through each .npz file, load data, and extract means and stds
for key, file_path in normalization_npz_files.items():
    data = np.load(file_path)
    means = np.atleast_1d(data['means'])
    stds = np.atleast_1d(data['stds'])

    # Append to lists
    all_means.append(means)
    all_stds.append(stds)

# Concatenate all means and stds
all_means = np.concatenate(all_means, axis=0)
all_stds = np.concatenate(all_stds, axis=0)

# Display the concatenated means and stds
print("\nConcatenated Means:\n", all_means)
print("\nConcatenated Stds:\n", all_stds)

### Transforms

In [None]:
# Transformations
data_transforms = Compose([
    Normalize(mean=all_means,
              std=all_stds)])

### Instantiate Datasets and Dataloaders

In [None]:
# Instantiate datasets
train_dataset = KelpDataset(metadata=train_metadata, data_path=train_satellite_np, distance_map_path=train_distance_maps_np, ndvi_path=train_ndvi_np, label_path=train_kelp_np, data_transforms=data_transforms, label_transforms=None)
val_dataset = KelpDataset(metadata=val_metadata, data_path=train_satellite_np, distance_map_path=train_distance_maps_np, ndvi_path=train_ndvi_np, label_path=train_kelp_np, data_transforms=data_transforms, label_transforms=None)
test_dataset = KelpDataset(metadata=test_metadata, data_path=train_satellite_np, distance_map_path=train_distance_maps_np, ndvi_path=train_ndvi_np, label_path=train_kelp_np, data_transforms=data_transforms, label_transforms=None)
test_given_dataset = KelpDataset(metadata=test_given_metadata, data_path=test_satellite_np, distance_map_path=test_distance_maps_np, ndvi_path=test_ndvi_np,  label_path=None, data_transforms=data_transforms, label_transforms=None)

# Instantiate dataloaders
batch_size = 16
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
test_given_loader = DataLoader(test_given_dataset, batch_size=batch_size, shuffle=False)

### Datasets Inspection

In [None]:
print(f"Length of Training Dataset: {len(train_dataset)}")
print(f"Length of Validation Dataset: {len(val_dataset)}")
print(f"Length of Test Dataset: {len(test_dataset)}")
print(f"Length of Test Given Dataset: {len(test_given_dataset)}")

In [None]:
# Shapes and Types train Dataloader
for data, label, cloud_mask, _ in train_loader:
    print("Data:", data.shape, "data type:", data.dtype)
    print("Label shape:", label.shape, "data type:", label.dtype)
    print("Cloud Mask shape:", cloud_mask.shape, "data type:", cloud_mask.dtype)
    break  # Remove or comment out this line to iterate through the entire dataset

In [None]:
def visualize_data(dataloader, num_rows=1):
    # Fetch the first batch
    for data, label, cloud_mask, tile_id in dataloader:
        # Ensure num_rows does not exceed the batch size
        num_rows = min(num_rows, data.shape[0])

        fig, axes = plt.subplots(num_rows, 4, figsize=(20, 5 * num_rows))
        # Ensure axes is always a 2D array for consistency
        if num_rows == 1:
            axes = axes.reshape(-1, 4)

        for row in range(num_rows):
            ax_rgb, ax_label, ax_cloud, ax_distance = axes[row]

            # Display RGB Image
            rgb = data[row][[2,3,4], :, :] 
            rgb = np.transpose(rgb.cpu().numpy(), (1, 2, 0))
            rgb = (rgb - rgb.min()) / (rgb.max() - rgb.min())  # Normalize to [0, 1]
            ax_rgb.imshow(rgb)
            ax_rgb.set_title(f"RGB Image - Sample {tile_id[row]}")

            # Display Label
            label_np = label[row].squeeze().cpu().numpy()  # Ensure data is on CPU and squeezed
            ax_label.imshow(label_np, cmap='gray')
            ax_label.set_title(f"Kelp Mask - Sample {tile_id[row]}")

            # Display Cloud Mask
            cloud_mask_np = cloud_mask[row].squeeze().cpu().numpy()  # Ensure data is on CPU and squeezed
            ax_cloud.imshow(cloud_mask_np, cmap='gray')
            ax_cloud.set_title(f"Cloud Mask - Sample {tile_id[row]}")

            # Display Distance Map
            distance_map_np = data[row][5, :, :].squeeze().cpu().numpy()
            ax_distance.imshow(distance_map_np, cmap='viridis')
            ax_distance.set_title(f"Distance Map - Sample {tile_id[row]}")

        plt.tight_layout()
        plt.show()

        # Break after visualizing the first batch
        break

# visualize_data(train_loader, num_rows=5)
# print("***"*100)
# visualize_data(val_loader, num_rows=5)
# print("***"*100)
# visualize_data(test_loader, num_rows=5)
# print("***"*100)
# visualize_data(test_given_loader, num_rows=5)

## Model Architectures




### Segnet
Code available on: https://github.com/vinceecws/SegNet_PyTorch/blob/master/Pavements/SegNet.py

<img src="https://drive.google.com/uc?export=view&id=1-J5p-sMU_LSX1VbbXo6P---LL6-KYL5d" width="750"/>

*Source: Badrinarayanan et al. (2017)*

In [None]:
class SEGNET(nn.Module):

    def __init__(self, in_chn, out_chn, BN_momentum):
        super(SEGNET, self).__init__()

        # Construct a string to represent the model configuration for identification
        self.modelname = f"SEGNET_in-chn={in_chn}_out-chn={out_chn}_BN_momentum={BN_momentum}"

        #SegNet Architecture
        #Takes input of size in_chn = 3 (RGB images have 3 channels)
        # Outputs size label_chn (N # of classes)

        #ENCODING consists of 5 stages
        #Stage 1, 2 has 2 layers of Convolution + Batch Normalization + Max Pool respectively
        #Stage 3, 4, 5 has 3 layers of Convolution + Batch Normalization + Max Pool respectively

        #General Max Pool 2D for ENCODING layers
        #Pooling indices are stored for Upsampling in DECODING layers

        self.in_chn = in_chn
        self.out_chn = out_chn

        self.MaxEn = nn.MaxPool2d(2, stride=2, return_indices=True)

        self.ConvEn11 = nn.Conv2d(self.in_chn, 64, kernel_size=3, padding=1)
        self.BNEn11 = nn.BatchNorm2d(64, momentum=BN_momentum)
        self.ConvEn12 = nn.Conv2d(64, 64, kernel_size=3, padding=1)
        self.BNEn12 = nn.BatchNorm2d(64, momentum=BN_momentum)

        self.ConvEn21 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
        self.BNEn21 = nn.BatchNorm2d(128, momentum=BN_momentum)
        self.ConvEn22 = nn.Conv2d(128, 128, kernel_size=3, padding=1)
        self.BNEn22 = nn.BatchNorm2d(128, momentum=BN_momentum)

        self.ConvEn31 = nn.Conv2d(128, 256, kernel_size=3, padding=1)
        self.BNEn31 = nn.BatchNorm2d(256, momentum=BN_momentum)
        self.ConvEn32 = nn.Conv2d(256, 256, kernel_size=3, padding=1)
        self.BNEn32 = nn.BatchNorm2d(256, momentum=BN_momentum)
        self.ConvEn33 = nn.Conv2d(256, 256, kernel_size=3, padding=1)
        self.BNEn33 = nn.BatchNorm2d(256, momentum=BN_momentum)

        self.ConvEn41 = nn.Conv2d(256, 512, kernel_size=3, padding=1)
        self.BNEn41 = nn.BatchNorm2d(512, momentum=BN_momentum)
        self.ConvEn42 = nn.Conv2d(512, 512, kernel_size=3, padding=1)
        self.BNEn42 = nn.BatchNorm2d(512, momentum=BN_momentum)
        self.ConvEn43 = nn.Conv2d(512, 512, kernel_size=3, padding=1)
        self.BNEn43 = nn.BatchNorm2d(512, momentum=BN_momentum)

        self.ConvEn51 = nn.Conv2d(512, 512, kernel_size=3, padding=1)
        self.BNEn51 = nn.BatchNorm2d(512, momentum=BN_momentum)
        self.ConvEn52 = nn.Conv2d(512, 512, kernel_size=3, padding=1)
        self.BNEn52 = nn.BatchNorm2d(512, momentum=BN_momentum)
        self.ConvEn53 = nn.Conv2d(512, 512, kernel_size=3, padding=1)
        self.BNEn53 = nn.BatchNorm2d(512, momentum=BN_momentum)


        #DECODING consists of 5 stages
        #Each stage corresponds to their respective counterparts in ENCODING

        #General Max Pool 2D/Upsampling for DECODING layers
        self.MaxDe = nn.MaxUnpool2d(2, stride=2)

        self.ConvDe53 = nn.Conv2d(512, 512, kernel_size=3, padding=1)
        self.BNDe53 = nn.BatchNorm2d(512, momentum=BN_momentum)
        self.ConvDe52 = nn.Conv2d(512, 512, kernel_size=3, padding=1)
        self.BNDe52 = nn.BatchNorm2d(512, momentum=BN_momentum)
        self.ConvDe51 = nn.Conv2d(512, 512, kernel_size=3, padding=1)
        self.BNDe51 = nn.BatchNorm2d(512, momentum=BN_momentum)

        self.ConvDe43 = nn.Conv2d(512, 512, kernel_size=3, padding=1)
        self.BNDe43 = nn.BatchNorm2d(512, momentum=BN_momentum)
        self.ConvDe42 = nn.Conv2d(512, 512, kernel_size=3, padding=1)
        self.BNDe42 = nn.BatchNorm2d(512, momentum=BN_momentum)
        self.ConvDe41 = nn.Conv2d(512, 256, kernel_size=3, padding=1)
        self.BNDe41 = nn.BatchNorm2d(256, momentum=BN_momentum)

        self.ConvDe33 = nn.Conv2d(256, 256, kernel_size=3, padding=1)
        self.BNDe33 = nn.BatchNorm2d(256, momentum=BN_momentum)
        self.ConvDe32 = nn.Conv2d(256, 256, kernel_size=3, padding=1)
        self.BNDe32 = nn.BatchNorm2d(256, momentum=BN_momentum)
        self.ConvDe31 = nn.Conv2d(256, 128, kernel_size=3, padding=1)
        self.BNDe31 = nn.BatchNorm2d(128, momentum=BN_momentum)

        self.ConvDe22 = nn.Conv2d(128, 128, kernel_size=3, padding=1)
        self.BNDe22 = nn.BatchNorm2d(128, momentum=BN_momentum)
        self.ConvDe21 = nn.Conv2d(128, 64, kernel_size=3, padding=1)
        self.BNDe21 = nn.BatchNorm2d(64, momentum=BN_momentum)

        self.ConvDe12 = nn.Conv2d(64, 64, kernel_size=3, padding=1)
        self.BNDe12 = nn.BatchNorm2d(64, momentum=BN_momentum)
        self.ConvDe11 = nn.Conv2d(64, self.out_chn, kernel_size=3, padding=1)
        self.BNDe11 = nn.BatchNorm2d(self.out_chn, momentum=BN_momentum)

    def forward(self, x):

        #ENCODE LAYERS
        #Stage 1
        x = F.relu(self.BNEn11(self.ConvEn11(x)))
        x = F.relu(self.BNEn12(self.ConvEn12(x)))
        x, ind1 = self.MaxEn(x)
        size1 = x.size()

        #Stage 2
        x = F.relu(self.BNEn21(self.ConvEn21(x)))
        x = F.relu(self.BNEn22(self.ConvEn22(x)))
        x, ind2 = self.MaxEn(x)
        size2 = x.size()

        #Stage 3
        x = F.relu(self.BNEn31(self.ConvEn31(x)))
        x = F.relu(self.BNEn32(self.ConvEn32(x)))
        x = F.relu(self.BNEn33(self.ConvEn33(x)))
        x, ind3 = self.MaxEn(x)
        size3 = x.size()

        #Stage 4
        x = F.relu(self.BNEn41(self.ConvEn41(x)))
        x = F.relu(self.BNEn42(self.ConvEn42(x)))
        x = F.relu(self.BNEn43(self.ConvEn43(x)))
        x, ind4 = self.MaxEn(x)
        size4 = x.size()

        #Stage 5
        x = F.relu(self.BNEn51(self.ConvEn51(x)))
        x = F.relu(self.BNEn52(self.ConvEn52(x)))
        x = F.relu(self.BNEn53(self.ConvEn53(x)))
        x, ind5 = self.MaxEn(x)
        size5 = x.size()

        #DECODE LAYERS
        #Stage 5
        x = self.MaxDe(x, ind5, output_size=size4)
        x = F.relu(self.BNDe53(self.ConvDe53(x)))
        x = F.relu(self.BNDe52(self.ConvDe52(x)))
        x = F.relu(self.BNDe51(self.ConvDe51(x)))

        #Stage 4
        x = self.MaxDe(x, ind4, output_size=size3)
        x = F.relu(self.BNDe43(self.ConvDe43(x)))
        x = F.relu(self.BNDe42(self.ConvDe42(x)))
        x = F.relu(self.BNDe41(self.ConvDe41(x)))

        #Stage 3
        x = self.MaxDe(x, ind3, output_size=size2)
        x = F.relu(self.BNDe33(self.ConvDe33(x)))
        x = F.relu(self.BNDe32(self.ConvDe32(x)))
        x = F.relu(self.BNDe31(self.ConvDe31(x)))

        #Stage 2
        x = self.MaxDe(x, ind2, output_size=size1)
        x = F.relu(self.BNDe22(self.ConvDe22(x)))
        x = F.relu(self.BNDe21(self.ConvDe21(x)))

        #Stage 1
        x = self.MaxDe(x, ind1)
        x = F.relu(self.BNDe12(self.ConvDe12(x)))
        x = self.ConvDe11(x)

        return x

In [None]:
# Check sizes of Segnet

# segnet_test = SEGNET(in_chn=6, out_chn=1, BN_momentum=0.5)

# Create random input sizes
# random_input_image = (1, 6, 350, 350)

# Get a summary of the input and outputs of Unet
# summary(model=segnet_test,
#         input_size= random_input_image,
#         col_names=["input_size", "output_size", "num_params", "trainable"],
#         col_width=20,
#         row_settings=["var_names"])

### Unet Modified

Code available on: https://github.com/4uiiurz1/pytorch-nested-unet/blob/master/archs.py

In [None]:
class VGGBlock(nn.Module):
    def __init__(self, in_channels, middle_channels, out_channels):
        super().__init__()
        self.relu = nn.ReLU(inplace=True)
        self.conv1 = nn.Conv2d(in_channels, middle_channels, 3, padding=1)
        self.bn1 = nn.BatchNorm2d(middle_channels)
        self.conv2 = nn.Conv2d(middle_channels, out_channels, 3, padding=1)
        self.bn2 = nn.BatchNorm2d(out_channels)

    def forward(self, x):
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)

        return out

    
def pad_and_concat(x1, x2):
    # Calculate size differences
    diffY = x2.size()[2] - x1.size()[2]
    diffX = x2.size()[3] - x1.size()[3]

    # Padding to match the sizes
    x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,
                    diffY // 2, diffY - diffY // 2])

    # Concatenating along the channel dimension
    return torch.cat([x2, x1], dim=1)

class UNetModified(nn.Module):
    def __init__(self, num_classes, input_channels=3, **kwargs):
        super().__init__()
        
        # Construct a string to represent the model configuration for identification
        self.modelname = f"UNetMod_num-classes={num_classes}_input-channels={input_channels}"

        nb_filter = [64, 128, 256, 512, 1024]

        self.pool = nn.MaxPool2d(2, 2)
        self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)

        self.conv0_0 = VGGBlock(input_channels, nb_filter[0], nb_filter[0])
        self.conv1_0 = VGGBlock(nb_filter[0], nb_filter[1], nb_filter[1])
        self.conv2_0 = VGGBlock(nb_filter[1], nb_filter[2], nb_filter[2])
        self.conv3_0 = VGGBlock(nb_filter[2], nb_filter[3], nb_filter[3])
        self.conv4_0 = VGGBlock(nb_filter[3], nb_filter[4], nb_filter[4])

        self.conv3_1 = VGGBlock(nb_filter[3]+nb_filter[4], nb_filter[3], nb_filter[3])
        self.conv2_2 = VGGBlock(nb_filter[2]+nb_filter[3], nb_filter[2], nb_filter[2])
        self.conv1_3 = VGGBlock(nb_filter[1]+nb_filter[2], nb_filter[1], nb_filter[1])
        self.conv0_4 = VGGBlock(nb_filter[0]+nb_filter[1], nb_filter[0], nb_filter[0])

        self.final = nn.Conv2d(nb_filter[0], num_classes, kernel_size=1)
        
    def forward(self, input):
        x0_0 = self.conv0_0(input)
        x1_0 = self.conv1_0(self.pool(x0_0))
        x2_0 = self.conv2_0(self.pool(x1_0))
        x3_0 = self.conv3_0(self.pool(x2_0))
        x4_0 = self.conv4_0(self.pool(x3_0))

        # Use the pad_and_concat function for each concatenation step
        x3_1 = self.conv3_1(pad_and_concat(self.up(x4_0), x3_0))
        x2_2 = self.conv2_2(pad_and_concat(self.up(x3_1), x2_0))
        x1_3 = self.conv1_3(pad_and_concat(self.up(x2_2), x1_0))
        x0_4 = self.conv0_4(pad_and_concat(self.up(x1_3), x0_0))

        output = self.final(x0_4)
        return output


In [None]:
# unet_modified_test = UNetModified(num_classes=1, input_channels=6)

# # Create random input sizes
# random_input_image = (1, 6, 350, 350)

# # Get a summary of the input and outputs of Unet
# summary(model=unet_modified_test,
#         input_size= random_input_image,
#         col_names=["input_size", "output_size", "num_params", "trainable"],
#         col_width=20,
#         row_settings=["var_names"])

## Training and Testing

### Loss Classes

In [None]:
# Dice Loss. Source: https://github.com/Mr-TalhaIlyas/Loss-Functions-Package-Tensorflow-Keras-PyTorch
# INPUTS: Logits

class DiceLoss(nn.Module):
    def __init__(self):
        super(DiceLoss, self).__init__()

    def forward(self, logits, labels, smooth=1):
        preds = torch.sigmoid(logits)  # Use torch.sigmoid to ensure compatibility

        # Flatten label and prediction tensors
        preds = preds.view(-1)
        labels = labels.view(-1)


        intersection = (preds * labels).sum()
        dice = (2. * intersection + smooth) / (preds.sum() + labels.sum() + smooth)

        return 1 - dice

### Helper Functions

In [None]:
def train_epoch(model, optimizer, loss_func, dataloader, device, use_distance_maps, use_dems, use_ndvi):
    model.train()
    running_loss = 0.0
    running_iou = 0.0

    # Determine indices to drop
    drop_indices = []
    if not use_distance_maps:
        drop_indices.append(5)  # Index 5 is for distance maps
    if not use_dems:
        drop_indices.append(6)  # Index 6 is for DEMs
    if not use_ndvi:
        drop_indices.append(7)  # Index 7 is for NDVI

    for data_batch, labels_batch, _, _ in dataloader:
        # Exclude specified bands based on flags
        if drop_indices:
            data_batch = data_batch[:, [i for i in range(data_batch.shape[1]) if i not in drop_indices], :, :]

        data_batch = data_batch.to(device)
        labels_batch = labels_batch.to(device)

        optimizer.zero_grad()
        y_hat = model(data_batch)
        loss = loss_func(y_hat, labels_batch)
        iou = calculate_iou(y_hat, labels_batch)

        loss.backward()
        optimizer.step()

        running_loss += loss.item() * data_batch.size(0)
        running_iou += iou.item() * data_batch.size(0)

    train_epoch_loss = running_loss / len(dataloader.dataset)
    train_epoch_iou = running_iou / len(dataloader.dataset)
    print(f'Train - Loss: {train_epoch_loss:.4f}, IoU: {train_epoch_iou:.4f}')

    return train_epoch_loss, train_epoch_iou

In [None]:
def val_epoch(model, loss_func, dataloader, device, use_distance_maps, use_dems, use_ndvi):
    model.eval()  # Set the model to evaluation mode
    running_loss = 0.0
    running_iou = 0.0

    # Initialize lists to store true labels, predictions, and tile IDs
    all_y_true, all_y_pred, all_tile_ids = [], [], []
    
    # Determine indices to drop
    drop_indices = []
    if not use_distance_maps:
        drop_indices.append(5)  # Index 5 is for distance maps
    if not use_dems:
        drop_indices.append(6)  # Index 6 is for DEMs
    if not use_ndvi:
        drop_indices.append(7)  # Index 7 is for NDVI

    with torch.no_grad():
        for data_batch, labels_batch, _, tile_ids_batch in dataloader:
            # Exclude specified bands based on flags
            if drop_indices:
                data_batch = data_batch[:, [i for i in range(data_batch.shape[1]) if i not in drop_indices], :, :]
            
            data_batch = data_batch.to(device)
            labels_batch = labels_batch.to(device)

            # Forward pass
            y_hat = model(data_batch)
            loss = loss_func(y_hat, labels_batch)
            iou = calculate_iou(y_hat, labels_batch)

            # Aggregate loss and IoU
            running_loss += loss.item() * data_batch.size(0)
            running_iou += iou.item() * data_batch.size(0)

            # Store true labels and predictions
            all_y_true.append(labels_batch.cpu().numpy())
            all_y_pred.append(y_hat.cpu().numpy())
            all_tile_ids.extend(tile_ids_batch)

    # Calculate average loss and IoU for the epoch
    val_epoch_loss = running_loss / len(dataloader.dataset)
    val_epoch_iou = running_iou / len(dataloader.dataset)
    print(f'Eval - Loss: {val_epoch_loss:.4f}, IoU: {val_epoch_iou:.4f}')

    # Concatenate arrays of true labels and predictions
    all_y_true = np.concatenate(all_y_true, axis=0)
    all_y_pred = np.concatenate(all_y_pred, axis=0)

    return val_epoch_loss, val_epoch_iou, all_y_true, all_y_pred, all_tile_ids


In [None]:
def load_checkpoint(checkpoints_path, model, optimizer, scheduler):
    filepath = os.path.join(checkpoints_path, f'{model.modelname}.pth')
    if os.path.isfile(filepath):
        checkpoint = torch.load(filepath)
        model.load_state_dict(checkpoint['model_state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])

        # Load the scheduler state only if scheduler is not None and the state is saved in the checkpoint
        if scheduler is not None and 'scheduler_state_dict' in checkpoint:
            scheduler.load_state_dict(checkpoint['scheduler_state_dict'])

        start_epoch = checkpoint['epoch'] + 1  # Continue from next epoch
        min_val_loss = checkpoint['min_val_loss']
        train_loss_history = checkpoint['train_loss_history']
        val_loss_history = checkpoint['val_loss_history']
        train_iou_history = checkpoint['train_iou_history']
        val_iou_history = checkpoint['val_iou_history']
        best_epoch = checkpoint.get('best_epoch', checkpoint['epoch'])  # Load best_epoch, default to current epoch if not found
        print(f"Loaded checkpoint '{filepath}' (epoch {checkpoint['epoch']})")
        return start_epoch, min_val_loss, train_loss_history, val_loss_history, train_iou_history, val_iou_history, best_epoch
    else:
        print(f"No checkpoint found at '{filepath}', starting from scratch")
        return 1, np.inf, [], [], [], [], 0  # Include default best_epoch


In [None]:
def train_model(model, optimizer, loss_func, scheduler, train_dataloader, val_dataloader, device, num_epochs, checkpoints_path, use_distance_maps, use_dems, use_ndvi, use_checkpoint=False):
    since = time.time()
    model.to(device)

    # Initialize training variables
    start_epoch, min_val_loss, train_loss_history, val_loss_history, train_iou_history, val_iou_history = 1, np.inf, [], [], [], []
    counter = 0
    patience = 10
    delta_p = 0.001
    best_epoch = 0
    best_checkpoint_path = os.path.join(checkpoints_path, f'{model.modelname}.pth')

    # Load checkpoint if specified and exists
    if use_checkpoint:
        start_epoch, min_val_loss, train_loss_history, val_loss_history, train_iou_history, val_iou_history, best_epoch = load_checkpoint(checkpoints_path, model, optimizer, scheduler)

    for epoch in range(start_epoch, num_epochs + 1):
        print(f'Epoch {epoch}/{num_epochs}')
        print('-' * 10)

        # Training phase
        train_epoch_loss, train_epoch_iou = train_epoch(model, optimizer, loss_func, train_dataloader, device, use_distance_maps, use_dems, use_ndvi)
        train_loss_history.append(train_epoch_loss)
        train_iou_history.append(train_epoch_iou)

        # Validation phase
        val_epoch_loss, val_epoch_iou, _, _, _ = val_epoch(model, loss_func, val_dataloader, device, use_distance_maps, use_dems, use_ndvi)
        val_loss_history.append(val_epoch_loss)
        val_iou_history.append(val_epoch_iou)

        # Update the learning rate according to the scheduler, if it exists
        if scheduler is not None:
            scheduler.step()

        # Checkpoint
        if val_epoch_loss < min_val_loss - delta_p:
            min_val_loss = val_epoch_loss
            best_epoch = epoch
            counter = 0
            checkpoint = {
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'scheduler_state_dict': scheduler.state_dict() if scheduler is not None else None,  # Save the scheduler state only if it exists
                'min_val_loss': min_val_loss,
                'train_loss_history': train_loss_history,
                'val_loss_history': val_loss_history,
                'train_iou_history': train_iou_history,
                'val_iou_history': val_iou_history,
                'best_epoch': best_epoch
            }
            torch.save(checkpoint, best_checkpoint_path)
            print(f"Checkpoint saved at: {best_checkpoint_path}")
        else:
            counter += 1

        if counter == patience:
            print(f'\nEarly stopping after {patience} epochs without improvement.')
            break

    # Training Time Calculation
    time_elapsed = time.time() - since
    time_epoch = time_elapsed / num_epochs
    time_epoch = f'{time_epoch // 60:.0f}m {time_epoch % 60:.0f}s'
    print(f'Training complete in {time_elapsed // 60:.0f}m {time_elapsed % 60:.0f}s')

    # Load best model weights before returning
    if os.path.exists(best_checkpoint_path):
        model.load_state_dict(torch.load(best_checkpoint_path)['model_state_dict'])

    return model, time_epoch, train_loss_history, val_loss_history, train_iou_history, val_iou_history, best_epoch


In [None]:
# IOU
def calculate_iou(logits, labels, p_threshold=0.5, smooth=1e-6):
    preds = (torch.sigmoid(logits) > p_threshold).float()

    preds = preds.view(-1)
    labels = labels.view(-1)

    intersection = (preds * labels).sum()
    union = preds.sum() + labels.sum() - intersection

    IoU = (intersection + smooth) / (union + smooth)
    
    return IoU

In [None]:
# DiceCoefficient
def calculate_dice_coefficient(logits, labels, p_threshold=0.5, smooth=1e-6):
    preds = (torch.sigmoid(logits) > p_threshold).float()

    preds = preds.view(-1)
    labels = labels.view(-1)

    intersection = (preds * labels).sum()
    dice_coefficient = (2. * intersection + smooth) / (preds.sum() + labels.sum() + smooth)
    
    return dice_coefficient

In [None]:
def plot_learning_curves(train_loss, val_loss, train_iou, val_iou, best_epoch):
    # Create subplots with 1 row and 2 columns
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 5))

    # Adjust epoch values for the training data to shift the curves to the left by 0.5
    epochs_train = [x - 0.5 for x in range(1, len(train_loss) + 1)]
    epochs_val = range(1, len(val_loss) + 1)  # Validation epochs remain unchanged

    # Plot Training and Validation Loss on the first axis
    ax1.set_title('Training and Validation Loss')
    ax1.plot(epochs_train, train_loss, label="Train Loss")  # Use adjusted epochs for training data
    ax1.plot(epochs_val, val_loss, label="Validation Loss")
    ax1.axvline(x=best_epoch, color='grey', linestyle='--', label=f'Best Epoch: {best_epoch}')
    ax1.set_xlabel('Epochs')
    ax1.set_ylabel('Loss')
    ax1.legend()

    # Plot Training and Validation IoU on the second axis
    ax2.set_title('Training and Validation IoU')
    ax2.plot(epochs_train, train_iou, label="Train IoU")  # Use adjusted epochs for training data
    ax2.plot(epochs_val, val_iou, label="Validation IoU")
    ax2.axvline(x=best_epoch, color='grey', linestyle='--', label=f'Best Epoch: {best_epoch}')
    ax2.set_xlabel('Epochs')
    ax2.set_ylabel('IoU')
    ax2.legend()

    # Adjust layout and show the plot
    plt.tight_layout()
    plt.show()

In [None]:
def evaluate_segmentation_performance_global(y_true, y_pred):
    # Convert NumPy arrays back to PyTorch tensors
    y_true = torch.tensor(y_true, dtype=torch.float32)
    y_pred = torch.tensor(y_pred, dtype=torch.float32)

    # Calculate metrics using tensors
    iou = calculate_iou(y_pred, y_true)
    dice_coefficient = calculate_dice_coefficient(y_pred, y_true)

    print(f"Global Metrics in Testing Set: ")
    print(f"IoU: {iou:.4f}")
    print(f"Dice Coefficient: {dice_coefficient:.4f}")


In [None]:
def evaluate_and_plot_performance_individual(y_true, y_pred, tile_ids, num_extremes=5, p_threshold=0.5):
    num_images = y_true.shape[0]
    dice_coefficients = []

    y_true = torch.tensor(y_true, dtype=torch.float32)
    y_pred = torch.tensor(y_pred, dtype=torch.float32)

    # Calculate Dice Coefficients for each image
    for i in range(num_images):
        dice_coef = calculate_dice_coefficient(y_pred[i], y_true[i])
        dice_coefficients.append(dice_coef)
    
    # Sorting indices by Dice Coefficient
    sorted_indices = sorted(range(num_images), key=lambda x: dice_coefficients[x])
    extremes_indices = sorted_indices[:num_extremes] + sorted_indices[-num_extremes:]

    # Plotting the distribution of Dice coefficients
    fig, ax = plt.subplots(1, 1, figsize=(8, 6))
    ax.hist(dice_coefficients, bins=40, color='skyblue')
    ax.set_title('Distribution of Individual Dice Coefficients')
    ax.set_xlabel('Dice Coefficient')
    ax.set_ylabel('Frequency')
    plt.show()

    # Plotting the extremes
    y_pred = (torch.sigmoid(y_pred) > p_threshold).float()
    fig, axs = plt.subplots(len(extremes_indices), 2, figsize=(10, 5 * len(extremes_indices)))

    for i, index in enumerate(extremes_indices):
        # Prediction
        axs[i, 0].imshow(y_pred[index].squeeze().cpu().numpy(), cmap='viridis')
        axs[i, 0].set_title(f'Prediction (Binary) - Tile {tile_ids[index]} - Dice: {dice_coefficients[index]:.4f}')
        axs[i, 0].axis('off')

        # Ground Truth
        axs[i, 1].imshow(y_true[index].squeeze().cpu().numpy(), cmap='viridis')
        axs[i, 1].set_title(f'Ground Truth - Tile {tile_ids[index]}')
        axs[i, 1].axis('off')

    plt.tight_layout()
    plt.show()


In [None]:
def postprocess_and_export_predictions(model, test_given_loader, device, predictions_path, use_distance_maps, use_dems, use_ndvi, postprocess=True, threshold=0, p_threshold=0.5, all_means=all_means, all_stds=all_stds):
    model.to(device)
    model.eval()

    # Ensure the predictions directory exists
    if not os.path.exists(predictions_path):
        os.makedirs(predictions_path)

    # Determine indices to drop
    drop_indices = []
    if not use_distance_maps:
        drop_indices.append(5)  # Index 5 is for distance maps
    if not use_dems:
        drop_indices.append(6)  # Index 6 is for DEMs
    if not use_ndvi:
        drop_indices.append(7)  # Index 7 is for NDVI

    with torch.no_grad():
        for data_batch, _, _, tile_ids_batch in tqdm(test_given_loader, desc='Processing', leave=True):
            # Extract the DEM band
            dem_band_batch = data_batch[:, 6:7, :, :]

            # Unnormalize the DEM band using the mean and standard deviation
            if all_means is not None and all_stds is not None:
                dem_mean = all_means[6]
                dem_std = all_stds[6]
                dem_band_batch = (dem_band_batch * dem_std) + dem_mean

            # Extract the red band for additional verification
            red_band_batch = data_batch[:, 2:3, :, :]

            # Drop bands not included in training
            if drop_indices:
                data_batch = data_batch[:, [i for i in range(data_batch.shape[1]) if i not in drop_indices], :, :]

            data_batch = data_batch.to(device)
            outputs = model(data_batch)
            predictions = torch.sigmoid(outputs) > p_threshold

            for idx, prediction in enumerate(predictions):
                # Verification based on the red band
                invalid_mask = red_band_batch[idx] == -32768

                if postprocess:
                    # Use the previously extracted DEM band for post-processing
                    dem_band = dem_band_batch[idx].to(device)

                    # Create a mask where DEM values are greater than the threshold
                    mask = dem_band > threshold

                    # Adjust prediction tensor shape if necessary
                    if prediction.dim() > mask.dim():
                        prediction = prediction.squeeze(0)

                    prediction[mask] = 0  # Apply the mask to the prediction

                # Set invalid pixels to 0 based on the red band verification
                prediction[invalid_mask] = 0

                # Prepare the prediction for saving
                prediction = prediction.cpu().squeeze().numpy().astype(np.uint8)

                # Save the prediction as a TIFF file
                tiff_filename = os.path.join(predictions_path, f'{tile_ids_batch[idx]}_kelp.tif')
                Image.fromarray(prediction).save(tiff_filename, format='TIFF')

    # Create an archive of all the prediction TIFF files
    archive_name = 'predictions.tar.gz'
    with tarfile.open(os.path.join(predictions_path, archive_name), 'w:gz') as tar:
        for file in os.listdir(predictions_path):
            if file.endswith('.tif'):
                tar.add(os.path.join(predictions_path, file), arcname=file)

    print(f'All predictions are saved and archived in {predictions_path}, archive name: {archive_name}')

### Model Configurations

In [None]:
# Flags to indicate whether to include specific bands for training
use_distance_maps = True  # Include distance maps
use_dems = False  # Include DEMs band
use_ndvi = False  # Include NDVI band

# Base number of bands without additional features
base_bands = 5

# Calculate total number of bands by adding 1 for each additional feature used
n_bands = base_bands + use_distance_maps + use_dems + use_ndvi

In [None]:
# Instantiate Model
#segnet = SEGNET(in_chn = n_bands, out_chn = 1, BN_momentum = 0.5)
unet_modified = UNetModified(num_classes=1, input_channels=n_bands)

In [None]:
# Check if CUDA (GPU) is available
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# Print the selected device
print(device)

In [None]:
# Set Up Loss Function
loss_func = DiceLoss()

In [None]:
# Set Up Optimizer
lr = 0.0001
weight_decay = 1e-7
optimizer = torch.optim.Adam(unet_modified.parameters(), lr = lr, weight_decay = weight_decay)

In [None]:
# Initialize the Exponential Decay Learning Rate Scheduler
gamma = 0.95
scheduler = None

In [None]:
# Set Up the number of epochs
num_epochs = 80

### Training

In [None]:
use_checkpoint = False  # Set to False if you want to start training from scratch
model, time_epoch, train_loss_history, val_loss_history, train_iou_history, val_iou_history, best_epoch = train_model(
    unet_modified, optimizer, loss_func, scheduler, train_loader, val_loader, device, num_epochs, checkpoints, use_distance_maps, use_dems, use_ndvi, use_checkpoint)

In [None]:
# Plot model learning curves
plot_learning_curves(train_loss_history, val_loss_history, train_iou_history, val_iou_history, best_epoch)

### Testing

In [None]:
# Evaluate model on the testing set: Loss and IOU
val_epoch_loss, val_epoch_iou, y_true, y_pred , tile_ids = val_epoch(model, loss_func, test_loader, device, use_distance_maps, use_dems, use_ndvi)

In [None]:
# Metrics GLOBAL
evaluate_segmentation_performance_global(y_true, y_pred)

In [None]:
# Metrics INDIVIDUAL
evaluate_and_plot_performance_individual(y_true, y_pred, tile_ids, num_extremes=5)

### Submission Folder

In [None]:
# Call the function with the modified parameters
postprocess_and_export_predictions(model, test_loader, device, predictions, use_distance_maps, use_dems, use_ndvi, postprocess=True, threshold=0)

## References

* Badrinarayanan, V., Kendall, A., & Cipolla, R. (2017). Segnet: A deep convolutional encoder-decoder architecture for image segmentation. IEEE transactions on pattern analysis and machine intelligence, 39(12), 2481-2495.
* Ronneberger, O., Fischer, P., & Brox, T. (2015). U-net: Convolutional networks for biomedical image segmentation. In Medical Image Computing and Computer-Assisted Intervention–MICCAI 2015: 18th International Conference, Munich, Germany, October 5-9, 2015, Proceedings, Part III 18 (pp. 234-241). Springer International Publishing.