In [None]:
import os
import glob 
import json
import math
import numpy as np
from imgaug import augmenters as iaa
from imgaug.augmentables.segmaps import SegmentationMapsOnImage
from torch.utils.data import Dataset

In [None]:
class ContentOrientedDataset(Dataset):
    def __init__(self, root='', crop_size=256, 
        normalize=False, **kwargs):
        super().__init__()

        self.data_dir = root 
        img_extensions = ['.jpg', '.png']
        self.imgs = []
        for ext in img_extensions:
            self.imgs += glob.glob(os.path.join(self.data_dir, f'images/**/*{ext}'), recursive=True)
        self.crop_size = crop_size
        self.image_dims = (3, self.crop_size, self.crop_size)
        self.normalize = normalize
        
        json_file_path = os.path.join(self.data_dir, "face_coords.json")
        with open(json_file_path, 'r') as json_file:
            self.face_coords = json.load(json_file)

    def _augment(self, img, face_masks, structure_masks):
        """
        Apply augmentations 
        """
        SCALE_MIN = 0.75
        SCALE_MAX = 0.95
        H, W, _ = img.shape # slightly confusing
        shortest_side_length = min(H,W)
        minimum_scale_factor = float(self.crop_size) / float(shortest_side_length)
        scale_low = max(minimum_scale_factor, SCALE_MIN)
        scale_high = max(scale_low, SCALE_MAX)
        scale = np.random.uniform(scale_low, scale_high)

        self.augmentations = iaa.Sequential([iaa.Fliplr(0.5), # horizontally flip 50% of the images
                                             iaa.Resize((math.ceil(scale * H), math.ceil(scale * W))), # resize
                                             iaa.size.CropToFixedSize(self.crop_size,self.crop_size)])
        
        masks = np.dstack( [face_masks, structure_masks])
        masks = SegmentationMapsOnImage(masks, shape=(H,W,2))
        img, masks = self.augmentations(image=img, segmentation_maps=masks)
        masks = masks.get_arr()
        face_masks, structure_masks = masks[:,:,0], masks[:,:,1]
        
        return img, face_masks, structure_masks

    def _transforms(self, img, face_mask, structure_mask):
        pass
    
    def get_face_mask(self, idx, shape): 
        pass

    def get_structure_mask(self, idx):
        pass

    def __getitem__(self, idx):
        pass


        