In [4]:
from lxml import etree
import numpy as np
import time
import cv2
import multiprocessing
from multiprocessing import Value,Array,Process
import ctypes
import random
from numba import jit,njit
import sys

def get_img_info(save_path):
    dim_elem_num=np.load(save_path+r'\dim_elem_num.npy')
    dim_len=np.load(save_path+r'\dim_len.npy')
    voxel_len=np.load(save_path+r'\voxel_len.npy')
    return dim_elem_num,dim_len,voxel_len

def get_img_xml_info(xml_path,save_path):
    '''
    Read xml file and extract the information of dimensions and each tile
    return - (1)dim_elem_num - linspace(uint), the quantity of voxels for each dimension,
    (2)dim_len - linspace(float), the length for one 3D image,
    (3)voxel_len - lingspace(float), the length for a voxel,
    (4)tile_num - int, the quantity of tiles,
    (5)tile_field - array(uint), the identifier of each file
    (6)tile_pos - array(float), the XYZ position information of each tile
    '''
    parser=etree.XMLParser()
    my_et=etree.parse(xml_path,parser=parser)
    dim_attrib=my_et.xpath('//Dimensions/DimensionDescription')
    dim_elem_num=np.zeros(3,dtype='uint')
    dim_len=np.zeros(3)
    for i in range(3):
        dim_elem_num[i],dim_len[i]=dim_attrib[i].attrib['NumberOfElements'],dim_attrib[i].attrib['Length']
    voxel_len=dim_len/dim_elem_num
    np.save(save_path+r'\dim_elem_num.npy',dim_elem_num)
    np.save(save_path+r'\dim_len.npy',dim_len)
    np.save(save_path+r'\voxel_len.npy',voxel_len)
    return dim_elem_num,dim_len,voxel_len

def import_2D_img(img_path,img_name,z_th):
    one_img_name=r'%s\%s_z%.2d_RAW_ch00.tif'%(img_path,img_name,z_th)
    return cv2.imread(one_img_name,cv2.IMREAD_GRAYSCALE)

def pyr_down_img(img,times=1):
    img_down=cv2.pyrDown(img)
    for i in range(times-1):
        img_down=cv2.pyrDown(img_down)
    return img_down

def loss_fun(ovl1,ovl2):
    ovl1,ovl2=ovl1.astype('float32'),ovl2.astype('float32')
    loss=np.sum((ovl1-ovl2)**2)/np.sqrt(np.sum(ovl1**2)*np.sum(ovl2**2))
    return loss

def calculate_xy_shift_by_RANSAC(img1,img2,pts1,pts2,loss_threshold=1,sample_time=1000):
    count=0
    matches_num=pts1.shape[0]
    RANSAC_num=np.int32(np.min((2,matches_num*0.1)))
    loss_threshold=0.05
    loss_min=np.inf
    xy_shift_min=np.zeros(2)
    while(loss_min>loss_threshold and count<sample_time):
        count+=1
        index_list=random.sample(range(matches_num),RANSAC_num)
        xy_shift_all=pts2[index_list,:]-pts1[index_list,:]
        max_shift,min_shift=np.max(xy_shift_all,axis=0),np.min(xy_shift_all,axis=0)
        if any((max_shift-min_shift)>50):
            continue
        xy_shift=np.int32(np.round(np.mean(xy_shift_all,axis=0)))#XY
        if all(xy_shift==xy_shift_min):
            continue
        ovl1,ovl2=img1[np.max((0,-xy_shift[1])):,np.max((0,-xy_shift[0])):],img2[np.max((0,xy_shift[1])):,np.max((0,xy_shift[0])):]
        x_range,y_range=np.min((ovl1.shape[1],ovl2.shape[1])),np.min((ovl1.shape[0],ovl2.shape[0]))
        ovl1,ovl2=ovl1[0:y_range,0:x_range],ovl2[0:y_range,0:x_range]
        this_loss=loss_fun(ovl1,ovl2)
        #print(xy_shift,this_loss)
        if this_loss<loss_min:
            loss_min=this_loss
            xy_shift_min=xy_shift
    #print(xy_shift_min,loss_min)
    return xy_shift_min,loss_min

