# FCN Training Notebook

## Import Libraries 

In [1]:
import os, sys
BASE_DIR = os.path.dirname(os.path.dirname(os.path.abspath('__file__')))
sys.path.append(BASE_DIR)

from src.models.Segmentation.FCN import FCN
from src.dataset.bdd_drivable_segmentation import BDDDrivableSegmentation
from src.config.defaults import cfg
from pytorch_lightning import Trainer
from pytorch_lightning.profiler import SimpleProfiler
from pytorch_lightning.utilities.model_summary import ModelSummary
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning.callbacks import ModelCheckpoint

from src.utils.DataLoaders import get_loader

## Load Datasets and DataLoaders

In [2]:
bdd_train_params = {
    'cfg': cfg,
    'stage': 'train'
}

bdd_train = BDDDrivableSegmentation(**bdd_train_params)

In [3]:
bdd_val_params = {
    'cfg': cfg,
    'stage': 'val'
}

bdd_val = BDDDrivableSegmentation(**bdd_val_params)

In [4]:
train_dataloader_args = {
    'dataset': bdd_train,
    'batch_size': 32,
    'shuffle': True,
}
train_dataloader = get_loader(**train_dataloader_args)

val_dataloader_args = {
    'dataset': bdd_val,
    'batch_size': 32,
    'shuffle': False,
}
val_dataloader = get_loader(**val_dataloader_args)

## Load Faster RCNN Model

In [5]:
fcn_model_params = {
    'cfg': cfg,
    'num_classes': len(bdd_train.cls_to_idx),
    'backbone': 'resnet101',
    'learning_rate': 1e-5,
    'weight_decay': 1e-3,
    'pretrained_backbone': True,
    'checkpoint_path': None,
    'train_loader': train_dataloader,
    'val_loader': val_dataloader
}
model = FCN(**fcn_model_params)

In [8]:
model.model

FCN(
  (backbone): IntermediateLayerGetter(
    (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU(inplace=True)
    (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (layer1): Sequential(
      (0): Bottleneck(
        (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (downsample): Sequenti

## Model Summary

In [None]:
ModelSummary(model, max_depth=-1)  

## Training Section

In [None]:
profiler = SimpleProfiler()
early_stop_callback = EarlyStopping(monitor="val_loss", patience=5, verbose=False, mode="min")
checkpoint_callback = ModelCheckpoint(monitor="val_loss", mode='min')

In [None]:
trainer = Trainer(auto_lr_find=True, profiler=profiler, callbacks=[early_stop_callback, checkpoint_callback])

In [None]:
trainer.fit(model)