In [50]:
from lxml import etree
import numpy as np
import time
import cv2
import random
import sys

def get_img_xml_info(xml_path):
    '''
    Read xml file and extract the information of dimensions and each tile
    return - (1)dim_elem_num - linspace(uint), the quantity of voxels for each dimension,
    (2)dim_len - linspace(float), the length for one 3D image,
    (3)voxel_len - lingspace(float), the length for a voxel,
    (4)tile_num - int, the quantity of tiles,
    (5)tile_field - array(uint), the identifier of each file
    (6)tile_pos - array(float), the XYZ position information of each tile
    '''
    parser=etree.XMLParser()
    my_et=etree.parse(xml_path,parser=parser)
    dim_attrib=my_et.xpath('//Dimensions/DimensionDescription')
    dim_elem_num=np.zeros(3,dtype='uint')
    dim_len=np.zeros(3)
    for i in range(3):
        dim_elem_num[i],dim_len[i]=dim_attrib[i].attrib['NumberOfElements'],dim_attrib[i].attrib['Length']
    voxel_len=dim_len/dim_elem_num
    return dim_elem_num,dim_len,voxel_len

def import_2D_img(img_path,img_name,z_th):
    one_img_name=r'%s\%s_z%.2d_RAW_ch00.tif'%(img_path,img_name,z_th)
    return cv2.imread(one_img_name,cv2.IMREAD_GRAYSCALE)

def pyr_down_img(img,times=1):
    img_down=cv2.pyrDown(img)
    for i in range(times-1):
        img_down=cv2.pyrDown(img_down)
    return img_down

def loss_fun(ovl1,ovl2):
    ovl1,ovl2=ovl1.astype('float32'),ovl2.astype('float32')
    loss=np.sum((ovl1-ovl2)**2)/np.sqrt(np.sum(ovl1**2)*np.sum(ovl2**2))
    return loss

def calculate_xy_shift_by_RANSAC(img1,img2,pts1,pts2,loss_threshold=1,sample_time=1000):
    count=0
    matches_num=pts1.shape[0]
    RANSAC_num=np.int32(np.min((2,matches_num*0.1)))
    loss_threshold=0.05
    loss_min=np.inf
    xy_shift_min=np.zeros(2)
    while(loss_min>loss_threshold and count<sample_time):
        count+=1
        index_list=random.sample(range(matches_num),RANSAC_num)
        xy_shift_all=pts2[index_list,:]-pts1[index_list,:]
        max_shift,min_shift=np.max(xy_shift_all,axis=0),np.min(xy_shift_all,axis=0)
        if any((max_shift-min_shift)>50):
            continue
        xy_shift=np.int32(np.round(np.mean(xy_shift_all,axis=0)))#XY
        if all(xy_shift==xy_shift_min):
            continue
        ovl1,ovl2=img1[np.max((0,-xy_shift[1])):,np.max((0,-xy_shift[0])):],img2[np.max((0,xy_shift[1])):,np.max((0,xy_shift[0])):]
        x_range,y_range=np.min((ovl1.shape[1],ovl2.shape[1])),np.min((ovl1.shape[0],ovl2.shape[0]))
        ovl1,ovl2=ovl1[0:y_range,0:x_range],ovl2[0:y_range,0:x_range]
        this_loss=loss_fun(ovl1,ovl2)
        #print(xy_shift,this_loss)
        if this_loss<loss_min:
            loss_min=this_loss
            xy_shift_min=xy_shift
    #print(xy_shift_min,loss_min)
    return xy_shift_min,loss_min