def calculate_xy_shift_by_BF(img1_down,img2_down,img1,img2,xy_shift):
    ################################################################
    xy_shift_min=np.zeros(0)
    loss_min=np.inf
    for x in range(-5,6):
        for y in range(-5,6):
            this_xy_shift=xy_shift+np.array([x,y],dtype='int32')
            ovl1=img1_down[np.max((0,-this_xy_shift[1])):,np.max((0,-this_xy_shift[0])):]
            ovl2=img2_down[np.max((0,this_xy_shift[1])):,np.max((0,this_xy_shift[0])):]
            x_range,y_range=np.min((ovl1.shape[1],ovl2.shape[1])),np.min((ovl1.shape[0],ovl2.shape[0]))
            ovl1,ovl2=ovl1[0:y_range,0:x_range],ovl2[0:y_range,0:x_range]
            this_loss=loss_fun(ovl1,ovl2)
            if this_loss<loss_min:
                loss_min=this_loss
                xy_shift_min=this_xy_shift
            #print(this_xy_shift,this_loss)
    #print('first ',xy_shift_min,loss_min)
    ################################################################
    xy_shift_whole=xy_shift_min*4
    xy_shift_min=np.zeros(0)
    loss_min=np.inf
    for x in range(-18,19,3):
        for y in range(-18,19,3):
            this_xy_shift=xy_shift_whole+np.array([x,y],dtype='int32')
            ovl1=img1[np.max((0,-this_xy_shift[1])):,np.max((0,-this_xy_shift[0])):]
            ovl2=img2[np.max((0,this_xy_shift[1])):,np.max((0,this_xy_shift[0])):]
            x_range,y_range=np.min((ovl1.shape[1],ovl2.shape[1],2000)),np.min((ovl1.shape[0],ovl2.shape[0],2000))
            ovl1,ovl2=ovl1[0:y_range,0:x_range],ovl2[0:y_range,0:x_range]
#             if x==0 and y==0:
#                 cv2.imshow('1',ovl1)
#                 cv2.imshow('2',ovl2)
#                 cv2.waitKey()
#                 cv2.destroyAllWindows()
            this_loss=loss_fun(ovl1,ovl2)
            if this_loss<loss_min:
                loss_min=this_loss
                xy_shift_min=this_xy_shift
    #print(xy_shift_min)
    xy_shift_whole=xy_shift_min
    for x in range(-2,3):
        for y in range(-2,3):
            this_xy_shift=xy_shift_whole+np.array([x,y],dtype='int32')
            ovl1=img1[np.max((0,-this_xy_shift[1])):,np.max((0,-this_xy_shift[0])):]
            ovl2=img2[np.max((0,this_xy_shift[1])):,np.max((0,this_xy_shift[0])):]
            x_range,y_range=np.min((ovl1.shape[1],ovl2.shape[1],2000)),np.min((ovl1.shape[0],ovl2.shape[0],2000))
            ovl1,ovl2=ovl1[0:y_range,0:x_range],ovl2[0:y_range,0:x_range]
            this_loss=loss_fun(ovl1,ovl2)
            if this_loss<loss_min:
                loss_min=this_loss
                xy_shift_min=this_xy_shift
    #print('second ',xy_shift_min,loss_min)
    return xy_shift_min,loss_min

