<a href="https://colab.research.google.com/github/1byxero/pytorch-OCID-Dataloder/blob/master/OCID-Dataloader.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [0]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data as data
import numpy as np

import urllib
import tarfile
import cv2

from os import getcwd, listdir
from os.path import join as ospj, isdir, isfile

class OCIDDataset(data.Dataset):
    def custom_listdir(self, path):
        if isdir(path):
            return listdir(path)        
    
    def popolate_sample_dict(self, path, fl_name):
        fl_name_ext = fl_name.split(".")[0]
        pcd_file = ospj(path, 'pcd', fl_name_ext+".pcd")
        if not isfile(pcd_file):
            #one pcd folder is incorrectly named as pd
            #this takes care of it
            pcd_file = "/pd/".join(pcd_file.split('/pcd/'))
        img = {
            'label': ospj(path, 'label', fl_name_ext+".png"),
            'rgb': ospj(path, 'rgb', fl_name_ext+".png"),
            'depth': ospj(path, 'depth', fl_name_ext+".png"),
            'pcd': pcd_file
        }
        return img
    
    def create_list_of_samples_from_path(self, path):
        angle_type = ['bottom', 'top']
        dataset_type = ['ARID10', 'ARID20', 'YCB10']
        base_type = ['floor', 'table']

        objof = {
            'ARID10': ['box', 'curved', 'fruits', 'mixed', 'non-fruits'],
            'ARID20': [None],'YCB10': ['cuboid', 'curved', 'mixed']
        }
        
        data = []
        for ds in dataset_type:
            for bs in base_type:
                for angle in angle_type:
                    for obj_type in objof[ds]:
                        if obj_type is None:
                            pth = ospj(path, ds, bs, angle)
                            for seq in self.custom_listdir(pth):
                                pth2 = ospj(pth, seq)
                                dir_contents = self.custom_listdir(pth2)
                                if not dir_contents:
                                    continue
                                for fl_name in self.custom_listdir(
                                    ospj(pth2, dir_contents[0])
                                ):
                                    data.append(
                                        self.popolate_sample_dict(pth2, fl_name)
                                    )
                            continue
                        pth = ospj(path, ds, bs, angle, obj_type)
                        for seq in self.custom_listdir(pth):
                            pth2 = ospj(pth, seq)
                            dir_contents = self.custom_listdir(pth2)
                            if not dir_contents:
                                continue
                            for fl_name in self.custom_listdir(
                                ospj(pth2, dir_contents[0])
                            ):
                                data.append(
                                    self.popolate_sample_dict(pth2, fl_name)
                                )
        return data
    
    def __init__(self):
        dataset_dir = ospj(getcwd(), 'OCID-dataset')
        if not isdir(dataset_dir):
            dataset_tar = ospj(getcwd(), 'OCID-dataset.tar.gz')
            if not isfile(dataset_tar):
                dataset_url = 'https://data.acin.tuwien.ac.at/index.php/s/g3EkcgcPioolQmJ/download'
                print("Downloading the OCID dataset...")
                urllib.request.urlretrieve(dataset_url,dataset_tar)
                print("Download Complete")
            print("Extracting dataset tar")
            print("This may take a while ...")
            tar = tarfile.open(dataset_tar, "r:gz")
            tar.extractall()
            tar.close()
            print("Completed extracting!")        
        print("OCID dataset is available locally")
        self.samples = self.create_list_of_samples_from_path(dataset_dir)[:10]
        print("Dataset object created")
        
    def __len__(self):
        'Denotes the total number of samples'
        return len(self.samples)
    
    def __getitem__(self, index):
        'Generates one sample of data'
        label = torch.from_numpy(
            cv2.imread(self.samples[index]['label'], cv2.IMREAD_UNCHANGED).astype(np.int16)
        )
        rgb = torch.from_numpy(
            cv2.imread(self.samples[index]['rgb']).astype(np.int16)
        )
        depth = torch.from_numpy(
            cv2.imread(self.samples[index]['depth'], cv2.IMREAD_UNCHANGED).astype(np.int16)
        )
        rgbd = torch.cat(
            [rgb, torch.unsqueeze(depth, dim=2)], dim=2
        ).permute(2, 1, 0).float()
        return rgbd, rgb, depth, label


In [0]:
import time
sttime = time.time()
ds = OCIDDataset()
training_generator = data.DataLoader(ds, batch_size=10)
print("time taken {} seconds".format(time.time()-sttime))

Downloading the OCID dataset...
