In [None]:
import torch
import os, copy, glob
import sys
import numpy as np
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import segmentation_models_pytorch as smp
from monai.losses import *
import torchmetrics
import pytorch_lightning as pl

import warnings
warnings.filterwarnings("ignore")

In [None]:
import mlflow
from pytorch_lightning.loggers import MLFlowLogger

# Set our tracking server uri for logging
mlflow.set_tracking_uri(uri="http://127.0.0.1:8005")

In [None]:
sys.path.append(r"E:\Projects\Deep_Learning\Car_Segmentation\Utils")

from segmentation_utils import SegmentationDataset, DataAugmentation

### Model Parameters

In [None]:
DATASET_PATH = "../Data/car_dataset"
BASE_PROJECT_DIR = "E:/Projects/Deep_Learning/Car_Segmentation/"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
IMG_SIZE = 512
BATCH_SIZE = 4
EPOCHS = 150
ENCODER =  "xception"
MODEL_TYPE = "Unet"
EXPERIMENT_NAME = "Car_Segmentation"
LOSS_NAME = "DiceCELoss"
LOSS_FN =  DiceCELoss(include_background = True, softmax= True, lambda_dice=0.1, lambda_ce=0.9)
LEARNING_RATE =  0.0001
MODEL_VERSION = 1
RUN_NAME = EXPERIMENT_NAME + "_" + MODEL_TYPE + "_" + ENCODER + "_" + LOSS_NAME + "_" + str(MODEL_VERSION)
print(RUN_NAME)
# OPTIMIZER = {"name": "AdamW"},
# SCHEDULER = {"name": "ReduceLROnPlateau",
#                 "mode": "min",
#                 "patience": 5,
#                 "cooldown": 1,
#                 "verbose": True}
CALLBACKS =  [{"name": "EarlyStopping",
                "monitor": "val_loss",
                "min_delta": 0.001,
                "patience": 15,
                "verbose": True,
                "mode": "min"}]
IN_CHANNELS = 3
NUM_CLASSES = 5
PRECISION = 32
TRAIN_BOOL = True
MODEL_WEIGHTS_PATH = ""
OVERSAMPLE = False
NUM_CORES = 0

AUGMENTATIONS = [
            {
                "name":"RandomAffine",
                "degrees":360,
                "align_corners":True,
                "p":0.6
            },
            {
                "name":"RandomHorizontalFlip",
                "p":0.6
            },
            {
                "name":"RandomVerticalFlip",
                "p":0.6
            },
            {
                "name":"RandomRotation",
                "degrees":360,
                "p":0.6
            }
            ]



### Lightning Module

In [None]:
import torchmetrics.functional.classification
import torchmetrics.functional.segmentation


class SegmentationModule(pl.LightningModule):
    def __init__(self, augmentations, in_channels = 1, out_channels = 1, encoder_name = 'resnet34', encoder_weights = 'imagenet', model_type = 'Unet',
                    img_size = 512, loss_function = {"name":'CELoss', "ALPHA":1.0,"BETA":1.0}, lr = 0.01):
        super().__init__()
        # self.scheduler = scheduler
        # self.optimizer = optimizer
        self.in_channels = in_channels
        self.out_channels = out_channels # number of Classes 
        self.encoder_name = encoder_name
        self.encoder_weights = encoder_weights
        self.model_type = model_type
        self.img_size = img_size
        self.lr = lr
        
        self.augmentations = DataAugmentation(augmentations)
        self.loss_function = loss_function
        
        exec(f"self.model = smp.{self.model_type}(encoder_name=self.encoder_name, encoder_weights=self.encoder_weights, "
                    f"in_channels=self.in_channels, classes=self.out_channels)")
        
        self.save_hyperparameters()
    
    def forward(self, x):
        return self.model(x)
    
    def training_step(self, batch):
        img, mask = batch
        
        ## Apply augmentation to training data
        img, mask = self.augmentations(img, mask)
        
        pred_mask = self(img)
        
        loss = self.loss_function(pred_mask, mask)
        
        pred_mask = torch.softmax(pred_mask,dim=1).squeeze()
        mask = mask.squeeze()
        # mask = torch.argmax(mask, dim=1)
        
        tp, fp, fn, tn = smp.metrics.get_stats(pred_mask, mask.int(), mode='multilabel', threshold=0.5, num_classes=self.out_channels)
        iou_score = smp.metrics.iou_score(tp, fp, fn, tn, reduction="micro")
        
        dice_score = torchmetrics.functional.classification.multiclass_f1_score(preds = pred_mask, target = mask.int(), num_classes = self.out_channels)
        
        self.log('train_loss', loss, on_step=False, on_epoch=True, prog_bar=False, logger=True)
        self.log('train_iou', iou_score, on_step=False, on_epoch=True, prog_bar=False, logger=True)
        self.log('train_dice', dice_score, on_step=False, on_epoch=True, prog_bar=False, logger=True)
        
        return loss

    def validation_step(self, batch):
        img, mask = batch
        
        pred_mask = self(img)
        
        loss = self.loss_function(pred_mask, mask)
        
        pred_mask = torch.softmax(pred_mask,dim=1).squeeze()
        mask = mask.squeeze()
        # mask = torch.argmax(mask, dim=1)
        
        tp, fp, fn, tn = smp.metrics.get_stats(pred_mask, mask.int(), mode='multilabel', threshold=0.5, num_classes=self.out_channels)
        iou_score = smp.metrics.iou_score(tp, fp, fn, tn, reduction="micro")
        
        dice_score = torchmetrics.functional.classification.multiclass_f1_score(preds = pred_mask, target = mask.int(), num_classes = self.out_channels)
        
        self.log('val_loss', loss, on_step=False, on_epoch=True, prog_bar=False, logger=True)
        self.log('val_iou', iou_score, on_step=False, on_epoch=True, prog_bar=False, logger=True)
        self.log('val_dice', dice_score, on_step=False, on_epoch=True, prog_bar=False, logger=True)
        
        return loss
    
    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(self.parameters(), lr=self.lr)
        scheduler = {'scheduler': torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=0.1, mode = "min", patience =  3, cooldown = 1, verbose=True),
                        'monitor': 'val_loss'}
        return [optimizer], [scheduler]


