In [None]:
from lxml import etree
import numpy as np
import time
import cv2
import multiprocessing
from multiprocessing import Value,Array,Process
import ctypes
import random
from numba import jit,njit
import sys

def get_img_xml_info(xml_path,save_path):
    '''
    Read xml file and extract the information of dimensions and each tile
    return - (1)dim_elem_num - linspace(uint), the quantity of voxels for each dimension,
    (2)dim_len - linspace(float), the length for one 3D image,
    (3)voxel_len - lingspace(float), the length for a voxel,
    (4)tile_num - int, the quantity of tiles,
    (5)tile_field - array(uint), the identifier of each file
    (6)tile_pos - array(float), the XYZ position information of each tile
    '''
    parser=etree.XMLParser()
    my_et=etree.parse(xml_path,parser=parser)
    dim_attrib=my_et.xpath('//Dimensions/DimensionDescription')
    dim_elem_num=np.zeros(3,dtype='uint')
    dim_len=np.zeros(3)
    for i in range(3):
        dim_elem_num[i],dim_len[i]=dim_attrib[i].attrib['NumberOfElements'],dim_attrib[i].attrib['Length']
    voxel_len=dim_len/dim_elem_num
    np.save(save_path+r'\dim_elem_num.npy',dim_elem_num)
    np.save(save_path+r'\dim_len.npy',dim_len)
    np.save(save_path+r'\voxel_len.npy',voxel_len)
    return dim_elem_num,dim_len,voxel_len

def import_2D_img(img_path,img_name,z_th):
    one_img_name=r'%s\%s_z%.2d_RAW_ch00.tif'%(img_path,img_name,z_th)
    return cv2.imread(one_img_name,cv2.IMREAD_GRAYSCALE)

def pyr_down_img(img,times=1):
    img_down=cv2.pyrDown(img)
    for i in range(times-1):
        img_down=cv2.pyrDown(img_down)
    return img_down

def loss_fun(ovl1,ovl2):
    ovl1,ovl2=ovl1.astype('float32'),ovl2.astype('float32')
    loss=np.sum((ovl1-ovl2)**2)/np.sqrt(np.sum(ovl1**2)*np.sum(ovl2**2))
    return loss

def calculate_xy_shift_by_RANSAC(img1,img2,pts1,pts2,loss_threshold=1,sample_time=1000):
    count=0
    matches_num=pts1.shape[0]
    RANSAC_num=np.int32(np.min((2,matches_num*0.1)))
    loss_threshold=0.05
    loss_min=np.inf
    xy_shift_min=np.zeros(2)
    while(loss_min>loss_threshold and count<sample_time):
        count+=1
        index_list=random.sample(range(matches_num),RANSAC_num)
        xy_shift_all=pts2[index_list,:]-pts1[index_list,:]
        max_shift,min_shift=np.max(xy_shift_all,axis=0),np.min(xy_shift_all,axis=0)
        if any((max_shift-min_shift)>50):
            continue
        xy_shift=np.int32(np.round(np.mean(xy_shift_all,axis=0)))#XY
        if all(xy_shift==xy_shift_min):
            continue
        ovl1,ovl2=img1[np.max((0,-xy_shift[1])):,np.max((0,-xy_shift[0])):],img2[np.max((0,xy_shift[1])):,np.max((0,xy_shift[0])):]
        x_range,y_range=np.min((ovl1.shape[1],ovl2.shape[1])),np.min((ovl1.shape[0],ovl2.shape[0]))
        ovl1,ovl2=ovl1[0:y_range,0:x_range],ovl2[0:y_range,0:x_range]
        this_loss=loss_fun(ovl1,ovl2)
        #print(xy_shift,this_loss)
        if this_loss<loss_min:
            loss_min=this_loss
            xy_shift_min=xy_shift
    #print(xy_shift_min,loss_min)
    return xy_shift_min,loss_min

