# Train your own object detector with Faster-RCNN & PyTorch: Heads detector

In [1]:
!unzip input.zip
!unzip target.zip

Archive:  input.zip
  inflating: input/000.jpg           
  inflating: input/001.jpg           
  inflating: input/002.jpg           
  inflating: input/003.png           
  inflating: input/004.jpg           
  inflating: input/005.jpg           
  inflating: input/006.jpg           
  inflating: input/007.jpg           
  inflating: input/008.jpg           
  inflating: input/009.jpg           
  inflating: input/010.jpg           
  inflating: input/011.jpg           
  inflating: input/012.jpg           
  inflating: input/013.jpg           
  inflating: input/014.jpg           
  inflating: input/015.jpg           
  inflating: input/016.jpg           
  inflating: input/017.jpg           
  inflating: input/018.jpg           
  inflating: input/019.jpg           
Archive:  target.zip
  inflating: target/000.pt           
  inflating: target/001.pt           
  inflating: target/002.pt           
  inflating: target/003.pt           
  inflating: target/004.pt           
  inflati

In [2]:
!pip install pytorch_lightning

Collecting pytorch_lightning
[?25l  Downloading https://files.pythonhosted.org/packages/c2/a1/a991780873b5fd760fb99dfda01916fe9e5b186f0ba70a120e6b4f79cfaa/pytorch_lightning-1.3.1-py3-none-any.whl (805kB)
[K     |████████████████████████████████| 808kB 4.9MB/s 
[?25hCollecting torchmetrics>=0.2.0
[?25l  Downloading https://files.pythonhosted.org/packages/3b/e8/513cd9d0b1c83dc14cd8f788d05cd6a34758d4fd7e4f9e5ecd5d7d599c95/torchmetrics-0.3.2-py3-none-any.whl (274kB)
[K     |████████████████████████████████| 276kB 37.2MB/s 
[?25hCollecting fsspec[http]>=2021.4.0
[?25l  Downloading https://files.pythonhosted.org/packages/bc/52/816d1a3a599176057bf29dfacb1f8fadb61d35fbd96cb1bab4aaa7df83c0/fsspec-2021.5.0-py3-none-any.whl (111kB)
[K     |████████████████████████████████| 112kB 38.5MB/s 
Collecting pyDeprecate==0.3.0
  Downloading https://files.pythonhosted.org/packages/14/52/aa227a0884df71ed1957649085adf2b8bc2a1816d037c2f18b3078854516/pyDeprecate-0.3.0-py3-none-any.whl
Collecting PyYAML

# Dataset

In [43]:
import pathlib

import albumentations as A
import numpy as np
from skimage.io import imread
from typing import List, Dict, Callable, Tuple

import torch
import torchvision.models as models
from torchvision.ops import box_convert
from torchvision.models.detection.rpn import AnchorGenerator
from torchvision.ops import MultiScaleRoIAlign
from torchvision.models.detection.faster_rcnn import FasterRCNN
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint, LearningRateMonitor, EarlyStopping #callbacks
from torch.utils.data import Dataset, DataLoader

In [44]:
def get_filenames_of_path(path: List[pathlib.Path], ext: str = '*'):
    """
    Returns a list of files in a directory/path. Uses pathlib.
    """
    filenames = [file for file in path.glob(ext) if file.is_file()]
    return filenames


In [45]:

class Compose:
    """Baseclass - composes several transforms together."""

    def __init__(self, transforms: List[Callable]):
        self.transforms = transforms

    def __repr__(self): return str([transform for transform in self.transforms])
    
class ComposeDouble(Compose):
    """Composes transforms for input-target pairs."""

    def __call__(self, inp: np.ndarray, target: dict):
        for t in self.transforms:
            inp, target = t(inp, target)
        return inp, target


In [46]:
class Repr:
    """Evaluatable string representation of an object"""

    def __repr__(self): return f'{self.__class__.__name__}: {self.__dict__}'
    
class FunctionWrapperDouble(Repr):
    """A function wrapper that returns a partial for an input-target pair."""

    def __init__(self, function: Callable, input: bool = True, target: bool = False, *args, **kwargs):
        from functools import partial
        self.function = partial(function, *args, **kwargs)
        self.input = input
        self.target = target

    def __call__(self, inp: np.ndarray, tar: dict):
        if self.input: inp = self.function(inp)
        if self.target: tar = self.function(tar)
        return inp, tar
        
class AlbumentationWrapper(Repr):
    """
    A wrapper for the albumentation package.
    Bounding boxes are expected to be in xyxy format (pascal_voc).
    Bounding boxes cannot be larger than the spatial image's dimensions.
    Use Clip() if your bounding boxes are outside of the image, before using this wrapper.
    """
    def __init__(self, albumentation: Callable, format: str = 'pascal_voc'):
        self.albumentation = albumentation
        self.format = format

    def __call__(self, inp: np.ndarray, tar: dict):
        # input, target
        transform = A.Compose([
            self.albumentation
        ], bbox_params=A.BboxParams(format=self.format, label_fields=['class_labels']))

        out_dict = transform(image=inp, bboxes=tar['boxes'], class_labels=tar['labels'])

        input_out = np.array(out_dict['image'])
        boxes = np.array(out_dict['bboxes'])
        labels = np.array(out_dict['class_labels'])

        tar['boxes'] = boxes
        tar['labels'] = labels

        return input_out, tar


class Clip(Repr):
    """
    If the bounding boxes exceed one dimension, they are clipped to the dim's maximum.
    Bounding boxes are expected to be in xyxy format.
    Example: x_value=224 but x_shape=200 -> x1=199
    """
    def __call__(self, inp: np.ndarray, tar: dict):
        new_boxes = clip_bbs(inp=inp, bbs=tar['boxes'])
        tar['boxes'] = new_boxes

        return inp, tar


In [47]:
def map_class_to_int(labels: List[str], mapping: dict):
    """Maps a string to an integer."""
    labels = np.array(labels)
    dummy = np.empty_like(labels)
    for key, value in mapping.items():
        dummy[labels == key] = value

    return dummy.astype(np.uint8)

In [48]:
def clip_bbs(inp: np.ndarray,
             bbs: np.ndarray):
    """
    If the bounding boxes exceed one dimension, they are clipped to the dim's maximum.
    Bounding boxes are expected to be in xyxy format.
    Example: x_value=224 but x_shape=200 -> x1=199
    """

    def clip(value: int, max: int):

        if value >= max - 1:
            value = max - 1
        elif value <= 0:
            value = 0

        return value

    output = []
    for bb in bbs:
        x1, y1, x2, y2 = tuple(bb)
        x_shape = inp.shape[1]
        y_shape = inp.shape[0]

        x1 = clip(x1, x_shape)
        y1 = clip(y1, y_shape)
        x2 = clip(x2, x_shape)
        y2 = clip(y2, y_shape)

        output.append([x1, y1, x2, y2])

    return np.array(output)

In [49]:
class ObjectDetectionDataSet(torch.utils.data.Dataset):
    """
    Builds a dataset with images and their respective targets.
    A target is expected to be a pickled file of a dict
    and should contain at least a 'boxes' and a 'labels' key.
    inputs and targets are expected to be a list of pathlib.Path objects.

    In case your labels are strings, you can use mapping (a dict) to int-encode them.
    Returns a dict with the following keys: 'x', 'x_name', 'y', 'y_name'
    """

    def __init__(self,
                 inputs: List[pathlib.Path],
                 targets: List[pathlib.Path],
                 transform: ComposeDouble = None,
                 use_cache: bool = False,
                 convert_to_format: str = None,
                 mapping: Dict = None
                 ):
        self.inputs = inputs
        self.targets = targets
        self.transform = transform
        self.use_cache = use_cache
        self.convert_to_format = convert_to_format
        self.mapping = mapping

        if self.use_cache:
            # Use multiprocessing to load images and targets into RAM
            from multiprocessing import Pool
            with Pool() as pool:
                self.cached_data = pool.starmap(self.read_images, zip(inputs, targets))

    def __len__(self):
        return len(self.inputs)

    def __getitem__(self,
                    index: int):
        if self.use_cache:
            x, y = self.cached_data[index]
        else:
            # Select the sample
            input_ID = self.inputs[index]
            target_ID = self.targets[index]

            # Load input and target
            x, y = self.read_images(input_ID, target_ID)

        # From RGBA to RGB
        if x.shape[-1] == 4:
            from skimage.color import rgba2rgb
            x = rgba2rgb(x)

        # Read boxes
        try:
            boxes = torch.from_numpy(y['boxes']).to(torch.float32)
        except TypeError:
            boxes = torch.tensor(y['boxes']).to(torch.float32)

        # Read scores
        if 'scores' in y.keys():
            try:
                scores = torch.from_numpy(y['scores']).to(torch.float32)
            except TypeError:
                scores = torch.tensor(y['scores']).to(torch.float32)

        # Label Mapping
        if self.mapping:
            labels = map_class_to_int(y['labels'], mapping=self.mapping)
        else:
            labels = y['labels']

        # Read labels
        try:
            labels = torch.from_numpy(labels).to(torch.int64)
        except TypeError:
            labels = torch.tensor(labels).to(torch.int64)

        # Convert format
        if self.convert_to_format == 'xyxy':
            boxes = box_convert(boxes, in_fmt='xywh', out_fmt='xyxy')  # transforms boxes from xywh to xyxy format
        elif self.convert_to_format == 'xywh':
            boxes = box_convert(boxes, in_fmt='xyxy', out_fmt='xywh')  # transforms boxes from xyxy to xywh format

        # Create target
        target = {'boxes': boxes,
                  'labels': labels}

        if 'scores' in y.keys():
            target['scores'] = scores

        # Preprocessing
        target = {key: value.numpy() for key, value in target.items()}  # all tensors should be converted to np.ndarrays

        if self.transform is not None:
            x, target = self.transform(x, target)  # returns np.ndarrays

        # Typecasting
        x = torch.from_numpy(x).type(torch.float32)
        target = {key: torch.from_numpy(value) for key, value in target.items()}

        return {'x': x, 'y': target, 'x_name': self.inputs[index].name, 'y_name': self.targets[index].name}

    @staticmethod
    def read_images(inp, tar):
        return imread(inp), torch.load(tar)


In [50]:
def normalize_01(inp: np.ndarray):
    """Squash image input to the value range [0, 1] (no clipping)"""
    inp_out = (inp - np.min(inp)) / np.ptp(inp)
    return inp_out

In [51]:
root = pathlib.Path('')

In [52]:
inputs = get_filenames_of_path(root / 'input')
targets = get_filenames_of_path(root / 'target')

inputs.sort()
targets.sort()

In [53]:
mapping = {
    'head': 1,
}

### Transformations

In [54]:
transforms = ComposeDouble([
    Clip(),
    # AlbumentationWrapper(albumentation=A.HorizontalFlip(p=0.5)),
    # AlbumentationWrapper(albumentation=A.RandomScale(p=0.5, scale_limit=0.5)),
    # AlbuWrapper(albu=A.VerticalFlip(p=0.5)),
    FunctionWrapperDouble(np.moveaxis, source=-1, destination=0),
    FunctionWrapperDouble(normalize_01)
])

In [55]:
dataset = ObjectDetectionDataSet(inputs=inputs,
                                 targets=targets,
                                 transform=transforms,
                                 use_cache=False,
                                 convert_to_format=None,
                                 mapping=mapping)

# Training

In [56]:
def collate_double(batch):
    """
    collate function for the ObjectDetectionDataSet.
    Only used by the dataloader.
    """
    x = [sample['x'] for sample in batch]
    y = [sample['y'] for sample in batch]
    x_name = [sample['x_name'] for sample in batch]
    y_name = [sample['y_name'] for sample in batch]
    return x, y, x_name, y_name

In [57]:
# hyper-parameters
params = {'BATCH_SIZE': 2,
          'LR': 0.001,
          'PRECISION': 32,
          'CLASSES': 2,
          'SEED': 42,
          'PROJECT': 'Heads',
          'EXPERIMENT': 'heads',
          'MAXEPOCHS': 500,
          'BACKBONE': 'resnet34',
          'FPN': False,
          'ANCHOR_SIZE': ((32, 64, 128, 256, 512),),
          'ASPECT_RATIOS': ((0.5, 1.0, 2.0),),
          'MIN_SIZE': 1024,
          'MAX_SIZE': 1024,
          'IMG_MEAN': [0.485, 0.456, 0.406],
          'IMG_STD': [0.229, 0.224, 0.225],
          'IOU_THRESHOLD': 0.5
          }

In [58]:
# mapping
mapping = {
    'head': 1,
}

In [59]:
# training transformations and augmentations
transforms_training = ComposeDouble([
    Clip(),
    AlbumentationWrapper(albumentation=A.HorizontalFlip(p=0.5)),
    AlbumentationWrapper(albumentation=A.RandomScale(p=0.5, scale_limit=0.5)),
    # AlbuWrapper(albu=A.VerticalFlip(p=0.5)),
    FunctionWrapperDouble(np.moveaxis, source=-1, destination=0),
    FunctionWrapperDouble(normalize_01)
])

# validation transformations
transforms_validation = ComposeDouble([
    Clip(),
    FunctionWrapperDouble(np.moveaxis, source=-1, destination=0),
    FunctionWrapperDouble(normalize_01)
])

# test transformations
transforms_test = ComposeDouble([
    Clip(),
    FunctionWrapperDouble(np.moveaxis, source=-1, destination=0),
    FunctionWrapperDouble(normalize_01)
])

In [60]:
# random seed
from pytorch_lightning import seed_everything

seed_everything(params['SEED'])

Global seed set to 42


42

In [61]:
# training validation test split
inputs_train, inputs_valid, inputs_test = inputs[:12], inputs[12:16], inputs[16:]
targets_train, targets_valid, targets_test = targets[:12], targets[12:16], targets[16:]

In [62]:
# dataset training
dataset_train = ObjectDetectionDataSet(inputs=inputs_train,
                                       targets=targets_train,
                                       transform=transforms_training,
                                       use_cache=True,
                                       convert_to_format=None,
                                       mapping=mapping)

# dataset validation
dataset_valid = ObjectDetectionDataSet(inputs=inputs_valid,
                                       targets=targets_valid,
                                       transform=transforms_validation,
                                       use_cache=True,
                                       convert_to_format=None,
                                       mapping=mapping)

# dataset test
dataset_test = ObjectDetectionDataSet(inputs=inputs_test,
                                      targets=targets_test,
                                      transform=transforms_test,
                                      use_cache=True,
                                      convert_to_format=None,
                                      mapping=mapping)

# dataloader training
dataloader_train = DataLoader(dataset=dataset_train,
                              batch_size=params['BATCH_SIZE'],
                              shuffle=True,
                              num_workers=0,
                              collate_fn=collate_double)

# dataloader validation
dataloader_valid = DataLoader(dataset=dataset_valid,
                              batch_size=1,
                              shuffle=False,
                              num_workers=0,
                              collate_fn=collate_double)

# dataloader test
dataloader_test = DataLoader(dataset=dataset_test,
                             batch_size=1,
                             shuffle=False,
                             num_workers=0,
                             collate_fn=collate_double)

In [63]:
def get_resnet_backbone(backbone_name: str):
    """
    Returns a resnet backbone pretrained on ImageNet.
    Removes the average-pooling layer and the linear layer at the end.
    """
    if backbone_name == 'resnet18':
        pretrained_model = models.resnet18(pretrained=True, progress=False)
        out_channels = 512
    elif backbone_name == 'resnet34':
        pretrained_model = models.resnet34(pretrained=True, progress=False)
        out_channels = 512
    elif backbone_name == 'resnet50':
        pretrained_model = models.resnet50(pretrained=True, progress=False)
        out_channels = 2048
    elif backbone_name == 'resnet101':
        pretrained_model = models.resnet101(pretrained=True, progress=False)
        out_channels = 2048
    elif backbone_name == 'resnet152':
        pretrained_model = models.resnet152(pretrained=True, progress=False)
        out_channels = 2048

    backbone = torch.nn.Sequential(*list(pretrained_model.children())[:-2])
    backbone.out_channels = out_channels

    return backbone

In [64]:
def get_anchor_generator(anchor_size: Tuple[tuple] = None, aspect_ratios: Tuple[tuple] = None):
    """Returns the anchor generator."""
    if anchor_size is None:
        anchor_size = ((16,), (32,), (64,), (128,))
    if aspect_ratios is None:
        aspect_ratios = ((0.5, 1.0, 2.0),) * len(anchor_size)

    anchor_generator = AnchorGenerator(sizes=anchor_size,
                                       aspect_ratios=aspect_ratios)
    return anchor_generator

In [65]:
def get_roi_pool(featmap_names: List[str] = None, output_size: int = 7, sampling_ratio: int = 2):
    """Returns the ROI Pooling"""
    if featmap_names is None:
        # default for resnet with FPN
        featmap_names = ['0', '1', '2', '3']

    roi_pooler = MultiScaleRoIAlign(featmap_names=featmap_names,
                                    output_size=output_size,
                                    sampling_ratio=sampling_ratio)

    return roi_pooler

In [66]:
def get_fasterRCNN(backbone: torch.nn.Module,
                   anchor_generator: AnchorGenerator,
                   roi_pooler: MultiScaleRoIAlign,
                   num_classes: int,
                   image_mean: List[float] = [0.485, 0.456, 0.406],
                   image_std: List[float] = [0.229, 0.224, 0.225],
                   min_size: int = 512,
                   max_size: int = 1024,
                   **kwargs
                   ):
    """Returns the Faster-RCNN model. Default normalization: ImageNet"""
    model = FasterRCNN(backbone=backbone,
                       rpn_anchor_generator=anchor_generator,
                       box_roi_pool=roi_pooler,
                       num_classes=num_classes,
                       image_mean=image_mean,  # ImageNet
                       image_std=image_std,  # ImageNet
                       min_size=min_size,
                       max_size=max_size,
                       **kwargs
                       )
    model.num_classes = num_classes
    model.image_mean = image_mean
    model.image_std = image_std
    model.min_size = min_size
    model.max_size = max_size

    return model

In [67]:
def get_fasterRCNN_resnet(num_classes: int,
                          backbone_name: str,
                          anchor_size: List[float],
                          aspect_ratios: List[float],
                          fpn: bool = True,
                          min_size: int = 512,
                          max_size: int = 1024,
                          **kwargs
                          ):
    """Returns the Faster-RCNN model with resnet backbone with and without fpn."""

    # Backbone
    if fpn:
        backbone = get_resnet_fpn_backbone(backbone_name=backbone_name)
    else:
        backbone = get_resnet_backbone(backbone_name=backbone_name)

    # Anchors
    anchor_size = anchor_size
    aspect_ratios = aspect_ratios * len(anchor_size)
    anchor_generator = get_anchor_generator(anchor_size=anchor_size, aspect_ratios=aspect_ratios)

    # ROI Pool
    with torch.no_grad():
        backbone.eval()
        random_input = torch.rand(size=(1, 3, 512, 512))
        features = backbone(random_input)

    if isinstance(features, torch.Tensor):
        from collections import OrderedDict

        features = OrderedDict([('0', features)])

    featmap_names = [key for key in features.keys() if key.isnumeric()]

    roi_pool = get_roi_pool(featmap_names=featmap_names)

    # Model
    return get_fasterRCNN(backbone=backbone,
                          anchor_generator=anchor_generator,
                          roi_pooler=roi_pool,
                          num_classes=num_classes,
                          min_size=min_size,
                          max_size=max_size,
                          **kwargs)


In [68]:
model = get_fasterRCNN_resnet(num_classes=params['CLASSES'],
                              backbone_name=params['BACKBONE'],
                              anchor_size=params['ANCHOR_SIZE'],
                              aspect_ratios=params['ASPECT_RATIOS'],
                              fpn=params['FPN'],
                              min_size=params['MIN_SIZE'],
                              max_size=params['MAX_SIZE'])

In [73]:
def from_dict_to_BoundingBox(file: dict, name: str, groundtruth: bool = True):
    """Returns list of BoundingBox objects from groundtruth or prediction."""
    from metrics.bounding_box import BoundingBox
    from metrics.enumerators import BBFormat, BBType

    labels = file['labels']
    boxes = file['boxes']
    scores = np.array(file['scores'].cpu()) if not groundtruth else [None] * len(boxes)

    gt = BBType.GROUND_TRUTH if groundtruth else BBType.DETECTED

    return [BoundingBox(image_name=name,
                        class_id=int(l),
                        coordinates=tuple(bb),
                        format=BBFormat.XYX2Y2,
                        bb_type=gt,
                        confidence=s) for bb, l, s in zip(boxes, labels, scores)]


In [74]:
class FasterRCNN_lightning(pl.LightningModule):
    def __init__(self,
                 model: torch.nn.Module,
                 lr: float = 0.0001,
                 iou_threshold: float = 0.5
                 ):
        super().__init__()

        # Model
        self.model = model

        # Classes (background inclusive)
        self.num_classes = self.model.num_classes

        # Learning rate
        self.lr = lr

        # IoU threshold
        self.iou_threshold = iou_threshold

        # Transformation parameters
        self.mean = model.image_mean
        self.std = model.image_std
        self.min_size = model.min_size
        self.max_size = model.max_size

        # Save hyperparameters
        self.save_hyperparameters()

    def forward(self, x):
        self.model.eval()
        return self.model(x)

    def training_step(self, batch, batch_idx):
        # Batch
        x, y, x_name, y_name = batch  # tuple unpacking

        loss_dict = self.model(x, y)
        loss = sum(loss for loss in loss_dict.values())

        self.log_dict(loss_dict)
        return loss

    def validation_step(self, batch, batch_idx):
        # Batch
        x, y, x_name, y_name = batch

        # Inference
        preds = self.model(x)

        gt_boxes = [from_dict_to_BoundingBox(target, name=name, groundtruth=True) for target, name in zip(y, x_name)]
        gt_boxes = list(chain(*gt_boxes))

        pred_boxes = [from_dict_to_BoundingBox(pred, name=name, groundtruth=False) for pred, name in zip(preds, x_name)]
        pred_boxes = list(chain(*pred_boxes))

        return {'pred_boxes': pred_boxes, 'gt_boxes': gt_boxes}

    def validation_epoch_end(self, outs):
        gt_boxes = [out['gt_boxes'] for out in outs]
        gt_boxes = list(chain(*gt_boxes))
        pred_boxes = [out['pred_boxes'] for out in outs]
        pred_boxes = list(chain(*pred_boxes))

        from metrics.pascal_voc_evaluator import get_pascalvoc_metrics
        from metrics.enumerators import MethodAveragePrecision
        metric = get_pascalvoc_metrics(gt_boxes=gt_boxes,
                                       det_boxes=pred_boxes,
                                       iou_threshold=self.iou_threshold,
                                       method=MethodAveragePrecision.EVERY_POINT_INTERPOLATION,
                                       generate_table=True)

        per_class, mAP = metric['per_class'], metric['mAP']
        self.log('Validation_mAP', mAP)

        for key, value in per_class.items():
            self.log(f'Validation_AP_{key}', value['AP'])

    def test_step(self, batch, batch_idx):
        # Batch
        x, y, x_name, y_name = batch

        # Inference
        preds = self.model(x)

        gt_boxes = [from_dict_to_BoundingBox(target, name=name, groundtruth=True) for target, name in zip(y, x_name)]
        gt_boxes = list(chain(*gt_boxes))

        pred_boxes = [from_dict_to_BoundingBox(pred, name=name, groundtruth=False) for pred, name in zip(preds, x_name)]
        pred_boxes = list(chain(*pred_boxes))

        return {'pred_boxes': pred_boxes, 'gt_boxes': gt_boxes}

    def test_epoch_end(self, outs):
        gt_boxes = [out['gt_boxes'] for out in outs]
        gt_boxes = list(chain(*gt_boxes))
        pred_boxes = [out['pred_boxes'] for out in outs]
        pred_boxes = list(chain(*pred_boxes))

        from metrics.pascal_voc_evaluator import get_pascalvoc_metrics
        from metrics.enumerators import MethodAveragePrecision
        metric = get_pascalvoc_metrics(gt_boxes=gt_boxes,
                                       det_boxes=pred_boxes,
                                       iou_threshold=self.iou_threshold,
                                       method=MethodAveragePrecision.EVERY_POINT_INTERPOLATION,
                                       generate_table=True)

        per_class, mAP = metric['per_class'], metric['mAP']
        self.log('Test_mAP', mAP)

        for key, value in per_class.items():
            self.log(f'Test_AP_{key}', value['AP'])

    def configure_optimizers(self):
        optimizer = torch.optim.SGD(self.model.parameters(),
                                    lr=self.lr,
                                    momentum=0.9,
                                    weight_decay=0.005)
        lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer,
                                                                  mode='max',
                                                                  factor=0.75,
                                                                  patience=30,
                                                                  min_lr=0)
        return {'optimizer': optimizer, 'lr_scheduler': lr_scheduler, 'monitor': 'Validation_mAP'}

In [75]:
task = FasterRCNN_lightning(model=model, lr=params['LR'], iou_threshold=params['IOU_THRESHOLD'])

In [76]:
checkpoint_callback = ModelCheckpoint(monitor='Validation_mAP', mode='max')
learningrate_callback = LearningRateMonitor(logging_interval='step', log_momentum=False)
early_stopping_callback = EarlyStopping(monitor='Validation_mAP', patience=50, mode='max')

# trainer init
from pytorch_lightning import Trainer

trainer = Trainer(gpus=1,
                  precision=params['PRECISION'],  # try 16 with enable_pl_optimizer=False
                  callbacks=[checkpoint_callback, learningrate_callback, early_stopping_callback],
                  default_root_dir='heads',  # where checkpoints are saved to
                  log_every_n_steps=1,
                  num_sanity_val_steps=0
                  )

GPU available: True, used: True
TPU available: False, using: 0 TPU cores


In [77]:
# start training
trainer.max_epochs = params['MAXEPOCHS']
trainer.fit(task,
            train_dataloader=dataloader_train,
            val_dataloaders=dataloader_valid)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name  | Type       | Params
-------------------------------------
0 | model | FasterRCNN | 50.4 M
-------------------------------------
50.4 M    Trainable params
0         Non-trainable params
50.4 M    Total params
201.736   Total estimated model params size (MB)


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Training', layout=Layout(flex='2'), max…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

ModuleNotFoundError: ignored

In [None]:
# start testing
trainer.test(ckpt_path='best', test_dataloaders=dataloader_test)