def calculate_xy_shift_by_BF(img1_down,img2_down,img1,img2,xy_shift):
    ################################################################
    xy_shift_min=np.zeros(0)
    loss_min=np.inf
    for x in range(-5,6):
        for y in range(-5,6):
            this_xy_shift=xy_shift+np.array([x,y],dtype='int32')
            ovl1=img1_down[np.max((0,-this_xy_shift[1])):,np.max((0,-this_xy_shift[0])):]
            ovl2=img2_down[np.max((0,this_xy_shift[1])):,np.max((0,this_xy_shift[0])):]
            x_range,y_range=np.min((ovl1.shape[1],ovl2.shape[1])),np.min((ovl1.shape[0],ovl2.shape[0]))
            ovl1,ovl2=ovl1[0:y_range,0:x_range],ovl2[0:y_range,0:x_range]
            this_loss=loss_fun(ovl1,ovl2)
            if this_loss<loss_min:
                loss_min=this_loss
                xy_shift_min=this_xy_shift
            #print(this_xy_shift,this_loss)
    #print('first ',xy_shift_min,loss_min)
    ################################################################
    xy_shift_whole=xy_shift_min*4
    xy_shift_min=np.zeros(0)
    loss_min=np.inf
    for x in range(-18,19,3):
        for y in range(-18,19,3):
            this_xy_shift=xy_shift_whole+np.array([x,y],dtype='int32')
            ovl1=img1[np.max((0,-this_xy_shift[1])):,np.max((0,-this_xy_shift[0])):]
            ovl2=img2[np.max((0,this_xy_shift[1])):,np.max((0,this_xy_shift[0])):]
            x_range,y_range=np.min((ovl1.shape[1],ovl2.shape[1],2000)),np.min((ovl1.shape[0],ovl2.shape[0],2000))
            ovl1,ovl2=ovl1[0:y_range,0:x_range],ovl2[0:y_range,0:x_range]
#             if x==0 and y==0:
#                 cv2.imshow('1',ovl1)
#                 cv2.imshow('2',ovl2)
#                 cv2.waitKey()
#                 cv2.destroyAllWindows()
            this_loss=loss_fun(ovl1,ovl2)
            if this_loss<loss_min:
                loss_min=this_loss
                xy_shift_min=this_xy_shift
    #print(xy_shift_min)
    xy_shift_whole=xy_shift_min
    for x in range(-2,3):
        for y in range(-2,3):
            this_xy_shift=xy_shift_whole+np.array([x,y],dtype='int32')
            ovl1=img1[np.max((0,-this_xy_shift[1])):,np.max((0,-this_xy_shift[0])):]
            ovl2=img2[np.max((0,this_xy_shift[1])):,np.max((0,this_xy_shift[0])):]
            x_range,y_range=np.min((ovl1.shape[1],ovl2.shape[1],2000)),np.min((ovl1.shape[0],ovl2.shape[0],2000))
            ovl1,ovl2=ovl1[0:y_range,0:x_range],ovl2[0:y_range,0:x_range]
            this_loss=loss_fun(ovl1,ovl2)
            if this_loss<loss_min:
                loss_min=this_loss
                xy_shift_min=this_xy_shift
    #print('second ',xy_shift_min,loss_min)
    return xy_shift_min,loss_min

def update_lower_layer_info(xy_shift,axis_range1,axis_range2,voxel_len):
    axis_range2[0,0]=-axis_range2[0,0]+axis_range1[0,0]-xy_shift[0]*voxel_len[0]
    axis_range2[0,1]=-axis_range2[0,0]+axis_range1[0,0]-xy_shift[0]*voxel_len[0]
    axis_range2[1,0]=-axis_range2[1,0]+axis_range1[1,0]-xy_shift[1]*voxel_len[1]
    axis_range2[1,1]=-axis_range2[1,0]+axis_range1[1,0]-xy_shift[1]*voxel_len[1]
    return axis_range2

def adjust_contrast(img1,img2):
    img1,img2=img1.astype('float32'),img2.astype('float32')
    m1,m2=np.mean(img1),np.mean(img2)
    m=np.max((m1,m2))
    img1=np.uint8(np.clip(m/m1*img1,0,255))
    img2=np.uint8(np.clip(m/m2*img2,0,255))
    return img1,img2

