diff --git a/pl_bolts/datamodules/__init__.py b/pl_bolts/datamodules/__init__.py index c6315bad8b..659928af15 100644 --- a/pl_bolts/datamodules/__init__.py +++ b/pl_bolts/datamodules/__init__.py @@ -1,12 +1,24 @@ from pl_bolts.datamodules.async_dataloader import AsynchronousLoader -from pl_bolts.datamodules.cifar10_datamodule import CIFAR10DataModule, TinyCIFAR10DataModule -from pl_bolts.datamodules.dummy_dataset import DummyDataset -from pl_bolts.datamodules.experience_source import (ExperienceSourceDataset, ExperienceSource, - DiscountedExperienceSource) +from pl_bolts.datamodules.cifar10_datamodule import ( + CIFAR10DataModule, + TinyCIFAR10DataModule, +) +from pl_bolts.datamodules.dummy_dataset import DummyDataset, DummyDetectionDataset +from pl_bolts.datamodules.experience_source import ( + ExperienceSourceDataset, + ExperienceSource, + DiscountedExperienceSource, +) from pl_bolts.datamodules.fashion_mnist_datamodule import FashionMNISTDataModule from pl_bolts.datamodules.imagenet_datamodule import ImagenetDataModule from pl_bolts.datamodules.mnist_datamodule import MNISTDataModule from pl_bolts.datamodules.binary_mnist_datamodule import BinaryMNISTDataModule -from pl_bolts.datamodules.sklearn_datamodule import SklearnDataset, SklearnDataModule, TensorDataset, TensorDataModule +from pl_bolts.datamodules.sklearn_datamodule import ( + SklearnDataset, + SklearnDataModule, + TensorDataset, + TensorDataModule, +) from pl_bolts.datamodules.ssl_imagenet_datamodule import SSLImagenetDataModule from pl_bolts.datamodules.stl10_datamodule import STL10DataModule +from pl_bolts.datamodules.vocdetection_datamodule import VOCDetectionDataModule diff --git a/pl_bolts/datamodules/dummy_dataset.py b/pl_bolts/datamodules/dummy_dataset.py index b90713f610..771a7fdbb7 100644 --- a/pl_bolts/datamodules/dummy_dataset.py +++ b/pl_bolts/datamodules/dummy_dataset.py @@ -3,7 +3,6 @@ class DummyDataset(Dataset): - def __init__(self, *shapes, num_samples=10000): """ Generate a dummy dataset @@ -41,3 +40,29 @@ def __getitem__(self, idx): samples.append(sample) return samples + + +class DummyDetectionDataset(Dataset): + def __init__( + self, img_shape=(3, 256, 256), num_boxes=1, num_classes=2, num_samples=10000 + ): + super().__init__() + self.img_shape = img_shape + self.num_samples = num_samples + self.num_boxes = num_boxes + self.num_classes = num_classes + + def __len__(self): + return self.num_samples + + def _random_bbox(self): + c, h, w = self.img_shape + xs = torch.randint(w, (2,)) + ys = torch.randint(h, (2,)) + return [min(xs), min(ys), max(xs), max(ys)] + + def __getitem__(self, idx): + img = torch.rand(self.img_shape) + boxes = torch.tensor([self._random_bbox() for _ in range(self.num_boxes)]) + labels = torch.randint(self.num_classes, (self.num_boxes,)) + return img, {"boxes": boxes, "labels": labels} diff --git a/pl_bolts/datamodules/vocdetection_datamodule.py b/pl_bolts/datamodules/vocdetection_datamodule.py new file mode 100644 index 0000000000..d07607c974 --- /dev/null +++ b/pl_bolts/datamodules/vocdetection_datamodule.py @@ -0,0 +1,199 @@ +import torch +from pytorch_lightning import LightningDataModule +from torch.utils.data import DataLoader +from torchvision.datasets import VOCDetection +import torchvision.transforms as T + + +class Compose(object): + """ + Like `torchvision.transforms.compose` but works for (image, target) + """ + + def __init__(self, transforms): + self.transforms = transforms + + def __call__(self, image, target): + for t in self.transforms: + image, target = t(image, target) + return image, target + + +def _collate_fn(batch): + return tuple(zip(*batch)) + + +CLASSES = ( + "__background__ ", + "aeroplane", + "bicycle", + "bird", + "boat", + "bottle", + "bus", + "car", + "cat", + "chair", + "cow", + "diningtable", + "dog", + "horse", + "motorbike", + "person", + "pottedplant", + "sheep", + "sofa", + "train", + "tvmonitor", +) + + +def _prepare_voc_instance(image, target): + """ + Prepares VOC dataset into appropriate target for fasterrcnn + + https://github.com/pytorch/vision/issues/1097#issuecomment-508917489 + """ + anno = target["annotation"] + h, w = anno["size"]["height"], anno["size"]["width"] + boxes = [] + classes = [] + area = [] + iscrowd = [] + objects = anno["object"] + if not isinstance(objects, list): + objects = [objects] + for obj in objects: + bbox = obj["bndbox"] + bbox = [int(bbox[n]) - 1 for n in ["xmin", "ymin", "xmax", "ymax"]] + boxes.append(bbox) + classes.append(CLASSES.index(obj["name"])) + iscrowd.append(int(obj["difficult"])) + area.append((bbox[2] - bbox[0]) * (bbox[3] - bbox[1])) + + boxes = torch.as_tensor(boxes, dtype=torch.float32) + classes = torch.as_tensor(classes) + area = torch.as_tensor(area) + iscrowd = torch.as_tensor(iscrowd) + + image_id = anno["filename"][5:-4] + image_id = torch.as_tensor([int(image_id)]) + + target = {} + target["boxes"] = boxes + target["labels"] = classes + target["image_id"] = image_id + + # for conversion to coco api + target["area"] = area + target["iscrowd"] = iscrowd + + return image, target + + +class VOCDetectionDataModule(LightningDataModule): + name = "vocdetection" + + def __init__( + self, + data_dir: str, + year: str = "2012", + num_workers: int = 16, + normalize: bool = False, + *args, + **kwargs, + ): + """ + TODO(teddykoker) docstring + """ + + super().__init__(*args, **kwargs) + self.year = year + self.data_dir = data_dir + self.num_workers = num_workers + self.normalize = normalize + + @property + def num_classes(self): + """ + Return: + 21 + """ + return 21 + + def prepare_data(self): + """ + Saves VOCDetection files to data_dir + """ + VOCDetection(self.data_dir, year=self.year, image_set="train", download=True) + VOCDetection(self.data_dir, year=self.year, image_set="val", download=True) + + def train_dataloader(self, batch_size=1, transforms=None): + """ + VOCDetection train set uses the `train` subset + + Args: + batch_size: size of batch + transforms: custom transforms + """ + t = [_prepare_voc_instance] + transforms = transforms or self.train_transforms or self._default_transforms() + if transforms is not None: + t.append(transforms) + transforms = Compose(t) + + dataset = VOCDetection( + self.data_dir, year=self.year, image_set="train", transforms=transforms + ) + loader = DataLoader( + dataset, + batch_size=batch_size, + shuffle=True, + num_workers=self.num_workers, + pin_memory=True, + collate_fn=_collate_fn, + ) + return loader + + def val_dataloader(self, batch_size=1, transforms=None): + """ + VOCDetection val set uses the `val` subset + + Args: + batch_size: size of batch + transforms: custom transforms + """ + t = [_prepare_voc_instance] + transforms = transforms or self.val_transforms or self._default_transforms() + if transforms is not None: + t.append(transforms) + transforms = Compose(t) + dataset = VOCDetection( + self.data_dir, year=self.year, image_set="val", transforms=transforms + ) + loader = DataLoader( + dataset, + batch_size=batch_size, + shuffle=False, + num_workers=self.num_workers, + pin_memory=True, + collate_fn=_collate_fn, + ) + return loader + + def _default_transforms(self): + if self.normalize: + return ( + lambda image, target: ( + T.Compose( + [ + T.ToTensor(), + T.Normalize( + mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] + ), + ] + )(image), + target, + ), + ) + return lambda image, target: (T.ToTensor()(image), target) diff --git a/pl_bolts/models/detection/__init__.py b/pl_bolts/models/detection/__init__.py new file mode 100644 index 0000000000..65df6945ec --- /dev/null +++ b/pl_bolts/models/detection/__init__.py @@ -0,0 +1 @@ +from pl_bolts.models.detection.faster_rcnn import FasterRCNN diff --git a/pl_bolts/models/detection/faster_rcnn.py b/pl_bolts/models/detection/faster_rcnn.py new file mode 100644 index 0000000000..fbbd469b65 --- /dev/null +++ b/pl_bolts/models/detection/faster_rcnn.py @@ -0,0 +1,146 @@ +import torch +from torch import nn +from torchvision.models.detection import faster_rcnn, fasterrcnn_resnet50_fpn +from torchvision.ops import box_iou + +import pytorch_lightning as pl + +from pytorch_lightning.metrics import IoU +from argparse import ArgumentParser + +from pl_bolts.datamodules import VOCDetectionDataModule + + +def _evaluate_iou(target, pred): + """ + Evaluate intersection over union (IOU) for target from dataset and output prediction + from model + """ + if pred["boxes"].shape[0] == 0: + # no box detected, 0 IOU + return torch.tensor(0.0, device=pred["boxes"].device) + return box_iou(target["boxes"], pred["boxes"]).diag().mean() + + +class FasterRCNN(pl.LightningModule): + def __init__( + self, + learning_rate: float = 0.0001, + num_classes: int = 91, + pretrained: bool = False, + pretrained_backbone: bool = True, + trainable_backbone_layers: int = 3, + replace_head: bool = True, + **kwargs, + ): + """ + PyTorch Lightning implementation of `Faster R-CNN: Towards Real-Time Object Detection with + Region Proposal Networks `_. + + Paper authors: Shaoqing Ren, Kaiming He, Ross Girshick, Jian Sun + + Model implemented by: + - `Teddy Koker ` + + During training, the model expects both the input tensors, as well as targets (list of dictionary), containing: + - boxes (`FloatTensor[N, 4]`): the ground truth boxes in `[x1, y1, x2, y2]` format. + - labels (`Int64Tensor[N]`): the class label for each ground truh box + + CLI command:: + + # PascalVOC + python faster_rcnn.py --gpus 1 --pretrained True + + Args: + learning_rate: the learning rate + num_classes: number of detection classes (including background) + pretrained: if true, returns a model pre-trained on COCO train2017 + pretrained_backbone: if true, returns a model with backbone pre-trained on Imagenet + trainable_backbone_layers: number of trainable resnet layers starting from final block + """ + super().__init__() + + model = fasterrcnn_resnet50_fpn( + # num_classes=num_classes, + pretrained=pretrained, + pretrained_backbone=pretrained_backbone, + trainable_backbone_layers=trainable_backbone_layers, + ) + + if replace_head: + in_features = model.roi_heads.box_predictor.cls_score.in_features + head = faster_rcnn.FastRCNNPredictor(in_features, num_classes) + model.roi_heads.box_predictor = head + else: + assert num_classes == 91, "replace_head must be true to change num_classes" + + self.model = model + self.learning_rate = learning_rate + + def forward(self, x): + self.model.eval() + return self.model(x) + + def training_step(self, batch, batch_idx): + + images, targets = batch + targets = [{k: v for k, v in t.items()} for t in targets] + + # fasterrcnn takes both images and targets for training, returns + loss_dict = self.model(images, targets) + loss = sum(loss for loss in loss_dict.values()) + return {"loss": loss, "log": loss_dict} + + def validation_step(self, batch, batch_idx): + images, targets = batch + # fasterrcnn takes only images for eval() mode + outs = self.model(images) + iou = torch.stack([_evaluate_iou(t, o) for t, o in zip(targets, outs)]).mean() + return {"val_iou": iou} + + def validation_epoch_end(self, outs): + avg_iou = torch.stack([o["val_iou"] for o in outs]).mean() + logs = {"val_iou": avg_iou} + return {"avg_val_iou": avg_iou, "log": logs} + + def configure_optimizers(self): + return torch.optim.SGD( + self.model.parameters(), + lr=self.learning_rate, + momentum=0.9, + weight_decay=0.005, + ) + + @staticmethod + def add_model_specific_args(parent_parser): + parser = ArgumentParser(parents=[parent_parser], add_help=False) + parser.add_argument("--learning_rate", type=float, default=0.0001) + parser.add_argument("--num_classes", type=int, default=91) + parser.add_argument("--pretrained", type=bool, default=False) + parser.add_argument("--pretrained_backbone", type=bool, default=True) + parser.add_argument("--trainable_backbone_layers", type=int, default=3) + parser.add_argument("--replace_head", type=bool, default=True) + + parser.add_argument("--data_dir", type=str, default=".") + parser.add_argument("--batch_size", type=int, default=1) + return parser + + +def run_cli(): + pl.seed_everything(42) + parser = ArgumentParser() + parser = pl.Trainer.add_argparse_args(parser) + parser = FasterRCNN.add_model_specific_args(parser) + + args = parser.parse_args() + + datamodule = VOCDetectionDataModule.from_argparse_args(args) + args.num_classes = datamodule.num_classes + + model = FasterRCNN(**vars(args)) + trainer = pl.Trainer.from_argparse_args(args) + trainer.fit(model, datamodule) + + +if __name__ == "__main__": + run_cli() diff --git a/requirements.txt b/requirements.txt index 037801df54..3c314ca158 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,6 +1,6 @@ pytorch-lightning>=0.9.0 torch>=1.6 -torchvision>=0.5 +torchvision>=0.7 scikit-learn>=0.23 opencv-python test_tube>=0.7.5 diff --git a/tests/models/test_detection.py b/tests/models/test_detection.py new file mode 100644 index 0000000000..9665e4a22f --- /dev/null +++ b/tests/models/test_detection.py @@ -0,0 +1,29 @@ +import torch +import pytorch_lightning as pl + +from pl_bolts.models.detection import FasterRCNN +from pl_bolts.datamodules import DummyDetectionDataset +from torch.utils.data import DataLoader + + +def _collate_fn(batch): + return tuple(zip(*batch)) + + +def test_fasterrcnn(tmpdir): + + model = FasterRCNN() + + image = torch.rand(1, 3, 400, 400) + model(image) + + +def test_fasterrcnn_train(tmpdir): + + model = FasterRCNN() + + train_dl = DataLoader(DummyDetectionDataset(), collate_fn=_collate_fn) + valid_dl = DataLoader(DummyDetectionDataset(), collate_fn=_collate_fn) + + trainer = pl.Trainer(fast_dev_run=True, default_root_dir=tmpdir) + trainer.fit(model, train_dl, valid_dl)