In [3]:
# !pip install lightning

In [None]:
import pytorch_lightning as pl
from torch.utils.data import DataLoader
import albumentations as A
import torchvision.transforms as T
import albumentations.pytorch as pytorch
import albumentations as albu
from typing import Union

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F

from torchmetrics import Accuracy, JaccardIndex, FBetaScore
from typing import Any, Union

In [45]:

from torch.utils.data import Dataset

from sklearn.model_selection import train_test_split
import pandas as pd
import numpy as np

import os
from typing import Any

from PIL import Image

In [None]:
import segmentation_models_pytorch as smp


# ENCODER = 'se_resnext50_32x4d'
ENCODER = 'mobilenet_v2'
ENCODER_WEIGHTS = 'imagenet'
    
CLASSES = ['Background','Thermal Event']
ACTIVATION = 'sigmoid' # could be None for logits or 'softmax2d' for multiclass segmentation
DEVICE = 'cuda'

n_cpu = os.cpu_count()

model_name = 'Unet'


# model = smp.DeepLabV3Plus(
#     encoder_name=ENCODER, 
#     encoder_weights=ENCODER_WEIGHTS, 
#     classes=1, 
#     activation=ACTIVATION,
# )

model = smp.Unet(
    encoder_name=ENCODER, 
    encoder_weights=ENCODER_WEIGHTS, 
    in_channels = 3,
    classes=1, 
    activation=ACTIVATION,
)

preprocessing_fn = smp.encoders.get_preprocessing_fn(ENCODER, ENCODER_WEIGHTS)


# from segmentation_models_pytorch.utils import losses, metrics

from segmentation_models_pytorch.losses import FocalLoss, DiceLoss, JaccardLoss

loss = FocalLoss(mode= 'binary')
loss.__name__ = 'focal_loss'

# # loss = DiceLoss(mode= 'binary')
# # loss.__name__ = 'dice_loss'

# # loss = JaccardLoss(mode= 'binary')
# # loss.__name__ = 'jaccard_loss'

# # loss = losses.DiceLoss()
# # loss = losses.JaccardLoss()

# metrics = [
#     metrics.IoU(),
# ]

# optimizer = torch.optim.Adam([ 
#     dict(params=model.parameters(), lr=1e-3),
# ])

In [None]:
class ThermalDataset(Dataset):
    def __init__(self,
                 stage: str,
                 masks_path: str,
                 images_path: str,
                 augmentation: Any,
                 preprocessing: Any,
                 test_size: float = 0.1,
                 train_size: float = 0.8,
                 val_size: float = 0.1,
                 shuffle: bool = True,
                 random_state: int = 42):

        self.__attribute_checking(masks_path, images_path, test_size,
                                  train_size, val_size,
                                  stage, shuffle, random_state)

        self.masks_path = masks_path
        self.images_path = images_path

        self.augmentation = augmentation
        self.preprocessing = preprocessing
        self.test_size = test_size
        self.train_size = train_size
        self.val_size = val_size
        self.stage = stage
        self.shuffle = shuffle
        self.random_state = random_state
        self.total_len = None
        self._images, self._masks = self.__create_dataset()

    @staticmethod
    def __type_checking(masks_path: str, images_path: str,
                        stage: str, shuffle: bool,
                        test_size: float, train_size: float,
                        val_size: float,  random_state: int) -> None:

        assert isinstance(masks_path, str)
        assert isinstance(images_path, str)
        assert isinstance(test_size, float)
        assert isinstance(train_size, float)
        assert isinstance(val_size, float)
        assert isinstance(stage, str)
        assert isinstance(shuffle, bool)
        assert isinstance(random_state, int)

    @staticmethod
    def __split_checking(train_size: float, test_size: float, val_size: float) -> None:
        total_size = train_size + test_size + val_size
        assert total_size == 1

    @staticmethod
    def __path_checking(masks_path: str, images_path: str) -> None:
        assert os.path.isdir(images_path)
        assert os.path.isdir(masks_path)

    @staticmethod
    def __stage_checking(stage: str) -> None:
        assert stage in ["train", "test", "val"]

    @classmethod
    def __attribute_checking(cls, masks_path: str,
                             images_path: str,
                             test_size: float,
                             train_size: float,
                             val_size: float,
                             stage: str,
                             shuffle: bool,
                             random_state: int) -> None:

        cls.__type_checking(masks_path=masks_path,
                            images_path=images_path,
                            train_size=test_size,
                            test_size=test_size,
                            val_size=val_size,
                            stage=stage,
                            shuffle=shuffle,
                            random_state=random_state)

        cls.__split_checking(train_size=train_size,
                             test_size=test_size,
                             val_size=val_size)

        cls.__path_checking(masks_path=masks_path,
                            images_path=images_path)

        cls.__stage_checking(stage=stage)

    def __create_dataset(self) -> dict:
        dict_paths = {
            "image": [],
            "mask": []
        }

        images_path = self.__split_data(self.stage)

        for image_name in os.listdir(images_path):
            dict_paths["image"].append(os.path.join(images_path,image_name))
            dict_paths["mask"].append(os.path.join(images_path,image_name.replace('_NIR_SWIR','_mask')))

        dataframe = pd.DataFrame(
            data=dict_paths,
            index=np.arange(0, len(dict_paths["image"]))
        )
        self.total_len = len(dataframe)
        data_dict = {self.stage: (dataframe["image"].values,dataframe["mask"].values)}


        return data_dict[self.stage]

    def __split_data(self, stage: str) -> str:
        return os.path.join(self.images_path,stage,'images')

    def __len__(self) -> int:
        return self.total_len

    def __getitem__(self, idx) -> tuple:

        image = Image.open(self._images[idx])
        mask = Image.open(self._masks[idx])
        
        image = np.array(image)

        
        ### FOR FOCAL LOSS
        mask = mask.convert('L') # This ensures that the label only have 1 band, which is necessary for binary classification
        mask = np.array(mask)[:,:,np.newaxis]
        
        mask = np.divide(mask,255).astype('float32')
        
        # # apply augmentation
        if self.augmentation:
            sample = self.augmentation(image=image, mask=mask)
            image, mask = sample['image'], sample['mask']
        
        # apply preprocessing
        if self.preprocessing:
            sample = self.preprocessing(image=image, mask=mask)
            image, mask = sample['image'], sample['mask']

        return image, mask


