Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

Faster RCNN Model + Pascal VOC DataModule #157

Merged
merged 10 commits into from
Aug 22, 2020
22 changes: 17 additions & 5 deletions pl_bolts/datamodules/__init__.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,24 @@
from pl_bolts.datamodules.async_dataloader import AsynchronousLoader
from pl_bolts.datamodules.cifar10_datamodule import CIFAR10DataModule, TinyCIFAR10DataModule
from pl_bolts.datamodules.dummy_dataset import DummyDataset
from pl_bolts.datamodules.experience_source import (ExperienceSourceDataset, ExperienceSource,
DiscountedExperienceSource)
from pl_bolts.datamodules.cifar10_datamodule import (
CIFAR10DataModule,
TinyCIFAR10DataModule,
)
from pl_bolts.datamodules.dummy_dataset import DummyDataset, DummyDetectionDataset
from pl_bolts.datamodules.experience_source import (
ExperienceSourceDataset,
ExperienceSource,
DiscountedExperienceSource,
)
from pl_bolts.datamodules.fashion_mnist_datamodule import FashionMNISTDataModule
from pl_bolts.datamodules.imagenet_datamodule import ImagenetDataModule
from pl_bolts.datamodules.mnist_datamodule import MNISTDataModule
from pl_bolts.datamodules.binary_mnist_datamodule import BinaryMNISTDataModule
from pl_bolts.datamodules.sklearn_datamodule import SklearnDataset, SklearnDataModule, TensorDataset, TensorDataModule
from pl_bolts.datamodules.sklearn_datamodule import (
SklearnDataset,
SklearnDataModule,
TensorDataset,
TensorDataModule,
)
from pl_bolts.datamodules.ssl_imagenet_datamodule import SSLImagenetDataModule
from pl_bolts.datamodules.stl10_datamodule import STL10DataModule
from pl_bolts.datamodules.vocdetection_datamodule import VOCDetectionDataModule
27 changes: 26 additions & 1 deletion pl_bolts/datamodules/dummy_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@


class DummyDataset(Dataset):

def __init__(self, *shapes, num_samples=10000):
"""
Generate a dummy dataset
Expand Down Expand Up @@ -41,3 +40,29 @@ def __getitem__(self, idx):
samples.append(sample)

return samples


class DummyDetectionDataset(Dataset):
def __init__(
self, img_shape=(3, 256, 256), num_boxes=1, num_classes=2, num_samples=10000
):
super().__init__()
self.img_shape = img_shape
self.num_samples = num_samples
self.num_boxes = num_boxes
self.num_classes = num_classes

def __len__(self):
return self.num_samples

def _random_bbox(self):
c, h, w = self.img_shape
xs = torch.randint(w, (2,))
ys = torch.randint(h, (2,))
return [min(xs), min(ys), max(xs), max(ys)]

def __getitem__(self, idx):
img = torch.rand(self.img_shape)
boxes = torch.tensor([self._random_bbox() for _ in range(self.num_boxes)])
labels = torch.randint(self.num_classes, (self.num_boxes,))
return img, {"boxes": boxes, "labels": labels}
199 changes: 199 additions & 0 deletions pl_bolts/datamodules/vocdetection_datamodule.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,199 @@
import torch
Borda marked this conversation as resolved.
Show resolved Hide resolved
from pytorch_lightning import LightningDataModule
from torch.utils.data import DataLoader
from torchvision.datasets import VOCDetection
import torchvision.transforms as T


class Compose(object):
"""
Like `torchvision.transforms.compose` but works for (image, target)
"""

def __init__(self, transforms):
self.transforms = transforms

def __call__(self, image, target):
for t in self.transforms:
image, target = t(image, target)
return image, target


def _collate_fn(batch):
return tuple(zip(*batch))


CLASSES = (
"__background__ ",
"aeroplane",
"bicycle",
"bird",
"boat",
"bottle",
"bus",
"car",
"cat",
"chair",
"cow",
"diningtable",
"dog",
"horse",
"motorbike",
"person",
"pottedplant",
"sheep",
"sofa",
"train",
"tvmonitor",
)


def _prepare_voc_instance(image, target):
"""
Prepares VOC dataset into appropriate target for fasterrcnn

