In [5]:
from lxml import etree
import numpy as np
import time
import cv2
import random
import sys
import os

def get_img_xml_info(xml_path):
    '''
    Read xml file and extract the information of dimensions and each tile
    return - (1)dim_elem_num - linspace(uint), the quantity of voxels for each dimension,
    (2)dim_len - linspace(float), the length for one 3D image,
    (3)voxel_len - lingspace(float), the length for a voxel,
    (4)tile_num - int, the quantity of tiles,
    (5)tile_field - array(uint), the identifier of each file
    (6)tile_pos - array(float), the XYZ position information of each tile
    '''
    parser=etree.XMLParser()
    my_et=etree.parse(xml_path,parser=parser)
    dim_attrib=my_et.xpath('//Dimensions/DimensionDescription')
    dim_elem_num=np.zeros(3,dtype='uint')
    dim_len=np.zeros(3)
    for i in range(3):
        dim_elem_num[i],dim_len[i]=dim_attrib[i].attrib['NumberOfElements'],dim_attrib[i].attrib['Length']
    voxel_len=dim_len/dim_elem_num
    return dim_elem_num,dim_len,voxel_len

def import_2D_img(img_path,img_name,z_th):
    one_img_name=r'%s\%sz%.3d.tif'%(img_path,img_name,z_th)
    return cv2.imread(one_img_name,cv2.IMREAD_GRAYSCALE)

def pyr_down_img(img,times=1):
    img_down=cv2.pyrDown(img)
    for i in range(times-1):
        img_down=cv2.pyrDown(img_down)
    return img_down

def loss_fun(ovl1,ovl2):
    ovl1,ovl2=ovl1.astype('float32'),ovl2.astype('float32')
    loss=np.sum((ovl1-ovl2)**2)/np.sqrt(np.sum(ovl1**2)*np.sum(ovl2**2))
    return loss

def calculate_xy_shift_by_RANSAC(img1,img2,pts1,pts2,loss_threshold=1,sample_time=1000):
    count=0
    matches_num=pts1.shape[0]
    RANSAC_num=np.int32(np.min((2,np.max((1,matches_num*0.1)))))
    loss_threshold=0.05
    loss_min=np.inf
    xy_shift_min=np.int32(np.zeros(2))
    if pts1.shape[0]==0 or pts1.shape[0]==0:
        return xy_shift_min,loss_min
    while(loss_min>loss_threshold and count<sample_time):
        count+=1
        index_list=random.sample(range(matches_num),RANSAC_num)
        xy_shift_all=pts2[index_list,:]-pts1[index_list,:]
        max_shift,min_shift=np.max(xy_shift_all,axis=0),np.min(xy_shift_all,axis=0)
        if any((max_shift-min_shift)>50):
            continue
        xy_shift=np.int32(np.round(np.mean(xy_shift_all,axis=0)))#XY
        if all(xy_shift==xy_shift_min):
            continue
        ovl1,ovl2=img1[np.max((0,-xy_shift[1])):,np.max((0,-xy_shift[0])):],img2[np.max((0,xy_shift[1])):,np.max((0,xy_shift[0])):]
        x_range,y_range=np.min((ovl1.shape[1],ovl2.shape[1])),np.min((ovl1.shape[0],ovl2.shape[0]))
        ovl1,ovl2=ovl1[0:y_range,0:x_range],ovl2[0:y_range,0:x_range]
        this_loss=loss_fun(ovl1,ovl2)
        #print(xy_shift,this_loss)
        if this_loss<loss_min:
            loss_min=this_loss
            xy_shift_min=xy_shift
    #print(xy_shift_min,loss_min)
    return xy_shift_min,loss_min