In [None]:
def get_training_augmentation():
    train_transform = [

        albu.HorizontalFlip(p=0.5),
        albu.VerticalFlip(p=0.5),

        # albu.ShiftScaleRotate(scale_limit=0.5, rotate_limit=0, shift_limit=0.1, p=1, border_mode=0),
        # albu.Lambda(image=to_tensor, mask=to_tensor),

    ]
    return albu.Compose(train_transform)

def to_tensor(x, **kwargs):
    return x.transpose(2, 0, 1).astype('float32')

def get_preprocessing(preprocessing_fn):
    """Construct preprocessing transform
    
    Args:
        preprocessing_fn (callbale): data normalization function 
            (can be specific for each pretrained neural network)
    Return:
        transform: albumentations.Compose
    
    """
    
    _transform = [
        albu.Lambda(image=preprocessing_fn),
        albu.Lambda(image=to_tensor, mask=to_tensor),
    ]
    return albu.Compose(_transform)

In [None]:
class ThermalDataModule(pl.LightningDataModule):
    def __init__(self, target_path: str,
                 data_path: str,
                 augmentation: Union[T.Compose, A.Compose],
                 preprocessing: Any,
                 train_size: float = 0.8,
                 val_size: float = 0.1,
                 test_size: float = 0.1,
                 batch_size: int = 5,
                 num_workers: int = os.cpu_count(),
                 seed: int = 42):
        super().__init__()
        self.batch_size = batch_size
        self.num_workers = num_workers
        self.target_path = target_path
        self.data_path = data_path
        self.train_size = train_size
        self.val_size = val_size
        self.test_size = test_size
        self.data_train = None
        self.data_val = None
        self.data_test = None
        self.data_predict = None
        self.mean = [0.485, 0.456, 0.406]
        self.std = [0.229, 0.224, 0.225]
        self.train_augmentation = augmentation
        self.preprocessing = preprocessing


    def setup(self, stage: str = None) -> None:
        self.data_train = ThermalDataset(
            target_path=self.target_path,
            data_path=self.data_path,
            augmentation=self.train_augmentation,
            preprocessing=self.preprocessing,
            stage="train",
            train_size=self.train_size,
            test_size=self.test_size,
            val_size=self.val_size,
            shuffle=True,
            random_state=42
            )

        self.data_val = ThermalDataset(
            target_path=self.target_path,
            data_path=self.data_path,
            augmentation=self.eval_augmentation,
            preprocessing=self.preprocessing,
            stage="val",
            train_size=self.train_size,
            test_size=self.test_size,
            val_size=self.val_size,
            shuffle=True,
            random_state=42
            )

        self.data_test = ThermalDataset(
            target_path=self.target_path,
            data_path=self.data_path,
            augmentation=self.eval_augmentation,
            preprocessing=self.preprocessing,
            stage="test",
            train_size=self.train_size,
            test_size=self.test_size,
            val_size=self.val_size,
            shuffle=True,
            random_state=42
            )

        self.data_predict = ThermalDataset(
            target_path=self.target_path,
            data_path=self.data_path,
            augmentation=self.eval_augmentation,
            preprocessing=self.preprocessing,
            stage="test",
            train_size=self.train_size,
            test_size=self.test_size,
            val_size=self.val_size,
            shuffle=True,
            random_state=42
            )

    def train_dataloader(self) -> DataLoader:
        return DataLoader(
            dataset=self.data_train,
            batch_size=self.batch_size,
            num_workers=self.num_workers,
            shuffle=True
        )

    def val_dataloader(self) -> DataLoader:
        return DataLoader(
            dataset=self.data_val,
            batch_size=self.batch_size,
            num_workers=self.num_workers,
            shuffle=False
        )

    def test_dataloader(self) -> DataLoader:
        return DataLoader(
            dataset=self.data_test,
            batch_size=self.batch_size,
            num_workers=self.num_workers,
            shuffle=False
        )

    def predict_dataloader(self) -> DataLoader:
        return DataLoader(
            dataset=self.data_predict,
            batch_size=self.batch_size,
            num_workers=self.num_workers,
            shuffle=False
        )