def find_roi_region(img,thredshold=5):
    roi_range=np.zeros((2,2),dtype='uint32')
    lr_roi_index=[i for [i] in np.argwhere(np.mean(img,axis=0)>thredshold)]
    ud_roi_index=[i for [i] in np.argwhere(np.mean(img,axis=1)>thredshold)]
    roi_range[0,0],roi_range[0,1]=min(lr_roi_index),max(lr_roi_index)
    roi_range[1,0],roi_range[1,1]=min(ud_roi_index),max(ud_roi_index)
    return img[roi_range[1,0]:roi_range[1,1],roi_range[0,0]:roi_range[0,1]],roi_range

def start_vertical_stitch(i,img_path,save_path,file_name,img_name):
    sift=cv2.xfeatures2d.SIFT_create()
    bf=cv2.BFMatcher()
    img_path1,img_path2=img_path+'\\'+file_name%(i),img_path+'\\'+file_name%(i+1)
    xml_path1,xml_path2=img_path+'\\'+file_name%(i)+r'\MetaData'+'\\'+img_name+'.xml',img_path+'\\'+file_name%(i+1)+r'\MetaData'+'\\'+img_name+'.xml'
    dim_elem_num1,dim_len1,voxel_len=get_img_xml_info(xml_path1)
    dim_elem_num2,dim_len2,_=get_img_xml_info(xml_path2)
    axis_range1,axis_range2=np.load(save_path+r'\axis_range_%d.npy'%(i)),np.load(save_path+r'\axis_range_%d.npy'%(i+1))
    img2=import_2D_img(img_path2,img_name,0)
    img2,roi_range2=find_roi_region(img2)
    img2_down=pyr_down_img(img2,2)
    
    
    step=1
    loss_min=np.inf
    index_min=-1
    xy_shift_min=np.zeros(2)
    for j in range(dim_elem_num1[2]-10,np.uint32(np.round(dim_elem_num1[2]*0.3)),-step):
        img1=import_2D_img(img_path1,img_name,j)
        img1,roi_range1=find_roi_region(img1)
        img1_down=pyr_down_img(img1,2)
        img2_down=pyr_down_img(img2,2)
        img1_down,img2_down=adjust_contrast(img1_down,img2_down)
        kpts1,des1=sift.detectAndCompute(img1_down,None)
        kpts2,des2=sift.detectAndCompute(img2_down,None)
        kp1,kp2=np.float32([kp.pt for kp in kpts1]),np.float32([kp.pt for kp in kpts2])
        matches=bf.knnMatch(des1,des2,k=2)
        good_matches=[]
        good=[]
        for m in matches:
            if len(m)==2 and m[0].distance<0.75*m[1].distance:
                good_matches.append((m[0].queryIdx,m[0].trainIdx))
                good.append(m[0])
        img3=cv2.drawMatches(img1_down,kpts1,img2_down,kpts2,good,None,flags=2)
        cv2.imshow('3',img3)
        cv2.waitKey(3000)
        cv2.destroyAllWindows()
        pts1,pts2=np.float32([kp1[i,:] for (i,_) in good_matches]),np.float32([kp2[j,:] for (_,j) in good_matches])
        xy_shift,this_loss=calculate_xy_shift_by_RANSAC(img1_down,img2_down,pts1,pts2)
        print('%d.th layer for pyrDown image, xy_shift is %s, loss is %.8f'%(j,str(xy_shift),this_loss))
        xy_shift,this_loss=calculate_xy_shift_by_BF(img1_down,img2_down,img1,img2,xy_shift)
        print('%d.th layer for whole image, xy_shift is %s, loss is %.8f'%(j,str(xy_shift),this_loss))
        if this_loss<loss_min:
            loss_min=this_loss
            xy_shift_min[0],xy_shift_min[1]=xy_shift[0]+roi_range2[0,0]-roi_range1[0,0],xy_shift[0]+roi_range2[1,0]-roi_range1[1,0]
            index_min=j
    print('################################################################################')
    print('Finally the matched one is %d.th layer, xy_shift is %s, loss is %.8f'%(index_min,str(xy_shift_min),loss_min))
    print('################################################################################')
    
    print('start saving data')
    axis_range2=update_lower_layer_info(xy_shift_min,axis_range1,axis_range2,voxel_len)
    np.save(save_path+r'\first_index_%d.npy'%(i+1),np.array([dim_elem_num[2]-index_min],dtype='uint32'))
    np.save(save_path+r'\axis_range_stitch_%d.npy'%(i+1),axis_range2)
    print('end saving data')

