In [None]:
import lightning.pytorch as pl
import os
import torch
import torchvision
import numpy as np
from PIL import Image
from torch import nn
import torch.nn.functional as F
from torchvision.models.segmentation.deeplabv3 import deeplabv3_mobilenet_v3_large, DeepLabV3_MobileNet_V3_Large_Weights
from torchvision.models.segmentation.deeplabv3 import deeplabv3_resnet50, DeepLabV3_ResNet50_Weights
from torchvision.models.segmentation.deeplabv3 import deeplabv3_resnet101, DeepLabV3_ResNet101_Weights
from torch.optim import Adam, SGD, LBFGS, Adadelta, Adamax, Adagrad, ASGD
from torch.optim.lr_scheduler import CyclicLR, PolynomialLR, CosineAnnealingWarmRestarts
from torch.optim.lr_scheduler import ReduceLROnPlateau, ConstantLR, StepLR, CosineAnnealingLR
from lightning.pytorch.callbacks import ModelCheckpoint, StochasticWeightAveraging, EarlyStopping
from lightning.pytorch.callbacks import ModelSummary, LearningRateFinder, TQDMProgressBar, DeviceStatsMonitor
from lightning.pytorch.profilers import AdvancedProfiler, PyTorchProfiler
from lightning.pytorch.loggers import TensorBoardLogger
import sys
#from utils.losses import IoULoss, DiceLoss, TverskyLoss, FocalTverskyLoss, HybridLoss, FocalHybridLoss
from utils.datasets import CityscapesDataModule, cityscapes_color_map #, MapillaryDataset
from torch.utils.data import DataLoader
from torchsummary import summary
#from utils.eval import MeanIoU
#from utils.models import  Unet, Residual_Unet, Attention_Unet, Unet_plus, DeepLabV3plus
import yaml
from torchvision.utils import draw_segmentation_masks
from torchmetrics import JaccardIndex

In [None]:
# Read YAML file
print('Reading configuration from config yaml')

with open('config/Cityscapes.yaml', 'r') as config_file:
    config = yaml.safe_load(config_file)

# TODO: Add default values if a variable is not defined in the config file

LOGS_DIR = config.get('logs_dir')
model_config = config.get('model_config')
dataset_config = config.get('dataset_config')
train_config = config.get('train_config')

# Dataset Configuration
DATASET = dataset_config.get('name')
NUM_TRAIN_BATCHES = dataset_config.get('num_train_batches', 1.0)
NUM_EVAL_BATCHES = dataset_config.get('num_eval_batches', 1.0)
BATCH_SIZE = dataset_config.get('batch_size') #
SEED = dataset_config.get('seed')

# Model Configuration
MODEL_TYPE = model_config.get('architecture')
MODEL_NAME = model_config.get('name')
BACKBONE = model_config.get('backbone')
UNFREEZE_AT = model_config.get('unfreeze_at')
INPUT_SHAPE = model_config.get('input_shape')
OUTPUT_STRIDE = model_config.get('output_stride')
FILTERS = model_config.get('filters')
ACTIVATION = model_config.get('activation')
DROPOUT_RATE = model_config.get('dropout_rate')

# Training Configuration
# PRETRAINED_WEIGHTS = model_config['pretrained_weights']


EPOCHS = train_config.get('epochs') #
AUGMENTATION = train_config.get('augment') #
PRECISION = str(train_config.get('precision')) #

# Stohastic weight averaging parameters
SWA = train_config.get('swa')
if SWA is not None:
    SWA_LRS = SWA.get('lr', 1e-3)
    SWA_EPOCH_START = SWA.get('epoch_start', 0.7)

DISTRIBUTE_STRATEGY = train_config.get('distribute').get('strategy')
DEVICES = train_config.get('distribute').get('devices')

# save the config in the hparams.yaml file 
# with open(f'{LOGS_DIR}/my.yaml', 'w') as config_file:
#     config = yaml.safe_dump(config, config_file)

In [None]:
from typing import Any

_EVAL_IDS =   [7,8,11,12,13,17,19,20,21,22,23,24,25,26,27,28,31,32,33, 0] # MAP VOID CLASS TO 0 -> TOTAL BLACK 
_TRAIN_IDS =  [0,1, 2, 3, 4, 5, 6, 7, 8, 9,10,11,12,13,14,15,16,17,18,19]

grayscale_path = f'{LOGS_DIR}/predictions/{MODEL_TYPE}/{MODEL_NAME}/test/grayscale'
rgb_path = f'{LOGS_DIR}/predictions/{MODEL_TYPE}/{MODEL_NAME}/test/rgb'

