In [None]:
import sys
sys.path.append("D:\\ASGaze")

import os
import numpy as np
import cv2
import json
import torch
import torch.nn.functional as f

import import_ipynb
import iris_boundary_detector.utils.refinement as refinement
from iris_boundary_detector.data_sources.ASGaze_data import ASGaze_data
from iris_boundary_detector.graph.vgg_unet import get_model
from iris_boundary_detector.utils.load_model import data_gpu,load_checkpoint

In [None]:
def inference(val_loader, model, left_or_right, out_dir, video):
    model.eval()

    rect_trans_list = []
    full_img_list = []
    crop_img_list = []
    entropy_partial_list = []
    
    for i, (data_id, img, rect_trans,full_img) in enumerate(val_loader):
        rect_trans_list.append(rect_trans)
        full_img_list.append(full_img)
        
        out = model(data_gpu(img, device))
        
        prob = f.softmax(out, dim=1).cpu().numpy()
        prob_mask = torch.argmax(f.softmax(out, dim=1), dim=1).cpu().numpy()[0] # Most likely class  
        prob_mask_2 = torch.argmin(f.softmax(out, dim=1), dim=1).cpu().numpy()[0] # Least likely class
        crop_img = img.cpu().numpy()[0]
        crop_img_list.append(crop_img)
        
        # Complete entropy 
        entropy = np.sum(-prob*np.log(prob),axis=1)[0] 
        entropy = (entropy-entropy.min())/(entropy.max()-entropy.min())*255
        cv2.imwrite(os.path.join(out_dir,"visual_"+left_or_right,data_id[0]+"_entropy.png"),np.abs(entropy-255))

        # Partial entropy
        entropy_partial = np.sum(-prob[:,1:3,:,:]*np.log(prob[:,1:3,:,:]),axis=1)[0] # Relative entropy to quantify an uncertainty cross [iris,sclera]    
        entropy_partial[np.isnan(entropy_partial)] = 0
        entropy_filter = np.zeros(shape=entropy_partial.shape)
        row,col = np.where(prob_mask_2==0) # Least likely to be {background} (most likely to be {iris,sclera})
        entropy_filter[row,col] = 1

        entropy_partial = entropy_partial*entropy_filter
        entropy_partial = (entropy_partial-entropy_partial.min())/(entropy_partial.max()-entropy_partial.min())*255

        base_img_mask = cv2.threshold(entropy_partial.astype(np.uint8), 0, 255, cv2.THRESH_BINARY+cv2.THRESH_OTSU)[1] # Binarization image
        rows,cols = np.where(base_img_mask==0) 
        entropy_partial[rows,cols] = 0
        entropy_partial_list.append(entropy_partial)
        cv2.imwrite(os.path.join(out_dir,"visual_"+left_or_right,data_id[0]+"_entropy_p.png"),np.abs(entropy_partial-255))

        # Write video (optional)
        if video:
            entropy_img = (np.abs(entropy)).astype(np.uint8)
            entropy_partial_img = (np.abs(entropy_partial-255)).astype(np.uint8)
            if(left_or_right == "left"):
                video_entropy_left.write(cv2.cvtColor(entropy_img, cv2.COLOR_GRAY2RGB))
                video_entropy_partial_left.write(cv2.cvtColor(entropy_partial_img, cv2.COLOR_GRAY2RGB))
            else:
                video_entropy_right.write(cv2.cvtColor(entropy_img, cv2.COLOR_GRAY2RGB))
                video_entropy_partial_right.write(cv2.cvtColor(entropy_partial_img, cv2.COLOR_GRAY2RGB))
    
    #--------------Fit ellipse----------#    
    global temp_img  
    flag_l,pt1_l,pt2_l,flag_r,pt1_r,pt2_r = refinement.is_feature_matching_all(entropy_partial_list[0].astype(np.uint8),entropy_partial_list[1].astype(np.uint8))
    if(flag_l == True and flag_r == True):# The number of matches in the left and right parts of both images is greater than threshhold
        img1_new,points1,img2_new,points2 = refinement.feature_remove(entropy_partial_list[0].astype(np.uint8),entropy_partial_list[1].astype(np.uint8),pt1_l,pt2_l,pt1_r,pt2_r)
        # cv2 ellipse fitting (least square)
        e11= cv2.fitEllipse(np.array(points1).T) # Point1
        e21 = refinement.rect_transform(e11, rect_trans_list[0][0].numpy())
        cv2.imwrite(os.path.join(out_dir,"visual_"+left_or_right,"{:05}".format(0)+"_entropy_p_fr.png"),np.abs(255-img1_new))

        e12= cv2.fitEllipse(np.array(points2).T) # Point2
        e22 = refinement.rect_transform(e12, rect_trans_list[1][0].numpy())
        cv2.imwrite(os.path.join(out_dir,"visual_"+left_or_right,"{:05}".format(1)+"_entropy_p_fr.png"),np.abs(255-img2_new))

        temp_img = img2_new
        
        refine_img1 = (np.abs(255-img1_new)).astype(np.uint8)
        refine_img2 = (np.abs(255-img2_new)).astype(np.uint8)
            
    else: 
        # Not enough pixels to match
        print("initial else")
        entropy_fd1,entropy_fd_points1 = refinement.feature_detect(entropy_partial_list[0].astype(np.uint8))
        e11= cv2.fitEllipse(np.array(entropy_fd_points1).T) # Point1
        e21 = refinement.rect_transform(e11, rect_trans_list[0][0].numpy())
        cv2.imwrite(os.path.join(out_dir,"visual_"+left_or_right,"{:05}".format(0)+"_entropy_p_fd.png"),np.abs(255-entropy_fd1))
        
        entropy_fd2,entropy_fd_points2 = refinement.feature_detect(entropy_partial_list[1].astype(np.uint8))
        e12= cv2.fitEllipse(np.array(entropy_fd_points2).T) # Point2
        e22 = refinement.rect_transform(e12, rect_trans_list[1][0].numpy())
        cv2.imwrite(os.path.join(out_dir,"visual_"+left_or_right,"{:05}".format(1)+"_entropy_p_fd.png"),np.abs(255-entropy_fd2))
        
        temp_img = entropy_fd2
        
        refine_img1 = (np.abs(255-entropy_fd1)).astype(np.uint8)
        refine_img2 = (np.abs(255-entropy_fd2)).astype(np.uint8)

    full_img = cv2.ellipse(cv2.cvtColor(full_img_list[0][0].numpy(),cv2.COLOR_RGB2BGR),e21,(0,255,255),1)