https://github.com/pytorch/vision/issues/1097#issuecomment-508917489
"""
anno = target["annotation"]
h, w = anno["size"]["height"], anno["size"]["width"]
boxes = []
classes = []
area = []
iscrowd = []
objects = anno["object"]
if not isinstance(objects, list):
objects = [objects]
for obj in objects:
bbox = obj["bndbox"]
bbox = [int(bbox[n]) - 1 for n in ["xmin", "ymin", "xmax", "ymax"]]
boxes.append(bbox)
classes.append(CLASSES.index(obj["name"]))
iscrowd.append(int(obj["difficult"]))
area.append((bbox[2] - bbox[0]) * (bbox[3] - bbox[1]))

boxes = torch.as_tensor(boxes, dtype=torch.float32)
classes = torch.as_tensor(classes)
area = torch.as_tensor(area)
iscrowd = torch.as_tensor(iscrowd)

image_id = anno["filename"][5:-4]
image_id = torch.as_tensor([int(image_id)])

target = {}
target["boxes"] = boxes
target["labels"] = classes
target["image_id"] = image_id

# for conversion to coco api
target["area"] = area
target["iscrowd"] = iscrowd

return image, target


class VOCDetectionDataModule(LightningDataModule):
name = "vocdetection"

def __init__(
self,
data_dir: str,
year: str = "2012",
num_workers: int = 16,
normalize: bool = False,
*args,
**kwargs,
):
"""
TODO(teddykoker) docstring
"""

super().__init__(*args, **kwargs)
self.year = year
self.data_dir = data_dir
self.num_workers = num_workers
self.normalize = normalize

@property
def num_classes(self):
"""
Return:
21
"""
return 21

def prepare_data(self):
"""
Saves VOCDetection files to data_dir
"""
VOCDetection(self.data_dir, year=self.year, image_set="train", download=True)
VOCDetection(self.data_dir, year=self.year, image_set="val", download=True)

def train_dataloader(self, batch_size=1, transforms=None):
"""
VOCDetection train set uses the `train` subset

Args:
batch_size: size of batch
transforms: custom transforms
"""
t = [_prepare_voc_instance]
transforms = transforms or self.train_transforms or self._default_transforms()
if transforms is not None:
t.append(transforms)
transforms = Compose(t)

dataset = VOCDetection(
self.data_dir, year=self.year, image_set="train", transforms=transforms
)
loader = DataLoader(
dataset,
batch_size=batch_size,
shuffle=True,
num_workers=self.num_workers,
pin_memory=True,
collate_fn=_collate_fn,
)
return loader

def val_dataloader(self, batch_size=1, transforms=None):
"""
VOCDetection val set uses the `val` subset

Args:
batch_size: size of batch
transforms: custom transforms
"""
t = [_prepare_voc_instance]
transforms = transforms or self.val_transforms or self._default_transforms()
if transforms is not None:
t.append(transforms)
transforms = Compose(t)
dataset = VOCDetection(
self.data_dir, year=self.year, image_set="val", transforms=transforms
)
loader = DataLoader(
dataset,
batch_size=batch_size,
shuffle=False,
num_workers=self.num_workers,
pin_memory=True,
collate_fn=_collate_fn,
)
return loader

def _default_transforms(self):
if self.normalize:
return (
lambda image, target: (
T.Compose(
[
T.ToTensor(),
T.Normalize(
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
),
]
)(image),
target,
),
)
return lambda image, target: (T.ToTensor()(image), target)
1 change: 1 addition & 0 deletions pl_bolts/models/detection/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from pl_bolts.models.detection.faster_rcnn import FasterRCNN