In [None]:
import os
import sys
import os.path as path
sys.path.append("D:\\ASGaze")

import cv2
import numpy as np
import configparser
from glob import glob

import torch
import torch.utils.data

import import_ipynb
from iris_boundary_detector.data_sources.transform import aug,raw
from iris_boundary_detector.graph.losses import one_hot2dist

In [None]:
class ASGaze_data(torch.utils.data.Dataset):
    def __init__(self, datapath, name, out_res=(321,321), split='train', aug=True, flag=0):
        """ Pytorch Dataset definition for ASGaze dataset
        KeyArgs:
            out_res(tuple(int,int)): desired image resolution
            aug(bool): data augmentation
            flag(int): Type of return value for different usages. 
                    0: inference
                    1: train
        """
        self.datapath, self.aug = path.join(datapath,name,split), aug
        print("datapath",os.path.abspath(self.datapath))
        if(flag == 0):
            self.img_fns = glob(path.join(self.datapath,'*'))
        else:
            self.img_fns = glob(path.join(self.datapath,'image/*'))
        self.out_res = out_res
        self.name, self.config = name,configparser.ConfigParser()
        self._difficulty, self.flag = 0.0, flag  # Define bounds for noise values for different augmentation types
    
        print("Load {} samples of {} Dataset {} split from {}".format(len(self.img_fns),name,split,self.datapath))
    
    def set_difficulty(self, difficulty):
        """Set level of data augmentation."""
        assert isinstance(difficulty, float)
        assert 0.0 <= difficulty <= 1.0
        self._difficulty = difficulty
        
    def rect_transform(self, rect, trans):
        """Convert cropped rect back to original one"""
        return [(rect[0] + trans[1:3])/trans[0],rect[1]/trans[0],rect[-1]]
    
    def parse_ellipse(self, datapath):
        """Parsing ini format ellipse"""
        self.config.read(datapath)
        center = [
            float(self.config.get('iris','center_x')), 
            float(self.config.get('iris','center_y'))
        ]

        l,s = float(self.config.get('iris','long_radius')),float(self.config.get('iris','short_radius'))
        
        if l > s:
            degree = float(self.config.get('iris','rad_phi'))/np.pi*180.0*-1
            radius = [l*2,s*2]
        else:
            degree = float(self.config.get('iris','rad_phi'))/np.pi*180.0*-1 + 90
            radius = [s*2,l*2]

        return [center, radius, degree]
        
    def __len__(self):
        return len(self.img_fns)
    
    def __getitem__(self, idx):
        full_img = cv2.cvtColor(cv2.imread(self.img_fns[idx]),cv2.COLOR_BGR2RGB)
        file_id = self.img_fns[idx].split(".")[-2]
        data_id = file_id.split("\\")[-1]
        
        if self.flag == 0: # For inference
            crop_img, rect_trans = raw(full_img,self.out_res, 0.0)
            return data_id,crop_img,rect_trans,full_img
        
        if self.flag == 1: # For training segmentation network
            # Load ground truth of ellipse parameters
            rotated_rect = self.parse_ellipse(path.join(self.datapath,'ellipse_params',data_id+".ini"))  
            # Load ground truth of mask
            gt_img = cv2.imread(path.join(self.datapath,'mask',data_id+".png"))
            crop_img, gt ,one_hot, rotated_rect, rect_trans = aug(
                full_img, [gt_img], rotated_rect, self.out_res, flag=self.flag, diff=self._difficulty)
            
            # For iris missing loss
            iris_missing_weights = cv2.Canny(gt.astype(np.uint8),0,2)/255
            iris_missing_weights = cv2.dilate(iris_missing_weights,(3,3),iterations = 1)*20
            iris_missing_weights[iris_missing_weights==0] = 1

            # For distance map loss
            distMap = []
            for i in range(0, 3):
                distMap.append(one_hot2dist(gt==i))
            distMap = np.stack(distMap, 0) 
            
            return data_id, crop_img, gt, one_hot, rotated_rect, rect_trans, iris_missing_weights, np.float32(distMap) 
            
    def get_test(self, idx):
        return self.__getitem__(idx)