#     cv2.imwrite(os.path.join(out_dir,"visual_"+left_or_right,"{:05}".format(0)+".png"),full_img)
    crop_img = (np.transpose(crop_img_list[0]+1,(1,2,0))*255/2).astype(np.uint8)
    crop_img = cv2.cvtColor(crop_img,cv2.COLOR_RGB2BGR)
    crop_img = cv2.ellipse(crop_img,e11,(0,255,255),1)
    cv2.imwrite(os.path.join(out_dir,"visual_"+left_or_right,"{:05}".format(0)+"_crop.png"),crop_img)
    
    full_img = cv2.ellipse(cv2.cvtColor(full_img_list[1][0].numpy(),cv2.COLOR_RGB2BGR),e22,(0,255,255),1)
#     cv2.imwrite(os.path.join(out_dir,"visual_"+left_or_right,"{:05}".format(1)+".png"),full_img)
    crop_img = (np.transpose(crop_img_list[1]+1,(1,2,0))*255/2).astype(np.uint8)
    crop_img = cv2.cvtColor(crop_img,cv2.COLOR_RGB2BGR)
    crop_img = cv2.ellipse(crop_img,e12,(0,255,255),1)
    cv2.imwrite(os.path.join(out_dir,"visual_"+left_or_right,"{:05}".format(1)+"_crop.png"),crop_img)
    refinement.write_ellipse_params(e21, os.path.join(out_dir,"ellipse_params_"+left_or_right,"{:05}".format(0)+".ini"))
    refinement.write_ellipse_params(e22, os.path.join(out_dir,"ellipse_params_"+left_or_right,"{:05}".format(1)+".ini"))
    
    # Wirte video (optional)
    if video:
        if(left_or_right == "left"):
            video_refine_left.write(cv2.cvtColor(refine_img1, cv2.COLOR_GRAY2RGB))
            video_refine_left.write(cv2.cvtColor(refine_img2, cv2.COLOR_GRAY2RGB))
        else:
            video_refine_right.write(cv2.cvtColor(refine_img1, cv2.COLOR_GRAY2RGB))
            video_refine_right.write(cv2.cvtColor(refine_img2, cv2.COLOR_GRAY2RGB))

    for idx in range(2,len(entropy_partial_list)):
        flag_l,pt1_l,pt2_l,flag_r,pt1_r,pt2_r = refinement.is_feature_matching_all(temp_img.astype(np.uint8),entropy_partial_list[idx].astype(np.uint8))
        if(flag_l == True and flag_r == True):
            _,_,img_new,points = refinement.feature_remove(temp_img.astype(np.uint8),entropy_partial_list[idx].astype(np.uint8),pt1_l,pt2_l,pt1_r,pt2_r)
            final_points = points
            cv2.imwrite(os.path.join(out_dir,"visual_"+left_or_right,"{:05}".format(idx)+"_entropy_p_fr.png"),np.abs(255-img_new))

            temp_img = img_new
            
        else:
            print("else")
            entropy_fd,entropy_fd_points = refinement.feature_detect(entropy_partial_list[idx].astype(np.uint8))
            final_points = entropy_fd_points
            cv2.imwrite(os.path.join(out_dir,"visual_"+left_or_right,"{:05}".format(idx)+"_entropy_p_fd.png"),np.abs(255-entropy_fd))
        
            temp_img = entropy_fd  
            
        e1= cv2.fitEllipse(np.array(final_points).T)
        e2 = refinement.rect_transform(e1, rect_trans_list[idx][0].numpy())
        full_img = cv2.ellipse(cv2.cvtColor(full_img_list[idx][0].numpy(),cv2.COLOR_RGB2BGR),e2,(0,255,255),1)