In [None]:
class ThermalModel(pl.LightningModule):
    def __init__(self,
                 model: nn.Module,
                 example_input_array: Union[list, tuple],
                 optim_dict: dict = None,
                 lr: float = None,
                 num_classes: int = 23):
        super().__init__()
        self.save_hyperparameters()
        self.example_input_array = torch.zeros(size=example_input_array)
        self.num_classes = num_classes
        self.model = model
        self.criterion = nn.CrossEntropyLoss()
        self.optim_dict = optim_dict
        self._device = "cuda" if torch.cuda.is_available else "cpu"

        self.step_outputs = {
            "loss": [],
            "accuracy": [],
            "jaccard_index": [],
            "fbeta_score": []
        }

        self.metrics = {
            "accuracy": Accuracy(task="binary",
                                 threshold=0.5,
                                 num_classes=num_classes,
                                 validate_args=True,
                                 ignore_index=None,
                                 average="micro").to(self._device),

            "jaccard_index": JaccardIndex(task="binary",
                                          threshold=0.5,
                                          num_classes=num_classes,
                                          validate_args=True,
                                          ignore_index=None,
                                          average="macro").to(self._device),

            "fbeta_score": FBetaScore(task="binary",
                                      beta=1.0,
                                      threshold=0.5,
                                      num_classes=num_classes,
                                      average="micro",
                                      ignore_index=None,
                                      validate_args=True).to(self._device)
        }

    def forward(self, x):
        return self.model(x)

    def shared_step(self, batch, stage: str) -> torch.Tensor:
        x, y = batch
        x, y = x.to(self._device),y.to(self._device)

        assert x.ndim == 4
        assert x.max() <= 3 and x.min() >= -3 
        assert y.ndim == 4
        assert y.max() <= 1 and y.min() >= 0

        logits = self.forward()
        activated = F.softmax(input=logits, dim=1)
        predictions = torch.argmax(activated, dim=1)
        loss = self.criterion(logits, y)

        accuracy = self.metrics["accuracy"](predictions, y)
        jaccard_index = self.metrics["jaccard_index"](predictions, y)
        fbeta_score = self.metrics["fbeta_score"](predictions, y)

        self.step_outputs["loss"].append(loss)
        self.step_outputs["accuracy"].append(accuracy)
        self.step_outputs["jaccard_index"].append(jaccard_index)
        self.step_outputs["fbeta_score"].append(fbeta_score)
        return loss

    def shared_epoch_end(self, stage: Any):
        loss = torch.mean(torch.tensor([
            loss for loss in self.step_outputs["loss"]
        ]))

        accuracy = torch.mean(torch.tensor([
            accuracy for accuracy in self.step_outputs["accuracy"]
        ]))

        jaccard_index = torch.mean(torch.tensor([
            jaccard_index for jaccard_index in self.step_outputs["jaccard_index"]
        ]))

        fbeta_score = torch.mean(torch.tensor(
            [fbeta_score for fbeta_score in self.step_outputs["fbeta_score"]
             ]))

        for key in self.step_outputs.keys():
            self.step_outputs[key].clear()

        metrics = {
            f"{stage}_loss": loss,
            f"{stage}_accuracy": accuracy,
            f"{stage}_jaccard_index": jaccard_index,
            f"{stage}_fbeta_score": fbeta_score
        }
        self.log_dict(metrics, prog_bar=True)

    def training_step(self, batch: Any, batch_idx: Any):
        return self.shared_step(batch=batch, stage="train")

    def on_train_epoch_end(self) -> None:
        return self.shared_epoch_end(stage="train")

    def validation_step(self, batch: Any, batch_idx: Any):
        return self.shared_step(batch=batch, stage="val")

    def on_validation_epoch_end(self) -> None:
        return self.shared_epoch_end(stage="val")

    def test_step(self, batch: Any, batch_idx: Any):
        return self.shared_step(batch=batch, stage="test")

    def on_test_epoch_end(self) -> None:
        return self.shared_epoch_end(stage="test")

    def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0):
        x, y = batch

        assert x.ndim == 4
        assert x.max() <= 3 and x.min() >= -3
        assert y.ndim == 4
        assert y.max() <= 22 and y.min() >= 0

        logits = self.forward(x)
        activated = F.softmax(input=logits, dim=1)
        predictions = torch.argmax(activated, dim=1)
        return predictions

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(
            params=self.parameters(),
            lr=self.hparams.lr
        )

        scheduler_dict = {
            "scheduler": torch.optim.lr_scheduler.ReduceLROnPlateau(
                optimizer=optimizer,
                patience=5
            ),
            "interval": "epoch",
            "monitor": "val_loss"
        }

        optimization_dictionary = {"optimizer": optimizer, "lr_scheduler": scheduler_dict}
        return self.optim_dict if self.optim_dict else optimization_dictionary


