In [1]:
import sys
sys.path.append('/content')

In [2]:
%cd /content

/content


# Setup

In [None]:
!pip install -r requirements.txt

In [4]:
import os
import torch
import glob
import time
from torch import nn
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from torchvision import transforms
import segmentation_models_pytorch as smp
from torch.optim import AdamW
from torch.optim.lr_scheduler import ReduceLROnPlateau
from sklearn.metrics import f1_score, accuracy_score, precision_score, recall_score, jaccard_score
import matplotlib.pyplot as plt
import numpy as np
from PIL import Image

import random
import matplotlib.colors as mcolors
from skimage.io import imread, imshow

from skimage.io import imread
import albumentations as A
from typing import List

import datetime

In [None]:
def set_all_seeds(seed: int = 42) -> None:
    np.random.seed(seed)
    random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    # When running on the CuDNN backend, two further options must be set
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    # Set a fixed value for the hash seed
    os.environ["PYTHONHASHSEED"] = str(seed)
    print(f"Random seed set as {seed}")

set_all_seeds()

In [6]:
# DATASET
dataset_dir = "dataset/"
# ADJUST THESE 2 VALUES DEPENDING ON THE TRAINING TIME
train_test_split = 0.5
n_samples_per_zone = 20 # 50 recommended
n_zones = 1 # up to 10 zones
# Order of bands : B, G, R, nir, nir_vegetation, swir, ndvi, ndwi, ndmi
# Keep all bands by default and reduce if training takes too long (keep a minima the first 4 + ndvi and ndmi)
bands_to_keep = list(range(9))

# MODEL
num_channels = len(bands_to_keep)
num_classes = 1 # Mangrove class
encoder_name = "resnet50"
encoder_weights = None
activation = 'sigmoid' # Mangrove vs Non-Mangrove
use_augmentation = True
name_model = "UNet-Resnet50"
model_save_path_epochs = f'model/{name_model}/epochs/'
os.makedirs(model_save_path_epochs, exist_ok=True)
model_save_path_metrics = f'model/{name_model}/metrics/'
os.makedirs(model_save_path_metrics, exist_ok=True)

save_interval = 1

# TRAINING
batch_size = 16
learning_rate = 0.0001
num_epochs = 100


# COMPUTATION & PRINTS (PYTORCH LIGHTNING)
accelerator = 'gpu'
strategy =  'auto'#'ddp' if multiple GPUs otherwise leave emtpy if single GPU training
num_nodes = 1
gpus_per_node = 1
num_workers = 1
enable_progress_bar = True
progress_rate = 10

# Display random examples

In [None]:
import image_utils  # Import the image_utils module

# Now you can use functions from image_utils
path_to_2020_sentinel_images_folder = os.path.join(dataset_dir, "satellite-images")
path_to_2020_masks_folder = os.path.join(dataset_dir, "masks")

images = image_utils.get_all_file_paths(path_to_2020_sentinel_images_folder)
masks = image_utils.get_all_file_paths(path_to_2020_masks_folder)

image_utils.display_samples(images, masks, nb_samples=5)


# Mangrove Dataset

In [8]:
from dataset_utils import get_train_test_paths_by_zone, MangroveSegmentationDataset
import torch
from torch.utils.data import DataLoader

# Get train and test paths
full_paths_train, full_paths_test = get_train_test_paths_by_zone(dataset_dir, train_test_split, n_samples_per_zone, n_zones)

# Create dataset instances
dataset_train = MangroveSegmentationDataset(full_paths_train, bands_to_keep=bands_to_keep, use_augmentation=use_augmentation)
dataset_test = MangroveSegmentationDataset(full_paths_test, bands_to_keep=bands_to_keep, use_augmentation=False)

# Create DataLoaders
dataloader_train = DataLoader(dataset_train, batch_size=batch_size, shuffle=True)
dataloader_test = DataLoader(dataset_test, batch_size=batch_size, shuffle=False)


# Model

In [None]:
# To use Unet
model = smp.Unet(
    encoder_name = encoder_name,
    in_channels=num_channels,
    activation = activation,
    classes=1,
)

In [None]:
# To use PAN

