Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Kitti Datamodule #248

Merged
merged 28 commits into from Sep 27, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
3 changes: 3 additions & 0 deletions pl_bolts/datamodules/__init__.py
Expand Up @@ -18,5 +18,8 @@
from pl_bolts.datamodules.ssl_imagenet_datamodule import SSLImagenetDataModule
from pl_bolts.datamodules.stl10_datamodule import STL10DataModule
from pl_bolts.datamodules.vocdetection_datamodule import VOCDetectionDataModule

from pl_bolts.datamodules.kitti_dataset import KittiDataset
from pl_bolts.datamodules.kitti_datamodule import KittiDataModule
except ImportError:
pass
99 changes: 99 additions & 0 deletions pl_bolts/datamodules/kitti_datamodule.py
@@ -0,0 +1,99 @@
import os
import torch

from pytorch_lightning import LightningDataModule
from pl_bolts.datamodules.kitti_dataset import KittiDataset

from torch.utils.data import DataLoader
import torchvision.transforms as transforms
from torch.utils.data.dataset import random_split


class KittiDataModule(LightningDataModule):

name = 'kitti'

def __init__(
self,
data_dir: str,
val_split: float = 0.2,
test_split: float = 0.1,
num_workers: int = 16,
batch_size: int = 32,
seed: int = 42,
*args,
**kwargs,
):
"""
Kitti train, validation and test dataloaders.

Note: You need to have downloaded the Kitti dataset first and provide the path to where it is saved.
You can download the dataset here: http://www.cvlibs.net/datasets/kitti/eval_semseg.php?benchmark=semantics2015

Specs:
- 200 samples
- Each image is (3 x 1242 x 376)

In total there are 34 classes but some of these are not useful so by default we use only 19 of the classes
specified by the `valid_labels` parameter.

Example::

from pl_bolts.datamodules import KittiDataModule

dm = KittiDataModule(PATH)
model = LitModel()

Trainer().fit(model, dm)

Args::
data_dir: where to load the data from path, i.e. '/path/to/folder/with/data_semantics/'
val_split: size of validation test (default 0.2)
test_split: size of test set (default 0.1)
num_workers: how many workers to use for loading data
batch_size: the batch size
seed: random seed to be used for train/val/test splits
"""
super().__init__(*args, **kwargs)
self.data_dir = data_dir if data_dir is not None else os.getcwd()
self.batch_size = batch_size
self.num_workers = num_workers
self.seed = seed

self.default_transforms = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean=[0.35675976, 0.37380189, 0.3764753],
std=[0.32064945, 0.32098866, 0.32325324])
])

# split into train, val, test
kitti_dataset = KittiDataset(self.data_dir, transform=self.default_transforms)

val_len = round(val_split * len(kitti_dataset))
test_len = round(test_split * len(kitti_dataset))
train_len = len(kitti_dataset) - val_len - test_len

self.trainset, self.valset, self.testset = random_split(kitti_dataset,
lengths=[train_len, val_len, test_len],
generator=torch.Generator().manual_seed(self.seed))

def train_dataloader(self):
loader = DataLoader(self.trainset,
batch_size=self.batch_size,
shuffle=True,
num_workers=self.num_workers)
return loader

def val_dataloader(self):
loader = DataLoader(self.valset,
batch_size=self.batch_size,
shuffle=False,
num_workers=self.num_workers)
return loader

def test_dataloader(self):
loader = DataLoader(self.testset,
batch_size=self.batch_size,
shuffle=False,
num_workers=self.num_workers)
return loader
92 changes: 92 additions & 0 deletions pl_bolts/datamodules/kitti_dataset.py
@@ -0,0 +1,92 @@
import os
import numpy as np
from PIL import Image

from torch.utils.data import Dataset

DEFAULT_VOID_LABELS = (0, 1, 2, 3, 4, 5, 6, 9, 10, 14, 15, 16, 18, 29, 30, -1)
DEFAULT_VALID_LABELS = (7, 8, 11, 12, 13, 17, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 31, 32, 33)


class KittiDataset(Dataset):
"""
Note: You need to have downloaded the Kitti dataset first and provide the path to where it is saved.
You can download the dataset here: http://www.cvlibs.net/datasets/kitti/eval_semseg.php?benchmark=semantics2015

There are 34 classes, however not all of them are useful for training (e.g. railings on highways). These
useless classes (the pixel values of these classes) are stored in `void_labels`. Useful classes are stored
in `valid_labels`.

The `encode_segmap` function sets all pixels with any of the `void_labels` to `ignore_index`
(250 by default). It also sets all of the valid pixels to the appropriate value between 0 and
`len(valid_labels)` (since that is the number of valid classes), so it can be used properly by
the loss function when comparing with the output.

Args:
data_dir (str): where to load the data from path, i.e. '/path/to/folder/with/data_semantics/'
img_size: image dimensions (width, height)
void_labels: useless classes to be excluded from training
valid_labels: useful classes to include
"""
IMAGE_PATH = os.path.join('training', 'image_2')
MASK_PATH = os.path.join('training', 'semantic')

def __init__(
self,
data_dir: str,
img_size: tuple = (1242, 376),
void_labels: list = DEFAULT_VOID_LABELS,
valid_labels: list = DEFAULT_VALID_LABELS,
transform=None
):
self.img_size = img_size
self.void_labels = void_labels
self.valid_labels = valid_labels
self.ignore_index = 250
self.class_map = dict(zip(self.valid_labels, range(len(self.valid_labels))))
self.transform = transform

self.data_dir = data_dir
self.img_path = os.path.join(self.data_dir, self.IMAGE_PATH)
self.mask_path = os.path.join(self.data_dir, self.MASK_PATH)
self.img_list = self.get_filenames(self.img_path)
self.mask_list = self.get_filenames(self.mask_path)

def __len__(self):
return len(self.img_list)

def __getitem__(self, idx):
img = Image.open(self.img_list[idx])
img = img.resize(self.img_size)
img = np.array(img)

mask = Image.open(self.mask_list[idx]).convert('L')
mask = mask.resize(self.img_size)
mask = np.array(mask)
mask = self.encode_segmap(mask)

if self.transform:
img = self.transform(img)

return img, mask

def encode_segmap(self, mask):
"""
Sets void classes to zero so they won't be considered for training
"""
for voidc in self.void_labels:
mask[mask == voidc] = self.ignore_index
for validc in self.valid_labels:
mask[mask == validc] = self.class_map[validc]
# remove extra idxs from updated dataset
mask[mask > 18] = self.ignore_index
return mask

def get_filenames(self, path):
"""
Returns a list of absolute paths to images inside given `path`
"""
files_list = list()
for filename in os.listdir(path):
files_list.append(os.path.join(path, filename))
return files_list