#         cv2.imwrite(os.path.join(visual,"visual_"+left_or_right,"{:05}".format(idx)+".png"),full_img)
        crop_img = (np.transpose(crop_img_list[idx]+1,(1,2,0))*255/2).astype(np.uint8)
        crop_img = cv2.cvtColor(crop_img,cv2.COLOR_RGB2BGR)
        crop_img = cv2.ellipse(crop_img,e1,(0,255,255),1)
        cv2.imwrite(os.path.join(out_dir,"visual_"+left_or_right,"{:05}".format(idx)+"_crop.png"),crop_img)
        refinement.write_ellipse_params(e2, os.path.join(out_dir,"ellipse_params_"+left_or_right,"{:05}".format(idx)+".ini"))
        
        # Wirte video (optional)
        if video:
            refine_img = (np.abs(255-temp_img)).astype(np.uint8)
            if(left_or_right == "left"):
                video_refine_left.write(cv2.cvtColor(refine_img, cv2.COLOR_GRAY2RGB))
            else:
                video_refine_right.write(cv2.cvtColor(refine_img, cv2.COLOR_GRAY2RGB))

In [None]:
def inference_main(cfile,test_name,data_flag,video):
            
    # Load config file
    config = json.load(open(cfile))
    data_name = config['data']
    # -------------------------------------Save Dir initialization----------------------------------------------------- #
    runs_dir = os.path.join(config['val_data']['dir'], test_name, config['trainer']['load_id'])
    print("runs_dir",os.path.abspath(runs_dir))
    
    (out_dir,visual_dir_l, ellipse_dir_l, visual_dir_r, ellipse_dir_r) = (os.path.join(runs_dir,"InferResults"),
            os.path.join(runs_dir,"InferResults","visual_left"),
            os.path.join(runs_dir,"InferResults","ellipse_params_left"),
            os.path.join(runs_dir,"InferResults","visual_right"),
            os.path.join(runs_dir,"InferResults","ellipse_params_right"))
        
    for t in (out_dir,visual_dir_l,ellipse_dir_l,visual_dir_r,ellipse_dir_r):
        if not(os.path.isdir(t)): os.makedirs(t)   
            
    if video:
        global video_entropy_left,video_entropy_partial_left,video_refine_left,video_entropy_right,video_entropy_partial_right,video_refine_right
        video_entropy_left = cv2.VideoWriter(os.path.join(out_dir, "entropy_left.avi"), cv2.VideoWriter_fourcc(*"XVID"), 30, (321,321))
        video_entropy_partial_left = cv2.VideoWriter(os.path.join(out_dir, "entropy_partial_left.avi"), cv2.VideoWriter_fourcc(*"XVID"), 30, (321,321))
        video_refine_left = cv2.VideoWriter(os.path.join(out_dir, "refine_left.avi"), cv2.VideoWriter_fourcc(*"XVID"), 30, (321,321))

        video_entropy_right = cv2.VideoWriter(os.path.join(out_dir, "entropy_right.avi"), cv2.VideoWriter_fourcc(*"XVID"), 30, (321,321))
        video_entropy_partial_right = cv2.VideoWriter(os.path.join(out_dir, "entropy_partial_right.avi"), cv2.VideoWriter_fourcc(*"XVID"), 30, (321,321))
        video_refine_right = cv2.VideoWriter(os.path.join(out_dir, "refine_right.avi"), cv2.VideoWriter_fourcc(*"XVID"), 30, (321,321))
    
    # -------------------------------------Gpu Device && Tensorboard--------------------------------------------------- #
    global device
    device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

    # -------------------------------------Model and Optimizer Initialize-------------------------------------------------- #
    arch = config['arch']
    model = get_model("A")
    model.cuda()
    
    # Load dict
    checkpoint_dir = os.path.join(config['trainer']['runs_dir'], data_name+"-"+config['trainer']['load_id'])
    print("load checkpoingts from {}".format(os.path.abspath(checkpoint_dir)))
    checkpoint = load_checkpoint(os.path.join(checkpoint_dir), False)
    model.load_state_dict(checkpoint['model_state_dict'])
    
    vd = config['val_data'] # Prepare test dataset
    test_patch = test_name+"/crop"
    val_set_left = ASGaze_data(datapath=vd['dir'], name=test_patch,split="left",flag=data_flag) 
    val_loader_left = torch.utils.data.DataLoader(val_set_left, batch_size=vd['batch_size'], shuffle=vd['shuffle'], num_workers=vd['num_workers'])
    
    val_set_right = ASGaze_data(datapath=vd['dir'], name=test_patch,split="right",flag=data_flag) 
    val_loader_right = torch.utils.data.DataLoader(val_set_right, batch_size=vd['batch_size'], shuffle=vd['shuffle'], num_workers=vd['num_workers'])

    if data_flag == 0:
        with torch.no_grad():
            inference(val_loader_left, model, "left", out_dir, video)
            print("Left eye inference finish")
            inference(val_loader_right, model, "right", out_dir, video)
            print("Right eye inference finish")
        pass
    
    if video:
        video_entropy_left.release()
        video_entropy_partial_left.release()
        video_refine_left.release()

        video_entropy_right.release()
        video_entropy_partial_right.release()
        video_refine_right.release()

        cv2.destroyAllWindows()