# model = smp.PAN(
#     encoder_output_stride=16,
#     upsampling=4,
#     encoder_name = encoder_name,
#     #decoder_channels=decoder_channels,
#     in_channels=num_channels,
#     decoder_channels=512,
#     activation = activation,
#     classes=1,
# )

In [None]:
# To use MAnet

# model = smp.MAnet(
#     encoder_name = encoder_name,
#     encoder_depth = encoder_depth,
#     decoder_channels=decoder_channels,
#     in_channels=num_channels,
#     activation = activation,
#     classes=1,
# )

In [None]:
def count_parameters(model):
    # Number of parameters in millions
    return sum(p.numel() for p in model.parameters() if p.requires_grad)/1e6

count_parameters(model)
# 32.54 for resnet50, 24.45 for resnet34, 14.34 for resnet18

# Training

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(f"Using {device} device")
criterion = nn.BCELoss()
optimizer = torch.optim.AdamW(model.parameters(),lr=learning_rate)
scheduler = ReduceLROnPlateau(optimizer, 'min', patience=7, factor=0.5, verbose=False)
model.to(device)

In [None]:
from train_model import train_final_model

# Call the training function
mean_loss_train, mean_f1_train, mean_iou_train, mean_accuracy_train, mean_loss_test, mean_f1_test, mean_iou_test, mean_accuracy_test, elapsed_time, best_model_filename, best_mean_iou_test = train_final_model(
    model,
    dataloader_train,
    dataloader_test,
    num_epochs,
    optimizer,
    scheduler,
    criterion,
    device,
    model_save_path_epochs,
    save_interval
)

In [None]:
print(f"Elapsed time of {elapsed_time} seconds for {num_epochs} epochs => best test IOU = {best_mean_iou_test}")

In [None]:
from train_model import save_metrics_to_file

lists = {
    "mean_loss_train": mean_loss_train,
    "mean_f1_train": mean_f1_train,
    "mean_iou_train": mean_iou_train,
    "mean_accuracy_train": mean_accuracy_train,
    "mean_loss_test": mean_loss_test,
    "mean_f1_test": mean_f1_test,
    "mean_iou_test": mean_iou_test,
    "mean_accuracy_test": mean_accuracy_test,
    "elapsed_time": elapsed_time,
    "best_model_filename": best_model_filename,
    "best_mean_iou_test": best_mean_iou_test
}

save_metrics_to_file(lists, model_save_path_metrics+"metrics.txt")

# Plot training results

In [None]:
import plot_training_results

# Plot training and validation loss
plot_training_results.plot_train_val_loss(mean_loss_train, mean_loss_test)

# Plot training and validation IOU score
plot_training_results.plot_train_val_iou(mean_iou_train, mean_iou_test)

# Plot training and validation F1 score
plot_training_results.plot_train_val_f1(mean_f1_train, mean_f1_test)

# Plot training and validation accuracy
plot_training_results.plot_train_val_acc(mean_accuracy_train, mean_accuracy_test)


# Plot segmentation results

## Show inputs

In [None]:
import os
import numpy as np
import matplotlib.pyplot as plt

from plot_segmentation_results import plot_augmented_images

# Save images and mask
save_bands = f'model/{name_model}/plots_bands'
os.makedirs(save_bands, exist_ok=True)

# Call the function to plot augmented images
batch = next(iter(dataloader_test))
for idx, (test_image, true_mask) in enumerate(zip(batch[0], batch[1])):
    test_image_np = test_image.permute(1, 2, 0).cpu().detach().numpy()
    true_mask_np = true_mask.cpu().detach().numpy().squeeze().astype('int')
    plot_augmented_images(test_image_np, true_mask_np, idx, save_bands)



## Show predictions

In [None]:
import os
import numpy as np
import matplotlib.pyplot as plt

from plot_segmentation_results import plot_segmentation_results

# SHOW PREDICTIONS ON FULL BATCH
save_comparisons = f'model/{name_model}/plots_comparison'
os.makedirs(save_comparisons, exist_ok=True)

batch = next(iter(dataloader_test))
best_model = model
plot_segmentation_results(batch, device, model_save_path_epochs, best_model, best_model_filename, save_comparisons)