In [None]:
import os, sys
BASE_DIR = os.path.dirname(os.path.dirname(os.path.abspath('__file__')))
sys.path.append(BASE_DIR)

from src.models.Detection.Faster_RCNN import Faster_RCNN
from src.dataset.bdd_detetcion import BDD_Detection
from src.config.defaults import cfg
from pytorch_lightning import Trainer
from pytorch_lightning.profiler import SimpleProfiler
from pytorch_lightning.utilities.model_summary import ModelSummary
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning.callbacks import ModelCheckpoint

from src.utils.DataLoaders import get_loader

In [None]:
bdd_train_params = {
    'cfg': cfg,
    'stage': 'train'
}

bdd_train = BDD_Detection(**bdd_train_params)

In [None]:
bdd_val_params = {
    'cfg': cfg,
    'stage': 'val'
}

bdd_val = BDD_Detection(**bdd_val_params)

In [None]:
train_dataloader_args = {
    'dataset': bdd_train,
    'batch_size': 32,
    'shuffle': True,
}
train_dataloader = get_loader(**train_dataloader_args)

val_dataloader_args = {
    'dataset': bdd_val,
    'batch_size': 32,
    'shuffle': False,
}
val_dataloader = get_loader(**val_dataloader_args)

In [None]:
faster_rcnn_params = {
    'cfg': cfg,
    'num_classes': 7,
    'backbone': 'resnet101',
    'learning_rate': 1e-5,
    'weight_decay': 1e-3,
    'pretrained': True,
    'pretrained_backbone': True,
    'checkpoint_path': None,
    'train_loader': train_dataloader,
    'val_loader': val_dataloader
}
model = Faster_RCNN(**faster_rcnn_params)

In [None]:
ModelSummary(model, max_depth=-1)  

In [None]:
profiler = SimpleProfiler()
early_stop_callback = EarlyStopping(monitor="val_loss", patience=5, verbose=False, mode="min")
checkpoint_callback = ModelCheckpoint(monitor="val_loss", mode='min')

In [None]:
#trainer = Trainer(auto_lr_find=True, limit_train_batches=0.0001, limit_val_batches=0.01, profiler=profiler)

In [None]:
trainer = Trainer(auto_lr_find=True, profiler=profiler, callbacks=[early_stop_callback, checkpoint_callback])

In [None]:
trainer.tune(model)

In [None]:
trainer.fit(model)