def calculate_xy_shift_by_BF(img1_down,img2_down,img1,img2,xy_shift):
    ################################################################
    xy_shift_min=np.zeros(0)
    loss_min=np.inf
    for x in range(-5,6):
        for y in range(-5,6):
            this_xy_shift=xy_shift+np.array([x,y],dtype='int32')
            #print(this_xy_shift)
            ovl1=img1_down[np.max((0,-this_xy_shift[1])):,np.max((0,-this_xy_shift[0])):]
            ovl2=img2_down[np.max((0,this_xy_shift[1])):,np.max((0,this_xy_shift[0])):]
            x_range,y_range=np.min((ovl1.shape[1],ovl2.shape[1])),np.min((ovl1.shape[0],ovl2.shape[0]))
            ovl1,ovl2=ovl1[0:y_range,0:x_range],ovl2[0:y_range,0:x_range]
            this_loss=loss_fun(ovl1,ovl2)
            if this_loss<loss_min:
                loss_min=this_loss
                xy_shift_min=this_xy_shift
            #print(this_xy_shift,this_loss)
    print('first ',xy_shift_min,loss_min)
    ################################################################
    img1_down=pyr_down_img(img1,2)
    img2_down=pyr_down_img(img2,2)
    xy_shift_whole=xy_shift_min*2
    xy_shift_min=np.zeros(0)
    loss_min=np.inf
    for x in range(-18,19,3):
        for y in range(-18,19,3):
            this_xy_shift=xy_shift_whole+np.array([x,y],dtype='int32')
            ovl1=img1_down[np.max((0,-this_xy_shift[1])):,np.max((0,-this_xy_shift[0])):]
            ovl2=img2_down[np.max((0,this_xy_shift[1])):,np.max((0,this_xy_shift[0])):]
            x_range,y_range=np.min((ovl1.shape[1],ovl2.shape[1],2000)),np.min((ovl1.shape[0],ovl2.shape[0],2000))
            ovl1,ovl2=ovl1[0:y_range,0:x_range],ovl2[0:y_range,0:x_range]
#             if x==0 and y==0:
#                 cv2.imshow('1',ovl1)
#                 cv2.imshow('2',ovl2)
#                 cv2.waitKey()
#                 cv2.destroyAllWindows()
            this_loss=loss_fun(ovl1,ovl2)
            if this_loss<loss_min:
                loss_min=this_loss
                xy_shift_min=this_xy_shift
    #print(xy_shift_min)
    xy_shift_whole=xy_shift_min
    for x in range(-2,3):
        for y in range(-2,3):
            this_xy_shift=xy_shift_whole+np.array([x,y],dtype='int32')
            ovl1=img1_down[np.max((0,-this_xy_shift[1])):,np.max((0,-this_xy_shift[0])):]
            ovl2=img2_down[np.max((0,this_xy_shift[1])):,np.max((0,this_xy_shift[0])):]
            x_range,y_range=np.min((ovl1.shape[1],ovl2.shape[1],2000)),np.min((ovl1.shape[0],ovl2.shape[0],2000))
            ovl1,ovl2=ovl1[0:y_range,0:x_range],ovl2[0:y_range,0:x_range]
            this_loss=loss_fun(ovl1,ovl2)
            if this_loss<loss_min:
                loss_min=this_loss
                xy_shift_min=this_xy_shift
    print('second ',xy_shift_min,loss_min)
    ################################################################
    xy_shift_whole=xy_shift_min*4
    xy_shift_min=np.zeros(0)
    loss_min=np.inf
    for x in range(-18,19,3):
        for y in range(-18,19,3):
            this_xy_shift=xy_shift_whole+np.array([x,y],dtype='int32')
            ovl1=img1[np.max((0,-this_xy_shift[1])):,np.max((0,-this_xy_shift[0])):]
            ovl2=img2[np.max((0,this_xy_shift[1])):,np.max((0,this_xy_shift[0])):]
            x_range,y_range=np.min((ovl1.shape[1],ovl2.shape[1],2000)),np.min((ovl1.shape[0],ovl2.shape[0],2000))
            ovl1,ovl2=ovl1[0:y_range,0:x_range],ovl2[0:y_range,0:x_range]
#             if x==0 and y==0:
#                 cv2.imshow('1',ovl1)
#                 cv2.imshow('2',ovl2)
#                 cv2.waitKey()
#                 cv2.destroyAllWindows()
            this_loss=loss_fun(ovl1,ovl2)
            if this_loss<loss_min:
                loss_min=this_loss
                xy_shift_min=this_xy_shift
    #print(xy_shift_min)
    xy_shift_whole=xy_shift_min
    for x in range(-2,3):
        for y in range(-2,3):
            this_xy_shift=xy_shift_whole+np.array([x,y],dtype='int32')
            ovl1=img1[np.max((0,-this_xy_shift[1])):,np.max((0,-this_xy_shift[0])):]
            ovl2=img2[np.max((0,this_xy_shift[1])):,np.max((0,this_xy_shift[0])):]
            x_range,y_range=np.min((ovl1.shape[1],ovl2.shape[1],2000)),np.min((ovl1.shape[0],ovl2.shape[0],2000))
            ovl1,ovl2=ovl1[0:y_range,0:x_range],ovl2[0:y_range,0:x_range]
            this_loss=loss_fun(ovl1,ovl2)
            if this_loss<loss_min:
                loss_min=this_loss
                xy_shift_min=this_xy_shift
    print('third ',xy_shift_min,loss_min)
    return xy_shift_min,loss_min

