# **COMP9517 - Group Project (Segmentation Models Pytorch)**

## **0. Add Imports**

In [None]:
# !conda env create -f environment.yaml

In [None]:
import os
import sys
sys.path.append(os.path.join(os.getcwd(), '..'))

import lightning as L
from lightning.pytorch.callbacks import ModelCheckpoint, LearningRateMonitor, EarlyStopping, RichModelSummary, Timer
from lightning.pytorch.loggers import TensorBoardLogger, WandbLogger

from data import AerialDeadTreeSegDataModule, download_dataset
from lightning.pytorch.tuner import Tuner
from lightning_modules import SMPLitModule
from models import FreezeSMPEncoderUtils
from utils import paths
import segmentation_models_pytorch as smp


In [None]:
ARCH = "Unet"
ENCODER_NAME = "efficientnet-b5"
MODALITY = "merged"
TARGET_SIZE = 256
VERSION = f"{MODALITY}_{TARGET_SIZE}"
BATCH_SIZE = 32
MAX_EPOCHS = 100
EARLY_STOPPING_PATIENCE = 20
LOSS1 = smp.losses.JaccardLoss(mode='binary', from_logits=True)
LOSS2 = smp.losses.FocalLoss(mode='binary')
MIN_LR = 1e-3
MAX_LR = 0.1  # Maximum learning rate for the learning rate finder
PRECISION = "bf16-mixed"

## **1. Simple Summary of the Dattaset**

In [None]:
from PIL import Image
# Get the data folder
data_folder = download_dataset()

rgb_dir = os.path.join(data_folder, "RGB_images")
nrg_dir = os.path.join(data_folder, "NRG_images")
mask_dir = os.path.join(data_folder, "masks")

# Get the max and min resolution of the images
def get_max_min_resolution(image_dir):
    max_res = (0, 0)
    min_res = (float('inf'), float('inf'))
    for filename in os.listdir(image_dir):
        if filename.endswith(".png"):
            filepath = os.path.join(image_dir, filename)
            with open(filepath, 'rb') as f:
                img = Image.open(f)
                height, width = img.size
                max_res = (max(max_res[0], height), max(max_res[1], width))
                min_res = (min(min_res[0], height), min(min_res[1], width))
    return max_res, min_res
max_rgb_res, min_rgb_res = get_max_min_resolution(rgb_dir)
max_nrg_res, min_nrg_res = get_max_min_resolution(nrg_dir)
print(f"Max RGB resolution: {max_rgb_res}, Min RGB resolution: {min_rgb_res}")
print(f"Max NRG resolution: {max_nrg_res}, Min NRG resolution: {min_nrg_res}")

## **2. Prepare Data Module**

In [None]:
data_module = AerialDeadTreeSegDataModule(
    val_split=0.1, test_split=0.2, seed=42,
    modality=MODALITY, # in_channels=4. If modality is "merged", it will use 4 channels (RGB + NIR); Otherwise, it will use 3 channels (RGB).
    batch_size=BATCH_SIZE,
    num_workers= int(os.cpu_count() / 2) if os.cpu_count() is not None else 0,
    target_size=TARGET_SIZE)

## **3. Create Segmentation Models**

In [None]:
model = SMPLitModule(
    arch=ARCH,
    encoder_name=ENCODER_NAME,
    encoder_weights="imagenet",
    in_channels=data_module.in_channels,
    out_classes=1,  # Binary segmentation
    loss1=LOSS1,
    loss2=LOSS2,
)

## **4. Create Trainer**

In [None]:
model_sum_callback = RichModelSummary(max_depth=2)

lr_monitor = LearningRateMonitor(logging_interval='step')

early_stop_callback = EarlyStopping(
    monitor="per_image_iou/val",
    patience=EARLY_STOPPING_PATIENCE,
    verbose=True,
    mode="max"  # Maximize the metric
)

timer = Timer(interval="epoch", verbose=False)
                
checkpoint_callback = ModelCheckpoint(
    dirpath=os.path.join(paths.checkpoint_dir, f"smp_{ENCODER_NAME}_{ARCH}", VERSION),
    monitor="per_image_iou/val",
    filename="{epoch:02d}-{per_image_iou_val:.4f}",
    mode="max",
    save_top_k=2,
    enable_version_counter=True,
)

In [None]:
logger = TensorBoardLogger(paths.tensorboard_log_dir, name=f"smp_{ENCODER_NAME}_{ARCH}", version=VERSION)
trainer = L.Trainer(
    precision=PRECISION,
    max_epochs=MAX_EPOCHS,
    enable_progress_bar=True,
    logger=logger,
    callbacks=[
        model_sum_callback,
        lr_monitor,
        early_stop_callback,
        timer,
        checkpoint_callback
    ],
    log_every_n_steps=5,
)

## **5. Find Suggested Learning Rate**

In [None]:
tuner = Tuner(trainer)
lr_finder = tuner.lr_find(model, datamodule=data_module,
                          min_lr=MIN_LR, max_lr=MAX_LR,
                          num_training=100, early_stop_threshold=4)
fig = lr_finder.plot(suggest=True, show=True)
fig.show()

## **6. Train and Test Segmentation Model**

In [None]:
trainer.fit(model, datamodule=data_module)
print("Training starting time: ", timer.start_time("train"))
print("Time elapsed: ", timer.time_elapsed("train"))

In [None]:
trainer.test(model, datamodule=data_module)
print("Time elapsed: ", timer.time_elapsed("test"))