def calculate_xy_shift_by_BF(img1_down,img2_down,img1,img2,xy_shift):
    ################################################################
    xy_shift_min=np.zeros(0)
    loss_min=np.inf
    for x in range(-5,6):
        for y in range(-5,6):
            this_xy_shift=xy_shift+np.array([x,y],dtype='int32')
            ovl1=img1_down[np.max((0,-this_xy_shift[1])):,np.max((0,-this_xy_shift[0])):]
            ovl2=img2_down[np.max((0,this_xy_shift[1])):,np.max((0,this_xy_shift[0])):]
            x_range,y_range=np.min((ovl1.shape[1],ovl2.shape[1])),np.min((ovl1.shape[0],ovl2.shape[0]))
            ovl1,ovl2=ovl1[0:y_range,0:x_range],ovl2[0:y_range,0:x_range]
            this_loss=loss_fun(ovl1,ovl2)
            if this_loss<loss_min:
                loss_min=this_loss
                xy_shift_min=this_xy_shift
            #print(this_xy_shift,this_loss)
    #print('first ',xy_shift_min,loss_min)
    ################################################################
    xy_shift_whole=xy_shift_min*4
    xy_shift_min=np.zeros(0)
    loss_min=np.inf
    for x in range(-18,19,3):
        for y in range(-18,19,3):
            this_xy_shift=xy_shift_whole+np.array([x,y],dtype='int32')
            ovl1=img1[np.max((0,-this_xy_shift[1])):,np.max((0,-this_xy_shift[0])):]
            ovl2=img2[np.max((0,this_xy_shift[1])):,np.max((0,this_xy_shift[0])):]
            x_range,y_range=np.min((ovl1.shape[1],ovl2.shape[1],2000)),np.min((ovl1.shape[0],ovl2.shape[0],2000))
            ovl1,ovl2=ovl1[0:y_range,0:x_range],ovl2[0:y_range,0:x_range]
#             if x==0 and y==0:
#                 cv2.imshow('1',ovl1)
#                 cv2.imshow('2',ovl2)
#                 cv2.waitKey()
#                 cv2.destroyAllWindows()
            this_loss=loss_fun(ovl1,ovl2)
            if this_loss<loss_min:
                loss_min=this_loss
                xy_shift_min=this_xy_shift
    #print(xy_shift_min)
    xy_shift_whole=xy_shift_min
    for x in range(-2,3):
        for y in range(-2,3):
            this_xy_shift=xy_shift_whole+np.array([x,y],dtype='int32')
            ovl1=img1[np.max((0,-this_xy_shift[1])):,np.max((0,-this_xy_shift[0])):]
            ovl2=img2[np.max((0,this_xy_shift[1])):,np.max((0,this_xy_shift[0])):]
            x_range,y_range=np.min((ovl1.shape[1],ovl2.shape[1],2000)),np.min((ovl1.shape[0],ovl2.shape[0],2000))
            ovl1,ovl2=ovl1[0:y_range,0:x_range],ovl2[0:y_range,0:x_range]
            this_loss=loss_fun(ovl1,ovl2)
            if this_loss<loss_min:
                loss_min=this_loss
                xy_shift_min=this_xy_shift
    #print('second ',xy_shift_min,loss_min)
    return xy_shift_min,loss_min

def update_lower_layer_info(xy_shift,axis_range1,axis_range2,dim_elem_num,voxel_len):
    axis_range2[0,0]=-axis_range2[0,0]+axis_range1[0,0]-xy_shift[0]*voxel_len[0]
    axis_range2[0,1]=-axis_range2[0,0]+axis_range1[0,0]-xy_shift[0]*voxel_len[0]
    axis_range2[1,0]=-axis_range2[1,0]+axis_range1[1,0]-xy_shift[1]*voxel_len[1]
    axis_range2[1,1]=-axis_range2[1,0]+axis_range1[1,0]-xy_shift[1]*voxel_len[1]
    voxel_num=np.uint64(np.round((axis_range2[:,1]-axis_range2[:,0])/voxel_len))
    return axis_range2,voxel_num

def adjust_contrast(img1,img2):
    img1,img2=img1.astype('float32'),img2.astype('float32')
    m1,m2=np.mean(img1),np.mean(img2)
    m=np.max((m1,m2))
    img1=np.uint8(np.clip(m/m1*img1,0,255))
    img2=np.uint8(np.clip(m/m2*img2,0,255))
    return img1,img2

def find_ROI

def start_vertical_stitch(i,img_path,save_path,file_name,img_name):
    