def update_lower_layer_info(xy_shift,axis_range1,axis_range2):
    axis_range=np.zeros((2,2))
    axis_range[0,0]=axis_range1[0,0]-xy_shift[0]
    axis_range[0,1]=axis_range2[0,1]-axis_range2[0,0]+axis_range1[0,0]-xy_shift[0]
    axis_range[1,0]=axis_range1[1,0]-xy_shift[1]
    axis_range[1,1]=axis_range2[1,1]-axis_range2[1,0]+axis_range1[1,0]-xy_shift[1]
    return axis_range

def adjust_contrast(img1,img2):
    img1,img2=img1.astype('float32'),img2.astype('float32')
    m1,m2=np.mean(img1),np.mean(img2)
    m=np.max((m1,m2,40))
    #print(m,m1,m2)
    img1=np.uint8(np.clip(m/m1*img1,0,255))
    img2=np.uint8(np.clip(m/m2*img2,0,255))
    return img1,img2

def get_axis_range_from_img(img_path,img_name):
    img=import_2D_img(img_path,img_name,0)
    axis_range=np.zeros((2,2))
    axis_range[1,1],axis_range[0,1]=img.shape[0],img.shape[1]
    return axis_range

def start_vertical_stitch(i,img_path,save_path,file_name,img_name):
    sift=cv2.xfeatures2d.SIFT_create()
    bf=cv2.BFMatcher()
    img_path1,img_path2=img_path+'\\'+file_name%(i),img_path+'\\'+file_name%(i+1)
    dim_elem_num1,dim_elem_num2=len(os.listdir(img_path1)),len(os.listdir(img_path1))
    axis_range1,axis_range2=get_axis_range_from_img(img_path1,img_name),get_axis_range_from_img(img_path1,img_name)
    
    img1=cv2.medianBlur(import_2D_img(img_path1,img_name,dim_elem_num1-1),3)
    
    step=1
    loss_min=np.inf
    index_min=-1
    xy_shift_min=np.zeros(2)
    for j in range(10,np.uint32(np.round(dim_elem_num2*0.2)),step):
        img2=cv2.medianBlur(import_2D_img(img_path2,img_name,j),3)
        img1_down=pyr_down_img(img1,3)
        img2_down=pyr_down_img(img2,3)
        img1_down,img2_down=adjust_contrast(img1_down,img2_down)
        kpts1,des1=sift.detectAndCompute(img1_down,None)
        kpts2,des2=sift.detectAndCompute(img2_down,None)
        kp1,kp2=np.float32([kp.pt for kp in kpts1]),np.float32([kp.pt for kp in kpts2])
        matches=bf.knnMatch(des1,des2,k=2)
        good_matches=[]
        good=[]
        for m in matches:
            if len(m)==2 and m[0].distance<0.75*m[1].distance:
                good_matches.append((m[0].queryIdx,m[0].trainIdx))
                good.append(m[0])
        img3=cv2.drawMatches(img1_down,kpts1,img2_down,kpts2,good,None,flags=2)
        cv2.imshow('3',img3)
        cv2.waitKey(3000)
        cv2.destroyAllWindows()
        pts1,pts2=np.float32([kp1[i,:] for (i,_) in good_matches]),np.float32([kp2[j,:] for (_,j) in good_matches])
        xy_shift,this_loss=calculate_xy_shift_by_RANSAC(img1_down,img2_down,pts1,pts2)
        print('%d.th layer for pyrDown image, xy_shift is %s, loss is %.8f'%(j,str(xy_shift),this_loss))
        xy_shift,this_loss=calculate_xy_shift_by_BF(img1_down,img2_down,img1,img2,xy_shift)
        print('%d.th layer for whole image, xy_shift is %s, loss is %.8f'%(j,str(xy_shift),this_loss))
        if this_loss<loss_min:
            loss_min=this_loss
            xy_shift_min[0],xy_shift_min[1]=xy_shift[0],xy_shift[1]
            index_min=j
    print('################################################################################')
    print('Finally the matched one is %d.th layer, xy_shift is %s, loss is %.8f'%(index_min,str(xy_shift_min),loss_min))
    print('################################################################################')
    
    print('start saving data')
    axis_range2=update_lower_layer_info(xy_shift_min,axis_range1,axis_range2)
    np.save(save_path+r'\first_last_index_%d.npy'%(i+1),np.array([index_min,dim_elem_num],dtype='uint32'))
    np.save(save_path+r'\axis_range_%d.npy'%(i+1),axis_range2)
    print('end saving data')