def update_lower_layer_info(xy_shift,axis_range1,axis_range2,dim_elem_num,voxel_len):
    axis_range2[0,0]=-axis_range2[0,0]+axis_range1[0,0]-xy_shift[0]*voxel_len[0]
    axis_range2[0,1]=-axis_range2[0,0]+axis_range1[0,0]-xy_shift[0]*voxel_len[0]
    axis_range2[1,0]=-axis_range2[1,0]+axis_range1[1,0]-xy_shift[1]*voxel_len[1]
    axis_range2[1,1]=-axis_range2[1,0]+axis_range1[1,0]-xy_shift[1]*voxel_len[1]
    voxel_num=np.uint64(np.round((axis_range2[:,1]-axis_range2[:,0])/voxel_len))
    return axis_range2,voxel_num

def start_vertical_stitch(i,img_path,save_path,file_name,img_name):
    sift=cv2.xfeatures2d.SIFT_create()
    bf=cv2.BFMatcher()
    img_path1,img_path2=img_path+'\\'+file_name%(i),img_path+'\\'+file_name%(i+1)
    dim_elem_num,dim_len,voxel_len=get_img_info(save_path)
    axis_range1,axis_range2=np.load(save_path+r'\axis_range_stitch_%.4d.npy'%(i)),np.load(save_path+r'\axis_range_%.4d.npy'%(i+1))
    voxel_num1,voxel_num2=np.uint64(np.round((axis_range1[:,1]-axis_range1[:,0])/voxel_len)),np.uint64(np.round((axis_range2[:,1]-axis_range2[:,0])/voxel_len))
    index1,index2=np.load(save_path+r'\first_last_index_%.4d.npy'%(i)),np.load(save_path+r'\first_last_index_%.4d.npy'%(i+1))
    index1,index2=index1[1],index2[0]
    img1=import_2D_img(img_path1,img_name,index1)
    img1_down=img1_down=pyr_down_img(img1,1)
    
    step=1
    loss_min=np.inf
    index_min=-1
    xy_shift_min=np.zeros(0)
    for j in range(index2,np.int32(np.round(voxel_num1[2]*0.3)+index2),step):
        img2=import_2D_img(img_path2,img_name,j)
        img2_down=pyr_down_img(img2,1)
        kpts1,des1=sift.detectAndCompute(img1_down,None)
        kpts2,des2=sift.detectAndCompute(img2_down,None)
        kp1,kp2=np.float32([kp.pt for kp in kpts1]),np.float32([kp.pt for kp in kpts2])
        matches=bf.knnMatch(des1,des2,k=2)
        good_matches=[]
        for m in matches:
            if len(m)==2 and m[0].distance<0.75*m[1].distance:
                good_matches.append((m[0].queryIdx,m[0].trainIdx))
        pts1,pts2=np.float32([kp1[i,:] for (i,_) in good_matches]),np.float32([kp2[j,:] for (_,j) in good_matches])
        xy_shift,this_loss=calculate_xy_shift_by_RANSAC(img1_down,img2_down,pts1,pts2)
        print('%d.th layer for pyrDown image, xy_shift is %s, loss is %.8f'%(j,str(xy_shift),this_loss))
        xy_shift,this_loss=calculate_xy_shift_by_BF(img1_down,img2_down,img1,img2,xy_shift)
        print('%d.th layer for whole image, xy_shift is %s, loss is %.8f'%(j,str(xy_shift),this_loss))
        if this_loss<loss_min:
            loss_min=this_loss
            xy_shift_min=xy_shift
            index_min=j
    print('!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!')
    print('Finally the matched one is %d.th layer, xy_shift is %s, loss is %.8f'%(index_min,str(xy_shift_min),loss_min))
    print('!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!')
    
    index2=np.load(save_path+r'\first_last_index_%.4d.npy'%(i+1))
    index2[0]=index_min
    tile_pos2,axis_range2,voxel_num2=update_lower_layer_info(xy_shift,tile_pos2,axis_range1,axis_range2,dim_elem_num,voxel_len)
    print('start saving data')
    np.save(save_path+r'\axis_range_stitch_%.4d.npy'%(i+1),axis_range2)
    np.save(save_path+r'\first_last_index_stitch_%.4d.npy'%(i+1),index2)
    np.save(save_path+r'\tile_pos_stitch_%.4d.npy'%(i+1),tile_pos2)
    print('end saving data')
    
