In [None]:
import lightning.pytorch as pl
import torch
from lightning.pytorch.callbacks import ModelCheckpoint, StochasticWeightAveraging, EarlyStopping
from lightning.pytorch.callbacks import ModelSummary, LearningRateFinder, TQDMProgressBar
from lightning.pytorch.loggers import TensorBoardLogger
from torchsummary import summary
import yaml
from lib.datasets.cityscapes import CityscapesDataModule
from lib.models.base_module import SegmentationModule

In [None]:
# Read YAML file
print('Reading configuration from config yaml')

with open('config/Cityscapes.yaml', 'r') as config_file:
    config: dict = yaml.safe_load(config_file)

# TODO: Add default values if a variable is not defined in the config file

LOGS_DIR = config.get('logs_dir')
model_config: dict = config.get('model_config')
dataset_config: dict = config.get('dataset_config')
train_config: dict = config.get('train_config')
augmentation_config: dict = train_config.get('augmentations')

# Dataset Configuration
DATASET = dataset_config.get('name')
NUM_TRAIN_BATCHES = dataset_config.get('num_train_batches', 1.0)
NUM_EVAL_BATCHES = dataset_config.get('num_eval_batches', 1.0)

# Model Configuration
MODEL_TYPE = model_config.get('architecture')
MODEL_NAME = model_config.get('name')

EPOCHS = train_config.get('epochs') #
PRECISION = str(train_config.get('precision')) #
DISTRIBUTE_STRATEGY = train_config.get('distribute').get('strategy')
DEVICES = train_config.get('distribute').get('devices')

# Stohastic weight averaging parameters
SWA = train_config.get('swa')
if SWA is not None:
    SWA_LRS = SWA.get('lr', 1e-3)
    SWA_EPOCH_START = SWA.get('epoch_start', 0.7)

In [None]:
model_checkpoint_path = f'saved_models/{MODEL_TYPE}/{MODEL_NAME}'
model_checkpoint_callback = ModelCheckpoint(dirpath=LOGS_DIR,
                                            filename=model_checkpoint_path,
                                            save_weights_only=False,
                                            monitor='val_loss',
                                            mode='min',
                                        #    monitor='MeanIoU',
                                        #    mode='max',
                                            verbose=True)

early_stopping_callback = EarlyStopping(patience=6,
                                        monitor='val_loss',
                                        # mode='max',
                                        min_delta=1e-6,
                                        verbose=True,
                                        strict=True,
                                        check_finite=True,
                                        log_rank_zero_only=True)

#profiler = AdvancedProfiler(dirpath=LOGS_DIR, filename="perf_logs")
#lr_finder_callback = LearningRateFinder()

In [None]:
callbacks = [model_checkpoint_callback, ModelSummary(max_depth=3)]
#, DeviceStatsMonitor()
if SWA is not None:
    swa_callback = StochasticWeightAveraging(swa_lrs=SWA_LRS,
                                         swa_epoch_start=SWA_EPOCH_START)
    callbacks.append(swa_callback)

In [None]:
logger = TensorBoardLogger(save_dir=f'{LOGS_DIR}/Tensorboard_logs', name=f'{MODEL_TYPE}/{MODEL_NAME}')

In [None]:
model = SegmentationModule(
    model_config = model_config,
    train_config=train_config,
    logs_dir=LOGS_DIR
)

data_module = CityscapesDataModule(dataset_config, augmentation_config)

trainer = pl.Trainer(
    accelerator='gpu',
    devices=DEVICES,
    limit_train_batches=NUM_TRAIN_BATCHES,
    limit_val_batches=NUM_EVAL_BATCHES,
    max_epochs=EPOCHS,
    #precision=PRECISION,
    deterministic=False,
    callbacks=callbacks,
    default_root_dir=LOGS_DIR,
    logger=logger,
    #strategy=DISTRIBUTE_STRATEGY
    #profiler='simple',
    #sync_batchnorm=True,
)

In [None]:
torch.set_float32_matmul_precision('high')

In [None]:
trainer.fit(model, datamodule=data_module)

## Pure Pytorch

In [None]:
checkpoint = torch.load('/mnt/logs/saved_models/DeepLabV3/Full1.ckpt')
print(checkpoint.keys())
checkpoint['hyper_parameters']

## Lightning

In [None]:
model = SegmentationModule.load_from_checkpoint('/mnt/logs/saved_models/DeepLabV3/Full1.ckpt')

In [None]:
trainer.predict(model, datamodule=data_module, return_predictions=False)