diff --git a/pl_bolts/datamodules/__init__.py b/pl_bolts/datamodules/__init__.py index 2ac28cd8e6..e8de3eaf8d 100644 --- a/pl_bolts/datamodules/__init__.py +++ b/pl_bolts/datamodules/__init__.py @@ -18,5 +18,8 @@ from pl_bolts.datamodules.ssl_imagenet_datamodule import SSLImagenetDataModule from pl_bolts.datamodules.stl10_datamodule import STL10DataModule from pl_bolts.datamodules.vocdetection_datamodule import VOCDetectionDataModule + + from pl_bolts.datamodules.kitti_dataset import KittiDataset + from pl_bolts.datamodules.kitti_datamodule import KittiDataModule except ImportError: pass diff --git a/pl_bolts/datamodules/kitti_datamodule.py b/pl_bolts/datamodules/kitti_datamodule.py new file mode 100644 index 0000000000..6858af6629 --- /dev/null +++ b/pl_bolts/datamodules/kitti_datamodule.py @@ -0,0 +1,99 @@ +import os +import torch + +from pytorch_lightning import LightningDataModule +from pl_bolts.datamodules.kitti_dataset import KittiDataset + +from torch.utils.data import DataLoader +import torchvision.transforms as transforms +from torch.utils.data.dataset import random_split + + +class KittiDataModule(LightningDataModule): + + name = 'kitti' + + def __init__( + self, + data_dir: str, + val_split: float = 0.2, + test_split: float = 0.1, + num_workers: int = 16, + batch_size: int = 32, + seed: int = 42, + *args, + **kwargs, + ): + """ + Kitti train, validation and test dataloaders. + + Note: You need to have downloaded the Kitti dataset first and provide the path to where it is saved. + You can download the dataset here: http://www.cvlibs.net/datasets/kitti/eval_semseg.php?benchmark=semantics2015 + + Specs: + - 200 samples + - Each image is (3 x 1242 x 376) + + In total there are 34 classes but some of these are not useful so by default we use only 19 of the classes + specified by the `valid_labels` parameter. + + Example:: + + from pl_bolts.datamodules import KittiDataModule + + dm = KittiDataModule(PATH) + model = LitModel() + + Trainer().fit(model, dm) + + Args:: + data_dir: where to load the data from path, i.e. '/path/to/folder/with/data_semantics/' + val_split: size of validation test (default 0.2) + test_split: size of test set (default 0.1) + num_workers: how many workers to use for loading data + batch_size: the batch size + seed: random seed to be used for train/val/test splits + """ + super().__init__(*args, **kwargs) + self.data_dir = data_dir if data_dir is not None else os.getcwd() + self.batch_size = batch_size + self.num_workers = num_workers + self.seed = seed + + self.default_transforms = transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize(mean=[0.35675976, 0.37380189, 0.3764753], + std=[0.32064945, 0.32098866, 0.32325324]) + ]) + + # split into train, val, test + kitti_dataset = KittiDataset(self.data_dir, transform=self.default_transforms) + + val_len = round(val_split * len(kitti_dataset)) + test_len = round(test_split * len(kitti_dataset)) + train_len = len(kitti_dataset) - val_len - test_len + + self.trainset, self.valset, self.testset = random_split(kitti_dataset, + lengths=[train_len, val_len, test_len], + generator=torch.Generator().manual_seed(self.seed)) + + def train_dataloader(self): + loader = DataLoader(self.trainset, + batch_size=self.batch_size, + shuffle=True, + num_workers=self.num_workers) + return loader + + def val_dataloader(self): + loader = DataLoader(self.valset, + batch_size=self.batch_size, + shuffle=False, + num_workers=self.num_workers) + return loader + + def test_dataloader(self): + loader = DataLoader(self.testset, + batch_size=self.batch_size, + shuffle=False, + num_workers=self.num_workers) + return loader diff --git a/pl_bolts/datamodules/kitti_dataset.py b/pl_bolts/datamodules/kitti_dataset.py new file mode 100644 index 0000000000..937f106fa2 --- /dev/null +++ b/pl_bolts/datamodules/kitti_dataset.py @@ -0,0 +1,92 @@ +import os +import numpy as np +from PIL import Image + +from torch.utils.data import Dataset + +DEFAULT_VOID_LABELS = (0, 1, 2, 3, 4, 5, 6, 9, 10, 14, 15, 16, 18, 29, 30, -1) +DEFAULT_VALID_LABELS = (7, 8, 11, 12, 13, 17, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 31, 32, 33) + + +class KittiDataset(Dataset): + """ + Note: You need to have downloaded the Kitti dataset first and provide the path to where it is saved. + You can download the dataset here: http://www.cvlibs.net/datasets/kitti/eval_semseg.php?benchmark=semantics2015 + + There are 34 classes, however not all of them are useful for training (e.g. railings on highways). These + useless classes (the pixel values of these classes) are stored in `void_labels`. Useful classes are stored + in `valid_labels`. + + The `encode_segmap` function sets all pixels with any of the `void_labels` to `ignore_index` + (250 by default). It also sets all of the valid pixels to the appropriate value between 0 and + `len(valid_labels)` (since that is the number of valid classes), so it can be used properly by + the loss function when comparing with the output. + + Args: + data_dir (str): where to load the data from path, i.e. '/path/to/folder/with/data_semantics/' + img_size: image dimensions (width, height) + void_labels: useless classes to be excluded from training + valid_labels: useful classes to include + """ + IMAGE_PATH = os.path.join('training', 'image_2') + MASK_PATH = os.path.join('training', 'semantic') + + def __init__( + self, + data_dir: str, + img_size: tuple = (1242, 376), + void_labels: list = DEFAULT_VOID_LABELS, + valid_labels: list = DEFAULT_VALID_LABELS, + transform=None + ): + self.img_size = img_size + self.void_labels = void_labels + self.valid_labels = valid_labels + self.ignore_index = 250 + self.class_map = dict(zip(self.valid_labels, range(len(self.valid_labels)))) + self.transform = transform + + self.data_dir = data_dir + self.img_path = os.path.join(self.data_dir, self.IMAGE_PATH) + self.mask_path = os.path.join(self.data_dir, self.MASK_PATH) + self.img_list = self.get_filenames(self.img_path) + self.mask_list = self.get_filenames(self.mask_path) + + def __len__(self): + return len(self.img_list) + + def __getitem__(self, idx): + img = Image.open(self.img_list[idx]) + img = img.resize(self.img_size) + img = np.array(img) + + mask = Image.open(self.mask_list[idx]).convert('L') + mask = mask.resize(self.img_size) + mask = np.array(mask) + mask = self.encode_segmap(mask) + + if self.transform: + img = self.transform(img) + + return img, mask + + def encode_segmap(self, mask): + """ + Sets void classes to zero so they won't be considered for training + """ + for voidc in self.void_labels: + mask[mask == voidc] = self.ignore_index + for validc in self.valid_labels: + mask[mask == validc] = self.class_map[validc] + # remove extra idxs from updated dataset + mask[mask > 18] = self.ignore_index + return mask + + def get_filenames(self, path): + """ + Returns a list of absolute paths to images inside given `path` + """ + files_list = list() + for filename in os.listdir(path): + files_list.append(os.path.join(path, filename)) + return files_list