In [None]:
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.callbacks import EarlyStopping
from pytorch_lightning.callbacks import LearningRateMonitor
from pytorch_lightning.loggers import TensorBoardLogger
import pytorch_lightning as pl
import torch
import segmentation_models_pytorch as smp
import warnings
from typing import Union, Any
from colorama import Fore

In [None]:
def main(callbacks: list,
         model: Union[list, tuple],
         logger: TensorBoardLogger,
         data_path: str,
         target_path: str,
         optim_dict: dict,
         example_input_array: Union[list, tuple],
         transforms_dict: dict
         ) -> None:

    # Trainer
    trainer = pl.Trainer(
        fast_dev_run=False,
        accelerator="auto",
        strategy="auto",
        devices="auto",
        num_nodes=1,
        logger=logger,
        callbacks=callbacks,
        max_epochs=80,
        min_epochs=35
    )

    # Transforms Module
    # if not transforms_dict:
    #     transforms = TransformPipelineModule(
    #         height=704,
    #         width=1056,
    #         cut=0,
    #         defocus=True,
    #         pixel_dropout=False,
    #         pieces_dropout=False,
    #         horizontal_lines=False,
    #         vertical_lines=False,
    #         spatial=True,
    #         rain=True,
    #         sunny=True,
    #         snow=False,
    #         foggy=False
    #     )
    # else:
    #     transforms = TransformPipelineModule(
    #         **transforms_dict
    #     )

    # Datamodule
    datamodule = ThermalDataModule(
        target_path=target_path,
        data_path=data_path,
        augmentation=get_training_augmentation(),
        preprocessing=get_preprocessing(preprocessing_fn),
        train_size=0.80,
        val_size=0.1,
        test_size=0.1,
        batch_size=2,
        num_workers=2
    )

    # LightningModule
    lightning_model = ThermalModel(
        model=model,
        optim_dict=optim_dict,
        lr=3e-4,
        example_input_array=example_input_array
    )

    # Start training
    trainer.fit(model=lightning_model, datamodule=datamodule)

In [None]:
# Callbacks
callbacks = [
    ModelCheckpoint(
        dirpath=f"models/{model_name}",
        filename="{epoch}_{val_loss:.2f}_{val_accuracy:.2f}",
        save_top_k=10,
        monitor="val_loss",
        mode="min"
    ),

    EarlyStopping(
        monitor="val_loss",
        min_delta=2e-4,
        patience=8,
        verbose=False,
        mode="min"
    ),

    LearningRateMonitor(
        logging_interval="step"
    )
]