In [6]:
if __name__=='__main__':
    img_path=r'F:\20220811_VIP_Cre_DIOmS_RetromS_in_SCN'
    file_name=r'S%d'
    img_name=r''
    save_path=r'C:\Users\dingj\ZhaoLab\20220814_SiftVertStitch_WYL'
    layer_num=4
    for i in range(1,layer_num+1):
        dim_elem_num=len(os.listdir(img_path+'\\'+file_name%(i)))
        np.save('first_last_index_%d.npy'%(i),np.array([0,dim_elem_num-1]))
        print(dim_elem_num)
        if i==1:
            img=import_2D_img(img_path+'\\'+file_name%(i),img_name,0)
            axis_range=np.zeros((2,2))
            axis_range[1,1],axis_range[0,1]=img.shape[0],img.shape[1]
            print(axis_range)
            np.save(save_path+r'\axis_range_%d.npy'%(i),axis_range)
    for i in range(1,layer_num):
        start_vertical_stitch(i,img_path,save_path,file_name,img_name)

242
[[   0. 4254.]
 [   0. 4682.]]
242
242
241
10.th layer for pyrDown image, xy_shift is [-3  1], loss is 0.02286655
first  [-4 -2] 0.021653203
second  [-10   2] 0.12790072
third  [-20   5] 0.13964298
10.th layer for whole image, xy_shift is [-20   5], loss is 0.13964298
11.th layer for pyrDown image, xy_shift is [-2 -1], loss is 0.02215685
first  [-4 -2] 0.021483641
second  [-10   2] 0.1275784
third  [-20   4] 0.13849227
11.th layer for whole image, xy_shift is [-20   4], loss is 0.13849227
12.th layer for pyrDown image, xy_shift is [-4  1], loss is 0.02219105
first  [-4 -2] 0.021309536
second  [-9  2] 0.1269234
third  [-16   0] 0.1374781
12.th layer for whole image, xy_shift is [-16   0], loss is 0.13747810
13.th layer for pyrDown image, xy_shift is [-4  1], loss is 0.02210139
first  [-4 -2] 0.02126807
second  [-9  2] 0.1259591
third  [-16   2] 0.13645096
13.th layer for whole image, xy_shift is [-16   2], loss is 0.13645096
14.th layer for pyrDown image, xy_shift is [-2 -1], loss i

third  [-20 -44] 0.09630771
46.th layer for whole image, xy_shift is [-20 -44], loss is 0.09630771
47.th layer for pyrDown image, xy_shift is [-3 -4], loss is 0.01790887
first  [-4 -4] 0.017471464
second  [-8 -4] 0.09097625
third  [-12 -17] 0.09162921
47.th layer for whole image, xy_shift is [-12 -17], loss is 0.09162921
################################################################################
Finally the matched one is 47.th layer, xy_shift is [-12. -17.], loss is 0.09162921
################################################################################
start saving data
end saving data
10.th layer for pyrDown image, xy_shift is [-2  3], loss is 0.04519908
first  [-2  1] 0.04421405
second  [-3  3] 0.10909582
third  [-3 18] 0.12602542
10.th layer for whole image, xy_shift is [-3 18], loss is 0.12602542
11.th layer for pyrDown image, xy_shift is [-8  4], loss is 0.04885099
first  [-3  0] 0.04317061
second  [-3  3] 0.11050539
third  [-2 19] 0.12628277
11.th layer for whole image,