os.makedirs(grayscale_path, exist_ok=True)
os.makedirs(rgb_path, exist_ok=True)

from torchvision.transforms.v2 import functional as F
class DeepLabV3(pl.LightningModule):
    def __init__(self, 
                 model: nn.Module,
                 train_config: dict = None
                 ) -> None:        
        super().__init__()
        self.save_hyperparameters(ignore='model')
        
        loss = train_config.get('loss', 'CrossEntropy')
        
        self.model = model
        self.loss = self.get_loss(loss)
        self.optimizer_config = train_config.get('optimizer')
        self.lr_schedule_config = train_config.get('lr_schedule')
        self.batch_size = train_config.get('batch_size')
        
        # Metrics
        self.train_mean_iou = JaccardIndex(task='multiclass', 
                                           num_classes=20,
                                           average='macro',
                                           ignore_index=19)
        
        self.val_mean_iou = JaccardIndex(task='multiclass', 
                                         num_classes=20,
                                         average='micro',
                                         ignore_index=19)
        #self.example_input_array = torch.Tensor(4, 3, 1024, 2048)
        
    def get_lr_schedule(self, optimizer):
        lr = self.optimizer_config.get('learnin_rate', 1e-3)
        schedule = self.lr_schedule_config.get('name')
        
        # num of steps in cyclic lr should be cycle_epochs * steps_per_epoch
        # steps_per_epoch is defined depended on the length of the dataset
        # so maybe define the dataset inside the Lightning Module using 
        # the DataModule object
        
        if schedule in ['Polynomial', 'PolynomialLr', 'PolynomialLR', 'polynomial']:
            decay_epochs = self.lr_schedule_config.get('decay_epochs')
            power = self.lr_schedule_config.get('power')
            lr_schedule = PolynomialLR(
                optimizer=optimizer,
                total_iters=decay_epochs, #*steps_per_epoch,
                power=power,
                verbose=True
            )
            
        elif schedule in ['CyclicLR', 'Cyclic', 'CyclicLr', 'cyclic']:
            lr_schedule = CyclicLR(
                optimizer = optimizer,
                base_lr = lr,
                max_lr = self.lr_schedule_config.get('max_lr', 1e-2),
                # step_size_up=
                # step_size_down=
                gamma = self.lr_schedule_config.get('gamma', 1.0),
                verbose=  True
            )

        return lr_schedule
    
    
    def get_loss(self, loss: str):
        if loss in ['CrossEntropy', 'CrossEntropyLoss', 'crossentropy']:
            loss_fn = nn.CrossEntropyLoss()
        # elif loss in ['Dice, DiceLoss']:
        #     loss_fn = DiceLoss()
        # elif loss in ['Hybrid', 'HybridLoss']:
        #     loss_fn = HybridLoss()
        # elif loss in ['rmi', 'RMI', 'RmiLoss', 'RMILoss']:
        #     loss_fn = RMILoss()
        return loss_fn
    
    def configure_optimizers(self):
        optimizer_name = self.optimizer_config.get('name', 'Adam')
        lr = self.optimizer_config.get('learnin_rate', 1e-3)
        weight_decay = self.optimizer_config.get('weight_decay', 0)
        momentum = self.optimizer_config.get('momentum', 0)
        
        optimizer_dict = {
            'Adam' : Adam(params=self.model.parameters(),
                          lr=lr,
                          weight_decay=weight_decay),
            'Adadelta' : Adadelta(params=self.model.parameters(),
                                  lr=lr,
                                  weight_decay=weight_decay),
            'SGD' : SGD(params=self.model.parameters(),
                        lr=lr,
                        momentum=momentum,
                        weight_decay=weight_decay)
        }

        optimizer = optimizer_dict[optimizer_name]
        lr_schedule = self.get_lr_schedule(optimizer)
        
        if lr_schedule is None:
            return optimizer
        else:
            return {
                'optimizer': optimizer,
                'lr_scheduler': self.get_lr_schedule(optimizer)
            }
            
    def forward(self, *args: Any, **kwargs: Any) -> Any:
        return super().forward(*args, **kwargs)
    
    def training_step(self, train_batch, batch_idx):
        input, target = train_batch
        pred = self.model(input)['out']
        loss = self.loss(pred, target)
        self.train_mean_iou(torch.argmax(pred, dim=1), torch.argmax(target, dim=1))
        self.log("train_loss", loss, on_epoch=True, on_step=False, prog_bar=True, sync_dist=True)
        self.log('train_Mean_IoU', self.train_mean_iou, on_epoch=True, on_step=False, prog_bar=True, sync_dist=True)
        return loss
    
    def validation_step(self, val_batch, batch_idx):
        input, target = val_batch
        pred = self.model(input)['out']
        loss = self.loss(pred, target)
        self.val_mean_iou(torch.argmax(pred, dim=1), torch.argmax(target, dim=1))
        self.log("val_loss", loss, on_epoch=True, on_step=False, prog_bar=True, sync_dist=True)
        self.log('val_Mean_IoU', self.val_mean_iou, on_epoch=True, on_step=False, prog_bar=True, sync_dist=True)
    
    
    def predict_step(self, predict_batch: dict, batch_idx: int, dataloader_idx: int = 0) -> Any:
        input = predict_batch.get('image')
        filenames = list(predict_batch.get('filename'))
        
        predictions = self.model(input)['out']
        predictions = torch.argmax(predictions, 1)
        predictions = predictions.to(dtype=torch.uint8)

        # Map back to Eval ids
        for train_id, eval_id in zip(reversed(_TRAIN_IDS), reversed(_EVAL_IDS)):        
            predictions = torch.where(predictions==train_id, eval_id, predictions)

        for idx, filename in enumerate(filenames):
            pred = predictions[idx].to(device='cpu')
            input_img = input[idx].to(device='cpu')
            #print(pred.shape)
            # save grayscale predictions
            grayscale_img = Image.fromarray(pred.numpy())
            grayscale_img.save(f'{grayscale_path}/{filename}')
            
            # Draw segmentation mask on top of original image
            boolean_masks = pred == torch.arange(34)[:, None, None]
            overlayed_mask = draw_segmentation_masks(input_img.to(dtype=torch.uint8), 
                                                     boolean_masks, 
                                                     alpha=0.4, 
                                                     colors=list(cityscapes_color_map.values()))
            

            #torchvision.utils.save_image(overlayed_mask.to(dtype=torch.uint8), f'{rgb_path}/{filename}')
            overlayed_mask_img = F.to_pil_image(overlayed_mask, mode='RGB')
            overlayed_mask_img.save(f'{rgb_path}/{filename}')

