Skip to content

Commit

Permalink
Faster RCNN Model + Pascal VOC DataModule (#157)
Browse files Browse the repository at this point in the history
* VOCDetection DataModule

* faster_rcnn

* fixed vocdetection dataset

* Faster RCNN, Pascal VOC complete

* pep8 docs fix

* added dummy detection dataset and testing fasterrcnn training

* Update pl_bolts/models/detection/faster_rcnn.py

Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>

* fix pep8

* Change torchvision >= 0.7

Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
  • Loading branch information
teddykoker and Borda committed Aug 22, 2020
1 parent efa1db0 commit 620bfd1
Show file tree
Hide file tree
Showing 7 changed files with 419 additions and 7 deletions.
22 changes: 17 additions & 5 deletions pl_bolts/datamodules/__init__.py
@@ -1,12 +1,24 @@
from pl_bolts.datamodules.async_dataloader import AsynchronousLoader
from pl_bolts.datamodules.cifar10_datamodule import CIFAR10DataModule, TinyCIFAR10DataModule
from pl_bolts.datamodules.dummy_dataset import DummyDataset
from pl_bolts.datamodules.experience_source import (ExperienceSourceDataset, ExperienceSource,
DiscountedExperienceSource)
from pl_bolts.datamodules.cifar10_datamodule import (
CIFAR10DataModule,
TinyCIFAR10DataModule,
)
from pl_bolts.datamodules.dummy_dataset import DummyDataset, DummyDetectionDataset
from pl_bolts.datamodules.experience_source import (
ExperienceSourceDataset,
ExperienceSource,
DiscountedExperienceSource,
)
from pl_bolts.datamodules.fashion_mnist_datamodule import FashionMNISTDataModule
from pl_bolts.datamodules.imagenet_datamodule import ImagenetDataModule
from pl_bolts.datamodules.mnist_datamodule import MNISTDataModule
from pl_bolts.datamodules.binary_mnist_datamodule import BinaryMNISTDataModule
from pl_bolts.datamodules.sklearn_datamodule import SklearnDataset, SklearnDataModule, TensorDataset, TensorDataModule
from pl_bolts.datamodules.sklearn_datamodule import (
SklearnDataset,
SklearnDataModule,
TensorDataset,
TensorDataModule,
)
from pl_bolts.datamodules.ssl_imagenet_datamodule import SSLImagenetDataModule
from pl_bolts.datamodules.stl10_datamodule import STL10DataModule
from pl_bolts.datamodules.vocdetection_datamodule import VOCDetectionDataModule
27 changes: 26 additions & 1 deletion pl_bolts/datamodules/dummy_dataset.py
Expand Up @@ -3,7 +3,6 @@


class DummyDataset(Dataset):

def __init__(self, *shapes, num_samples=10000):
"""
Generate a dummy dataset
Expand Down Expand Up @@ -41,3 +40,29 @@ def __getitem__(self, idx):
samples.append(sample)

return samples


class DummyDetectionDataset(Dataset):
def __init__(
self, img_shape=(3, 256, 256), num_boxes=1, num_classes=2, num_samples=10000
):
super().__init__()
self.img_shape = img_shape
self.num_samples = num_samples
self.num_boxes = num_boxes
self.num_classes = num_classes

def __len__(self):
return self.num_samples

def _random_bbox(self):
c, h, w = self.img_shape
xs = torch.randint(w, (2,))
ys = torch.randint(h, (2,))
return [min(xs), min(ys), max(xs), max(ys)]

def __getitem__(self, idx):
img = torch.rand(self.img_shape)
boxes = torch.tensor([self._random_bbox() for _ in range(self.num_boxes)])
labels = torch.randint(self.num_classes, (self.num_boxes,))
return img, {"boxes": boxes, "labels": labels}
199 changes: 199 additions & 0 deletions pl_bolts/datamodules/vocdetection_datamodule.py
@@ -0,0 +1,199 @@
import torch
from pytorch_lightning import LightningDataModule
from torch.utils.data import DataLoader
from torchvision.datasets import VOCDetection
import torchvision.transforms as T


class Compose(object):
"""
Like `torchvision.transforms.compose` but works for (image, target)
"""

def __init__(self, transforms):
self.transforms = transforms

def __call__(self, image, target):
for t in self.transforms:
image, target = t(image, target)
return image, target


def _collate_fn(batch):
return tuple(zip(*batch))


CLASSES = (
"__background__ ",
"aeroplane",
"bicycle",
"bird",
"boat",
"bottle",
"bus",
"car",
"cat",
"chair",
"cow",
"diningtable",
"dog",
"horse",
"motorbike",
"person",
"pottedplant",
"sheep",
"sofa",
"train",
"tvmonitor",
)


def _prepare_voc_instance(image, target):
"""
Prepares VOC dataset into appropriate target for fasterrcnn
https://github.com/pytorch/vision/issues/1097#issuecomment-508917489
"""
anno = target["annotation"]
h, w = anno["size"]["height"], anno["size"]["width"]
boxes = []
classes = []
area = []
iscrowd = []
objects = anno["object"]
if not isinstance(objects, list):
objects = [objects]
for obj in objects:
bbox = obj["bndbox"]
bbox = [int(bbox[n]) - 1 for n in ["xmin", "ymin", "xmax", "ymax"]]
boxes.append(bbox)
classes.append(CLASSES.index(obj["name"]))
iscrowd.append(int(obj["difficult"]))
area.append((bbox[2] - bbox[0]) * (bbox[3] - bbox[1]))

boxes = torch.as_tensor(boxes, dtype=torch.float32)
classes = torch.as_tensor(classes)
area = torch.as_tensor(area)
iscrowd = torch.as_tensor(iscrowd)

image_id = anno["filename"][5:-4]
image_id = torch.as_tensor([int(image_id)])

target = {}
target["boxes"] = boxes
target["labels"] = classes
target["image_id"] = image_id

# for conversion to coco api
target["area"] = area
target["iscrowd"] = iscrowd

return image, target


class VOCDetectionDataModule(LightningDataModule):
name = "vocdetection"

def __init__(
self,
data_dir: str,
year: str = "2012",
num_workers: int = 16,
normalize: bool = False,
*args,
**kwargs,
):
"""
TODO(teddykoker) docstring
"""

super().__init__(*args, **kwargs)
self.year = year
self.data_dir = data_dir
self.num_workers = num_workers
self.normalize = normalize

@property
def num_classes(self):
"""
Return:
21
"""
return 21

def prepare_data(self):
"""
Saves VOCDetection files to data_dir
"""
VOCDetection(self.data_dir, year=self.year, image_set="train", download=True)
VOCDetection(self.data_dir, year=self.year, image_set="val", download=True)

def train_dataloader(self, batch_size=1, transforms=None):
"""
VOCDetection train set uses the `train` subset
Args:
batch_size: size of batch
transforms: custom transforms
"""
t = [_prepare_voc_instance]
transforms = transforms or self.train_transforms or self._default_transforms()
if transforms is not None:
t.append(transforms)
transforms = Compose(t)

dataset = VOCDetection(
self.data_dir, year=self.year, image_set="train", transforms=transforms
)
loader = DataLoader(
dataset,
batch_size=batch_size,
shuffle=True,
num_workers=self.num_workers,
pin_memory=True,
collate_fn=_collate_fn,
)
return loader

def val_dataloader(self, batch_size=1, transforms=None):
"""
VOCDetection val set uses the `val` subset
Args:
batch_size: size of batch
transforms: custom transforms
"""
t = [_prepare_voc_instance]
transforms = transforms or self.val_transforms or self._default_transforms()
if transforms is not None:
t.append(transforms)
transforms = Compose(t)
dataset = VOCDetection(
self.data_dir, year=self.year, image_set="val", transforms=transforms
)
loader = DataLoader(
dataset,
batch_size=batch_size,
shuffle=False,
num_workers=self.num_workers,
pin_memory=True,
collate_fn=_collate_fn,
)
return loader

def _default_transforms(self):
if self.normalize:
return (
lambda image, target: (
T.Compose(
[
T.ToTensor(),
T.Normalize(
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
),
]
)(image),
target,
),
)
return lambda image, target: (T.ToTensor()(image), target)
1 change: 1 addition & 0 deletions pl_bolts/models/detection/__init__.py
@@ -0,0 +1 @@
from pl_bolts.models.detection.faster_rcnn import FasterRCNN

0 comments on commit 620bfd1

Please sign in to comment.