In [11]:
import os
import time

import numpy as np

import cv2

import torch
import torch.nn as nn
from torch.utils.data import Dataset

class KITTIDataset(Dataset):
    """Generate Pytorch ready KittiDataset from dataset of images
    Arguments:
        Dataset {pytorch dataset} -- superclass 
    """
    def __init__(self):
        self.rootdir = "/home/sur/MonoDepth1_Implementation/dataset/training"
        self.extension_of_images = ".png"

        self.left_images = []
        self.right_images = []
        self.cols = 512
        self.rows = 256
        self.total_left_images = 0
        self.total_right_images = 0
        
        print("Loading Dataset: from -> ", self.rootdir)
        start = time.time()
        for subdir, dirs, files in os.walk(self.rootdir):
            if "image_2" in subdir: #Left RGB Folder
                for file in files:
                    left_file = os.path.join(subdir, file)
                    self.left_images.append(left_file)
                    self.total_left_images += 1
                
            if "image_3" in subdir: #Right RGB Folder
                for file in files:
                    right_file = os.path.join(subdir, file)
                    self.right_images.append(right_file)
                    self.total_right_images += 1
        self.left_images.sort()
        self.right_images.sort()
        assert(self.total_left_images == self.total_right_images)
        print("Loading Dataset: COMPLETE! took ", time.time()-start, " seconds")
        print("Total Stereo images acquired: ", len(self.right_images))       

    def __len__(self):
        return len(self.right_images)

    def __getitem__(self, idx):
        
        # read stereo images from disk
        if idx == 0:
            left_img = cv2.imread(self.left_images[idx])
            right_img = cv2.imread(self.right_images[idx])
        else:
            left_img = cv2.imread(self.left_images[idx])
            right_img = cv2.imread(self.right_images[idx])

        np.asarray(left_img)
        np.asarray(right_img)

        # resize as per model requirement
        left_img = cv2.resize(left_img,(self.cols, self.rows))
        right_img = cv2.resize(right_img,(self.cols, self.rows))

        # reshape for pytorch [channels, rows, cols]
        left_img = np.moveaxis(left_img, 2,0)
        right_img = np.moveaxis(right_img, 2,0)

        return {"left_img":left_img,"right_img":right_img}

In [12]:
ds = KITTIDataset()
ds.__len__()

Loading Dataset: from ->  /home/sur/MonoDepth1_Implementation/dataset/training
Loading Dataset: COMPLETE! took  0.03135204315185547  seconds
Total Stereo images acquired:  4200


4200