In [None]:
model_checkpoint_path = f'saved_models/{MODEL_TYPE}/{MODEL_NAME}'
model_checkpoint_callback = ModelCheckpoint(dirpath=LOGS_DIR,
                                            filename=model_checkpoint_path,
                                            save_weights_only=False,
                                            monitor='val_loss',
                                            mode='min',
                                        #    monitor='MeanIoU',
                                        #    mode='max',
                                            verbose=True)

early_stopping_callback = EarlyStopping(patience=6,
                                        monitor='val_loss',
                                        # mode='max',
                                        min_delta=1e-6,
                                        verbose=True,
                                        strict=True,
                                        check_finite=True,
                                        log_rank_zero_only=True)

#profiler = AdvancedProfiler(dirpath=LOGS_DIR, filename="perf_logs")
#lr_finder_callback = LearningRateFinder()

In [None]:
callbacks = [model_checkpoint_callback, ModelSummary(max_depth=3)]
#, DeviceStatsMonitor()
if SWA is not None:
    swa_callback = StochasticWeightAveraging(swa_lrs=SWA_LRS,
                                         swa_epoch_start=SWA_EPOCH_START)
    callbacks.append(swa_callback)

In [None]:
logger = TensorBoardLogger(save_dir=f'{LOGS_DIR}/Tensorboard_logs', name=f'{MODEL_TYPE}/{MODEL_NAME}', log_graph=True)

In [None]:
model = DeepLabV3(
    model = deeplabv3_mobilenet_v3_large(num_classes=20),
    train_config=train_config
)

data_module = CityscapesDataModule(dataset_config)

trainer = pl.Trainer(
    accelerator='gpu',
    devices=DEVICES,
    limit_train_batches=NUM_TRAIN_BATCHES,
    limit_val_batches=NUM_EVAL_BATCHES,
    max_epochs=EPOCHS,
    #precision=PRECISION,
    deterministic=False,
    callbacks=callbacks,
    default_root_dir=LOGS_DIR,
    logger=logger,
    #strategy=DISTRIBUTE_STRATEGY
    #profiler='simple',
    #sync_batchnorm=True,
)

In [None]:
torch.set_float32_matmul_precision('high')

In [None]:
trainer.fit(model, datamodule=data_module)

In [None]:
trainer.predict(model, datamodule=data_module, return_predictions=False)