45.th layer for pyrDown image, xy_shift is [-2 -6], loss is 0.03162164
first  [-1 -2] 0.030312143
second  [-3 -2] 0.072264366
third  [2 1] 0.070448704
45.th layer for whole image, xy_shift is [2 1], loss is 0.07044870
46.th layer for pyrDown image, xy_shift is [-3 -2], loss is 0.03207040
first  [-1 -3] 0.030535484
second  [-3 -2] 0.07235057
third  [2 0] 0.069478005
46.th layer for whole image, xy_shift is [2 0], loss is 0.06947801
47.th layer for pyrDown image, xy_shift is [-2 -6], loss is 0.03202441
first  [-1 -2] 0.030514827
second  [ 0 -4] 0.07220561
third  [2 0] 0.06888654
47.th layer for whole image, xy_shift is [2 0], loss is 0.06888654
################################################################################
Finally the matched one is 47.th layer, xy_shift is [2. 0.], loss is 0.06888654
################################################################################
start saving data
end saving data
10.th layer for pyrDown image, xy_shift is [ 0 -4], loss is 0.03762732
fi

third  [-2 15] 0.06453622
43.th layer for whole image, xy_shift is [-2 15], loss is 0.06453622
44.th layer for pyrDown image, xy_shift is [ 4 -2], loss is 0.03288465
first  [0 1] 0.016205942
second  [0 2] 0.056846518
third  [-2 15] 0.06343351
44.th layer for whole image, xy_shift is [-2 15], loss is 0.06343351
45.th layer for pyrDown image, xy_shift is [0 2], loss is 0.01704556
first  [0 1] 0.014066865
second  [0 2] 0.051635925
third  [-2 14] 0.062462196
45.th layer for whole image, xy_shift is [-2 14], loss is 0.06246220
46.th layer for pyrDown image, xy_shift is [0 2], loss is 0.01595158
first  [0 1] 0.012539756
second  [0 2] 0.04743348
third  [ 5 13] 0.060842287
46.th layer for whole image, xy_shift is [ 5 13], loss is 0.06084229
47.th layer for pyrDown image, xy_shift is [0 1], loss is 0.01158753
first  [0 1] 0.011587532
second  [0 2] 0.04581683
third  [-2 13] 0.060289733
47.th layer for whole image, xy_shift is [-2 13], loss is 0.06028973
##########################################

In [7]:
img_path=r'F:\20220811_VIP_Cre_DIOmS_RetromS_in_SCN'
file_name=r'S%d'
img_name=r''
img_save_path=r'F:\20220811_VIP_Cre_DIOmS_RetromS_in_SCN\zstitched'
save_img_name=r'merged_stitched%.4d'
save_path=r'C:\Users\dingj\ZhaoLab\20220814_SiftVertStitch_WYL'
stitch_layer_num=4

axis_range_array=np.zeros((stitch_layer_num,4))
for i in range(1,stitch_layer_num+1):
    axis_range_array[i-1,:]=np.load(save_path+r'\axis_range_%d.npy'%(i)).reshape((1,-1))
xy_axis_range=np.zeros((2,2))
xy_axis_range[0,0],xy_axis_range[0,1]=np.min(axis_range_array[:,0])-2,np.max(axis_range_array[:,1])+100
xy_axis_range[1,0],xy_axis_range[1,1]=np.min(axis_range_array[:,2])-2,np.max(axis_range_array[:,3])+100
xy_voxel_num=np.int64(np.round((xy_axis_range[:,1]-xy_axis_range[:,0])+1))
print(xy_voxel_num)

img_num=0
for i in range(stitch_layer_num):
    first_last_index=np.load(save_path+r'\first_last_index_%d.npy'%(i+1))
    axis_range=np.load(save_path+r'\axis_range_%d.npy'%(i+1))
    this_img_path=img_path+'\\'+file_name%(i+1)
    x_th=np.int64(np.round((axis_range[0,0]-xy_axis_range[0,0])))
    y_th=np.int64(np.round((axis_range[1,0]-xy_axis_range[1,0])))
    for j in range(first_last_index[0],first_last_index[1]):
        this_img=np.zeros((xy_voxel_num[1::-1]),dtype='uint8')
        img_2D=import_2D_img(this_img_path,img_name,j)
        this_img[y_th:y_th+img_2D.shape[0],x_th:x_th+img_2D.shape[1]]=img_2D
        cv2.imwrite(r'%s\z%.4d.tif'%(img_save_path,img_num),this_img)
        img_num+=1

[4371 4815]
