In [2]:
from torch.utils.data import Dataset
import os
from pathlib import Path
import glob
import random
import cv2 as cv
import numpy as np

In [8]:
class EndoscopyDataset(Dataset):
    """
    
    """
    def __init__(self, root, slice_vals = None):
        self.root = root
        self.image_idx = 0
        # Get annotations_list
        image_files = [file.name for file in Path(root).iterdir() if file.suffix in ['.png', '.jpg']]
#         print(len(image_files))
        annotations = []
        
        for file in image_files:
            annotation_path = Path(root, file + '.txt')
            if annotation_path.is_file():
                dict_annotation = self.__txt_to_dict(str(Path(root, file + '.txt')))
                dict_annotation['image_path'] = file
                annotations.append(dict_annotation)
#         print(len(annotations))
                
        # Cut annotations_list by slice_vals
        if slice_vals:
            size = len(annotations)
            random.seed(42)
            random.shuffle(annotations)
            annotations = annotations[int(size*slice_vals[0]) : int(size*slice_vals[1])]
            
        # Get images
        self.images = []
        for img_path in [str(Path(root, annotation['image_path'])) for annotation in annotations]:
            img = cv.imread(img_path)
            img = cv.cvtColor(img, cv.COLOR_BGR2RGB)
            img = np.array(img)
            self.images.append(img)
            
        # Get targets
        self.targets = []
        for idx, annotation in enumerate(annotations):
            boxes = annotation["bndboxes"]
            labels = annotation["labels"]
            # scale boxes to [0..1]
            height, width = self.images[idx].shape[:2]
            boxes = list(map(lambda box : self.__convert((width, height), box), boxes))

            target = {}
            target["boxes"] = boxes
            target["labels"] = labels
            self.targets.append(target)
                          
    def __iter__(self):
        self.image_idx = 0
        return self

    def __next__(self):
        if self.image_idx >= len(self.images):
            raise StopIteration
        img = self.images[self.image_idx].copy()
        target = dict(self.targets[self.image_idx])
        self.image_idx += 1
        return img, target
    
    def __getitem__(self, idx):
        # load images and masks
        img = self.images[idx].copy()
        target = dict(self.targets[idx])

        return img, target
    
    def __txt_to_dict(self, file: str):
        """Parse 
        "bndboxes": (x_cent,y_cent,width,height), "labels": (class), for each boxes in .txt file
        and "image_path" with target images.
        """ 
        
        with open(file, 'r') as f:
            lines = f.read().splitlines()

        target = {}
        target["bndboxes"] = []
        target["labels"] = []
        
        for line in lines:
            value = list(map(float, line.split()))
            target["bndboxes"].append(value[1:5])
            target["labels"].append(value[0])
            
        return target
    
    def __convert(self, size: tuple, box: list):
        """Takes as input:  (width, height) of an image
                                (x_cent, y_cent, w, h) of the bounding box
            and returns (x_cent, y_cent, w, h) in [0, 1] of the bounding box in yolo format.
        """   
        dw = 1./size[0]
        dh = 1./size[1]
        x = box[0]*dw
        w = box[2]*dw
        y = box[1]*dh
        h = box[3]*dh
        x = x+1/2*w
        y = y+1/2*h

        return (x, y, w, h)
        
    def __len__(self):
        return len(self.images)
    
        
    def save(self, save_path, dataset_type, on_append=False):
        """Save dataset to save_path to dataset_type directory"""
        image_path = os.path.join(save_path, "images", dataset_type)
        labels_path = os.path.join(save_path, "labels", dataset_type)
        os.makedirs(image_path, exist_ok=True)
        os.makedirs(labels_path, exist_ok=True)
        
        if on_append:
            numfile = len(os.listdir(os.path.join(save_path, "images", dataset_type)))
        
        for i, (img, target) in enumerate(self): 
            img = cv.cvtColor(img, cv.COLOR_RGB2BGR)
            file_number = i
            if on_append:
                file_number += numfile
            cv.imwrite(os.path.join(image_path, str(file_number)+".png"), img)
            
            labels = [[label, *box] for label, box in zip(target["labels"], target["boxes"])]
            with open(os.path.join(labels_path, str(file_number)+".txt"), 'w') as output:
                for row in labels:
                    output.write("\t".join(list(map(str,row))) + '\n')

In [9]:

train_face_dataset = EndoscopyDataset("../data/detection/", slice_vals=(0, 0.8))
valid_face_dataset = EndoscopyDataset("../data/detection/", slice_vals=(0.8, 1.0))

train_face_dataset.save('../data_prep/detection','train')
valid_face_dataset.save('../data_prep/detection','valid')