In [None]:
mlflow_logger = MLFlowLogger(
    experiment_name= EXPERIMENT_NAME,
    run_name= RUN_NAME,
    save_dir='E:/Projects/Deep_Learning/Car_Segmentation/mlruns',
    log_model = True,
    artifact_location = 'E:/Projects/Deep_Learning/Car_Segmentation/artifacts'
)

In [None]:
def main():
    print("---Training Staring---")
    
    # Define datasets
    whole_dataset = SegmentationDataset(dirPath=DATASET_PATH, imageDir='images/', masksDir='masks/', img_size=IMG_SIZE)
    
    train_size = int(0.8 * len(whole_dataset))
    val_size = len(whole_dataset) - train_size
    train_ds, val_ds = torch.utils.data.random_split(whole_dataset, [train_size, val_size])
    
    print(f"Length of Train dataset = {len(train_ds)}")
    print(f"Length of Val dataset = {len(val_ds)}")

    # Define DataLoaders
    train_dataloader = DataLoader(train_ds, BATCH_SIZE, shuffle=True, num_workers=0, pin_memory=False)
    val_dataloader = DataLoader(val_ds, BATCH_SIZE, shuffle=False, num_workers=0, pin_memory=False)

    print("Created Train and Val dataloaders")
    
    model = SegmentationModule(augmentations=copy.deepcopy(AUGMENTATIONS), in_channels=IN_CHANNELS, out_channels=NUM_CLASSES, encoder_name=ENCODER, encoder_weights=None, model_type=MODEL_TYPE, img_size=IMG_SIZE, 
                                loss_function=LOSS_FN, lr=LEARNING_RATE)
    
    # Create model directory
    os.makedirs(f'{BASE_PROJECT_DIR}/Model Checkpoints/{RUN_NAME}/', exist_ok=True)
    print(f"Created directory: {BASE_PROJECT_DIR}/Model Checkpoints/{RUN_NAME}/")
    
    callback_list = []
    early_stopping = pl.callbacks.EarlyStopping(monitor = 'val_loss', min_delta = 0.001, patience = 15, verbose = True, mode = "min")
    model_checkpoint = pl.callbacks.ModelCheckpoint(monitor = 'val_loss', dirpath = f'{BASE_PROJECT_DIR}/Model Checkpoints/{RUN_NAME}/', verbose = True)
    
    callback_list.append(early_stopping)
    callback_list.append(model_checkpoint)
    
    trainer = pl.Trainer(accelerator='gpu', devices='auto', accumulate_grad_batches= max(1, 16 // BATCH_SIZE),
                            benchmark=True, precision=PRECISION, min_epochs=5, max_epochs=EPOCHS,
                            logger=mlflow_logger, num_sanity_val_steps=5, callbacks=callback_list)
    
    # Train model
    trainer.fit(model, train_dataloader, val_dataloader)
    print("\nTraining finished")
    
    # -- Save model --
    # Find latest checkpoint
    best_chkpt = max(glob.glob(f'{BASE_PROJECT_DIR}/Model Checkpoints/{RUN_NAME}/*.ckpt'),
                    key=os.path.getctime)
    print(f"Found latest checkpoint at:\n{best_chkpt}")

    # Load latest checkpoint
    model = model.load_from_checkpoint(best_chkpt)

    # Convert the model and weights to torchscript and Save
    model.eval()

    torch.save(model.state_dict(), f'{BASE_PROJECT_DIR}/Model Checkpoints/{RUN_NAME}/{RUN_NAME}_weights.pth')
    print(f"Model Weights saved in {BASE_PROJECT_DIR}/Model Checkpoints/{RUN_NAME}/")

    try:
        model.to_torchscript(
            file_path=f'{BASE_PROJECT_DIR}/Model Checkpoints/{RUN_NAME}/{RUN_NAME}_traceModel.pth',
            method='trace', example_inputs=torch.rand(1, IN_CHANNELS, IMG_SIZE, IMG_SIZE))
        print(f"TorchScript Trace Model saved in {BASE_PROJECT_DIR}/Model Checkpoints/{RUN_NAME}/")
    except:
        print(f"Trace Model could not be created for {RUN_NAME} ")


In [None]:
if __name__ == "__main__":
    main()