def adjust_contrast(img1,img2):
    img1,img2=img1.astype('float32'),img2.astype('float32')
    m1,m2=np.mean(img1),np.mean(img2)
    m=np.max((m1,m2))
    img1=np.uint8(np.clip(m/m1*img1,0,255))
    img2=np.uint8(np.clip(m/m2*img2,0,255))
    return img1,img2

In [None]:
xml_path=r'C:\Users\dingj\ZhaoLab\Project1\MetaData\Region 1_Merged.xml'
img_path=r'C:\Users\dingj\ZhaoLab'
file_name=r'Project%d'
img_name=r'Region 1_Merged'
save_path=r'C:\Users\dingj\ZhaoLab\20220806_SIFTvertStitch'
sift=cv2.xfeatures2d.SIFT_create()
bf=cv2.BFMatcher()
i=1


In [5]:
xml_path=r'C:\Users\dingj\ZhaoLab\Project1\MetaData\Region 1_Merged.xml'
img_path=r'C:\Users\dingj\ZhaoLab'
file_name=r'Project%d'
img_name=r'Region 1_Merged'
save_path=r'C:\Users\dingj\ZhaoLab\20220806_SIFTvertStitch'
sift=cv2.xfeatures2d.SIFT_create()
bf=cv2.BFMatcher()
i=2
img_path1,img_path2=img_path+'\\'+file_name%(i),img_path+'\\'+file_name%(i+1)
dim_elem_num,dim_len,voxel_len=get_img_info(save_path)
axis_range1,axis_range2=np.load(save_path+r'\axis_range_stitch_%.4d.npy'%(i)),np.load(save_path+r'\axis_range_%.4d.npy'%(i+1))
voxel_num1,voxel_num2=np.uint64(np.round((axis_range1[:,1]-axis_range1[:,0])/voxel_len)),np.uint64(np.round((axis_range2[:,1]-axis_range2[:,0])/voxel_len))
index1,index2=np.load(save_path+r'\first_last_index_%.4d.npy'%(i)),np.load(save_path+r'\first_last_index_%.4d.npy'%(i+1))
index1,index2=index1[1],index2[0]
img1=import_2D_img(img_path1,img_name,index1)
img1_down=img1_down=pyr_down_img(img1,2)

step=1
loss_min=np.inf
index_min=-1
xy_shift_min=np.zeros(0)
for j in range(25,45):
    img2=import_2D_img(img_path2,img_name,j)
    img2_down=pyr_down_img(img2,2)
    img1,img2=adjust_contrast(img1,img2)
    img1_down,img2_down=adjust_contrast(img1_down,img2_down)
    kpts1,des1=sift.detectAndCompute(img1_down,None)
    kpts2,des2=sift.detectAndCompute(img2_down,None)
    kp1,kp2=np.float32([kp.pt for kp in kpts1]),np.float32([kp.pt for kp in kpts2])
    matches=bf.knnMatch(des1,des2,k=2)
    good_matches=[]
    good=[]
    for m in matches:
        if len(m)==2 and m[0].distance<0.75*m[1].distance:
            good_matches.append((m[0].queryIdx,m[0].trainIdx))
            good.append(m[0])
    img3=cv2.drawMatches(img1_down,kpts1,img2_down,kpts2,good,None,flags=2)
    cv2.imshow('3',img3)
    cv2.waitKey(3000)
    cv2.destroyAllWindows()
    pts1,pts2=np.float32([kp1[i,:] for (i,_) in good_matches]),np.float32([kp2[j,:] for (_,j) in good_matches])
    xy_shift,this_loss=calculate_xy_shift_by_RANSAC(img1_down,img2_down,pts1,pts2)
    print('%d.th layer for pyrDown image, xy_shift is %s, loss is %.8f'%(j,str(xy_shift),this_loss))
    xy_shift,this_loss=calculate_xy_shift_by_BF(img1_down,img2_down,img1,img2,xy_shift)
    print('%d.th layer for whole image, xy_shift is %s, loss is %.8f'%(j,str(xy_shift),this_loss))
    if this_loss<loss_min:
        loss_min=this_loss
        xy_shift_min=xy_shift
        index_min=j
