In [None]:
import numpy as np
import matplotlib.pyplot as plt

import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.loggers import CSVLogger
from pytorch_lightning.callbacks.early_stopping import EarlyStopping

from torchvision.transforms import Resize, InterpolationMode, ToPILImage, RandomCrop
import torchmetrics
from torchmetrics import JaccardIndex, Precision, Recall, F1Score
from src.callbacks.SaveRandomImagesCallback import SaveRandomImagesCallback
from src.callbacks.SaveTestPreds import SaveTestPreds
import segmentation_models_pytorch as smp

from src.evaluation.evaluate_result import evaluate_result
from src.datasets.UAVidSemanticSegmentationDataset import (
    UAVidSemanticSegmentationDataset,
)
from src.utils import view_and_save_images_shapes

## Prepare environment

In [None]:
torch.cuda.is_available()

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# if you get some cryptic CUDA error, set device to "cpu" and try again

# DEBUG
# device = torch.device("cpu")
print(device)

In [None]:
VAL_SIZE = 0.2
BATCH_SIZE = 1
SEED = 42
SAVE_VAL_DIR = "outputs/UAVid/val"
SAVE_TEST_DIR = "outputs/UAVid/test"
UAVID_DATASET_PATH = "data/UAVidSemanticSegmentationDataset"
IMAGE_WIDTH = 1024
IMAGE_HEIGHT = 576

## Prepare data

In [None]:
train_dataset = UAVidSemanticSegmentationDataset(
    UAVID_DATASET_PATH,
    split="train",
)
print(len(train_dataset))

In [None]:
val_dataset = UAVidSemanticSegmentationDataset(
    UAVID_DATASET_PATH,
    split="valid",
)
print(len(val_dataset))

In [None]:
test_dataset = UAVidSemanticSegmentationDataset(
    UAVID_DATASET_PATH,
    split="test",
)
print(len(test_dataset))

### Train, val, test split

In [None]:
train_loader = DataLoader(
    train_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=8
)
val_loader = DataLoader(
    val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=8
)
test_loader = DataLoader(
    test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=8
)

### Sanity check

... for this to make sense don't use any transforms in datasets

In [None]:
# view_and_save_images_shapes(train_loader, 'uavid_train_images_shapes', verbose=True)

In [None]:
# view_and_save_images_shapes(val_loader, 'uavid_val_images_shapes')

In [None]:
# view_and_save_images_shapes(test_loader, 'uavid_test_images_shapes', verbose=True, train=False)

In [None]:
for images, masks in train_loader:
    print(images.shape)
    print(masks.shape)
    break

In [None]:
to_pil_transform = ToPILImage()

In [None]:
if BATCH_SIZE == 1:
    img = to_pil_transform(images.squeeze())

In [None]:
# img.show()

In [None]:
if BATCH_SIZE == 1:
    msk = to_pil_transform(masks.squeeze())

In [None]:
# msk.show()

## Basic training loop

In [None]:
import gc

gc.collect()

In [None]:
torch.cuda.empty_cache()

# Training module

In [None]:
class SegmentationModel(pl.LightningModule):
    def __init__(self, model, learning_rate=1e-3):
        super(SegmentationModel, self).__init__()
        
        self.model = model
        self.learning_rate = learning_rate

        self.criterion = smp.losses.JaccardLoss(mode="multiclass", from_logits=False)
        self.train_iou = JaccardIndex(num_classes=8, task="multiclass")
        self.val_iou = JaccardIndex(num_classes=8, task="multiclass")

        # self.save_hyperparameters()

    def forward(self, x):
        output = self.model(x.to(device))
        return output

    def training_step(self, batch, batch_idx):
        images, masks = batch
        preds = self(images)
        loss = self.criterion(preds, masks.squeeze(1).long())

        self.log("train_loss", loss, on_epoch=True, on_step=True)
        self.log("train_iou", self.train_iou(preds, masks.squeeze(1)), on_epoch=True, on_step=True)

        return loss

    def validation_step(self, batch, batch_idx):
        images, masks = batch
        preds = self(images)
        loss = self.criterion(preds, masks.squeeze(1).long())

        self.log("val_loss", loss, on_epoch=True, on_step=True)
        self.log("val_iou", self.val_iou(preds, masks.squeeze(1)), on_epoch=True, on_step=True)
        
        return loss

    def test_step(self, batch, batch_idx):
        # just here to activate the test_epoch_end
        # callback SaveTestPreds starts on_test_epoch_end 
        pass

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)
        return optimizer

### Model selection

In [None]:
model = smp.Unet(
    encoder_name="resnet18",  # choose encoder, e.g. mobilenet_v2 or efficientnet-b7
    encoder_weights="imagenet",  # use `imagenet` pre-trained weights for encoder initialization
    in_channels=3,  # model input channels (1 for gray-scale images, 3 for RGB, etc.)
    classes=8,  # model output channels (number of classes in your dataset)
    activation="softmax",
).to(device)

In [None]:
# model = smp.UnetPlusPlus(
#     encoder_name="resnet18",  # choose encoder, e.g. mobilenet_v2 or efficientnet-b7
#     encoder_weights="imagenet",  # use `imagenet` pre-trained weights for encoder initialization
#     in_channels=3,  # model input channels (1 for gray-scale images, 3 for RGB, etc.)
#     classes=1,  # model output channels (number of classes in your dataset)
# ).to(device)

In [None]:
# model = smp.DeepLabV3(
#     encoder_name="resnet18",  # choose encoder, e.g. mobilenet_v2 or efficientnet-b7
#     encoder_weights="imagenet",  # use `imagenet` pre-trained weights for encoder initialization
#     in_channels=3,  # model input channels (1 for gray-scale images, 3 for RGB, etc.)
#     classes=1,
# ).to(device)

In [None]:
# model = smp.DeepLabV3Plus(
#     encoder_name="resnet18",  # choose encoder, e.g. mobilenet_v2 or efficientnet-b7
#     encoder_weights="imagenet",  # use `imagenet` pre-trained weights for encoder initialization
#     in_channels=3,  # model input channels (1 for gray-scale images, 3 for RGB, etc.)
#     classes=1,
# ).to(device)

In [None]:
segmentation_model = SegmentationModel(model)

In [None]:
model_checkpoint_callback = ModelCheckpoint(
    monitor="val_loss", save_top_k=-1, mode="min"
)

In [None]:
logger = CSVLogger("lightning_logs", name="uavid_segmentation_model")

In [None]:
save_callback = SaveRandomImagesCallback(save_dir=SAVE_VAL_DIR)

# Dubai test dataset is also labeled
# so we can use different logic for test set evaluation
save_test_preds_callback = SaveTestPreds(save_dir=SAVE_TEST_DIR)

early_stopping_callback = EarlyStopping(monitor="val_loss", mode="min", patience=5)

In [None]:
trainer = pl.Trainer(
    max_epochs=1,
    log_every_n_steps=10,
    callbacks=[
        model_checkpoint_callback,
        save_callback,
        save_test_preds_callback,
        early_stopping_callback,
    ],
    logger=logger,
)

# DEBUG
# trainer = pl.Trainer(
#     max_epochs=1,
#     callbacks=[model_checkpoint_callback, save_callback],
#     logger=logger,
#     accelerator="cpu"
#     )

In [None]:
trainer.fit(
    segmentation_model, train_dataloaders=train_loader, val_dataloaders=val_loader
)

In [None]:
trainer.test(ckpt_path="best", dataloaders=test_loader)