In [52]:
img_path=r'C:\Users\dingj\ZhaoLab'
file_name=r'Project%d'
img_name=r'Region 1_Merged'
save_path=r'C:\Users\dingj\ZhaoLab\20220808_SiftVertStitich_MSR'
layer_num=3
for i in range(1,layer_num+1):
    xml_path=img_path+'\\'+file_name%(i)+r'\MetaData'+'\\'+img_name+'.xml'
    dim_elem_num,dim_len,voxel_len=get_img_xml_info(xml_path)
    axis_range=np.zeros((2,2))
    axis_range[0,1],axis_range[1,1]=dim_len[0],dim_len[1]
    #np.save('first_last_index_%d.npy'%(i),np.array([0,dim_elem_num[2]-1]))
    np.save(save_path+r'\axis_range_%d.npy'%(i),axis_range)
    if i==1:
        np.save(save_path+r'\axis_range_stitch_%d.npy'%(i),axis_range)

for i in range(1,layer_num):
    start_vertical_stitch(i,img_path,save_path,file_name,img_name)


73.th layer for pyrDown image, xy_shift is [ 80 -73], loss is 0.87126195
73.th layer for whole image, xy_shift is [ 298 -257], loss is 0.93835229
72.th layer for pyrDown image, xy_shift is [-256 -147], loss is 0.93997133
72.th layer for whole image, xy_shift is [-984 -614], loss is 0.97304702
71.th layer for pyrDown image, xy_shift is [-264  126], loss is 1.04500329
71.th layer for whole image, xy_shift is [-1016   532], loss is 1.00901592
70.th layer for pyrDown image, xy_shift is [-197 -208], loss is 0.90158623
70.th layer for whole image, xy_shift is [-748 -840], loss is 0.94850528
69.th layer for pyrDown image, xy_shift is [-27 -22], loss is 0.74769270
69.th layer for whole image, xy_shift is [-107  -85], loss is 0.79024857
68.th layer for pyrDown image, xy_shift is [-205 -165], loss is 0.88729662
68.th layer for whole image, xy_shift is [-848 -700], loss is 0.93854380
67.th layer for pyrDown image, xy_shift is [-216 -287], loss is 0.93783927
67.th layer for whole image, xy_shift i

57.th layer for pyrDown image, xy_shift is [-236 -212], loss is 0.91749573
57.th layer for whole image, xy_shift is [-984 -888], loss is 1.89005864
56.th layer for pyrDown image, xy_shift is [-280 -291], loss is 0.79047513
56.th layer for whole image, xy_shift is [-1131 -1156], loss is 1.73104811
55.th layer for pyrDown image, xy_shift is [-294 -176], loss is 0.91408288
55.th layer for whole image, xy_shift is [-1204  -692], loss is 1.90730476
54.th layer for pyrDown image, xy_shift is [-225 -210], loss is 0.90652496
54.th layer for whole image, xy_shift is [-900 -880], loss is 1.86235249
53.th layer for pyrDown image, xy_shift is [-448 -296], loss is 0.78025776
53.th layer for whole image, xy_shift is [-1792 -1184], loss is 1.34500527
52.th layer for pyrDown image, xy_shift is [-205 -122], loss is 0.98433840
52.th layer for whole image, xy_shift is [-820 -528], loss is 2.10103798
51.th layer for pyrDown image, xy_shift is [-450 -272], loss is 0.78959310
51.th layer for whole image, xy

In [4]:
from lxml import etree
import numpy as np
import time
import cv2
import random
import sys

a=np.array([[1,2,3],[2,3,4],[4,5,6],[6,7,8]])
[i for [i] in np.argwhere(np.mean(a,axis=0)>1)]

[0, 1, 2]