print('!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!')
print('Finally the matched one is %d.th layer, xy_shift is %s, loss is %.8f'%(index_min,str(xy_shift_min),loss_min))
print('!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!')

index2=np.load(save_path+r'\first_last_index_%.4d.npy'%(i+1))
index2[0]=index_min
axis_range2,voxel_num2=update_lower_layer_info(xy_shift,axis_range1,axis_range2,dim_elem_num,voxel_len)
print('start saving data')
np.save(save_path+r'\axis_range_stitch_%.4d.npy'%(i+1),axis_range2)
np.save(save_path+r'\first_last_index_stitch_%.4d.npy'%(i+1),index2)
print('end saving data')

25.th layer for pyrDown image, xy_shift is [-7 17], loss is 0.56074208
25.th layer for whole image, xy_shift is [-16  81], loss is 0.57778776
26.th layer for pyrDown image, xy_shift is [-8 22], loss is 0.50946945
26.th layer for whole image, xy_shift is [-18  78], loss is 0.48364830
27.th layer for pyrDown image, xy_shift is [-7 17], loss is 0.41899672
27.th layer for whole image, xy_shift is [-19  77], loss is 0.42954984
28.th layer for pyrDown image, xy_shift is [-8 16], loss is 0.38128644
28.th layer for whole image, xy_shift is [-20  75], loss is 0.38492319
29.th layer for pyrDown image, xy_shift is [-8 17], loss is 0.34649470
29.th layer for whole image, xy_shift is [-20  75], loss is 0.34594545
30.th layer for pyrDown image, xy_shift is [-7 17], loss is 0.31658813
30.th layer for whole image, xy_shift is [-21  74], loss is 0.31181121
31.th layer for pyrDown image, xy_shift is [-7 17], loss is 0.28997391
31.th layer for whole image, xy_shift is [-21  74], loss is 0.27972698
32.th 

In [42]:
index1

74

In [4]:
cv2.imshow('1',img2_down)
cv2.waitKey()
cv2.destroyAllWindows()

In [8]:
file_name='TileScan 1'
xml_path=r'C:\Users\admin\DJY\Project1\MetaData\Region 1_Merged.xml'
img_path=r'C:\Users\admin\DJY\Project1'
save_path=r'C:\Users\admin\DJY\20220806_SIFTvertStitch'
img_name=r'Region 1_Merged'

In [47]:
xml_path=r'C:\Users\admin\DJY\Project3\MetaData\Region 2_Merged.xml'
dim_elem_num,dim_len,voxel_len=get_img_xml_info(xml_path,save_path)

In [65]:
index1,index2

(82, 0)

In [49]:
axis_range=np.zeros((3,2))
for i in range(3):
    axis_range[i,1]=dim_len[i]
axis_range

array([[0.        , 0.00939157],
       [0.        , 0.01071177],
       [0.        , 0.00029499]])

In [49]:
#np.save('axis_range_%.4d.npy'%(3),axis_range)
np.load('first_last_index_stitch_%.4d.npy'%(2))

array([25, 73], dtype=uint32)

In [None]:
 np.save(save_path+r'\axis_range_zstitch_%.4d.npy'%(i+1),axis_range2)

In [18]:
img=import_2D_img(img_path,img_name,60)
cv2.imshow('1',img)
cv2.waitKey()
cv2.destroyAllWindows()

In [19]:
img=pyr_down_img(img)
cv2.imshow('1',img)
cv2.waitKey()
cv2.destroyAllWindows()