In [1]:
from lxml import etree
import numpy as np
import time
import cv2
import multiprocessing
from multiprocessing import Value,Array,Process
import ctypes
import random
from numba import jit,njit

#functions
def get_img_xml_info(xml_path):
    '''
    Read xml file and extract the information of dimensions and each tile
    return - (1)dim_elem_num - linspace(uint), the quantity of voxels for each dimension,
    (2)dim_len - linspace(float), the length for one 3D image,
    (3)voxel_len - lingspace(float), the length for a voxel,
    (4)tile_num - int, the quantity of tiles,
    (5)tile_field - array(uint), the identifier of each file
    (6)tile_pos - array(float), the XYZ position information of each tile
    '''
    parser=etree.XMLParser()
    my_et=etree.parse(xml_path,parser=parser)
    dim_attrib=my_et.xpath('//Dimensions/DimensionDescription')
    dim_elem_num=np.zeros(3,dtype='uint')
    dim_len=np.zeros(3)
    for i in range(3):
        dim_elem_num[i],dim_len[i]=dim_attrib[i].attrib['NumberOfElements'],dim_attrib[i].attrib['Length']
        dim_len[i]=dim_attrib[i].attrib['Length']
    voxel_len=dim_len/dim_elem_num
    tile_attrib=my_et.xpath('//Attachment/Tile')
    tile_num=len(tile_attrib)
    tile_field=np.zeros((tile_num,2),dtype='uint')
    tile_pos=np.zeros((tile_num,3))
    for i in range(tile_num):
        tile_field[i,:]=[tile_attrib[i].attrib['FieldX'],tile_attrib[i].attrib['FieldY']]
        tile_pos[i,:]=[tile_attrib[i].attrib['PosX'],tile_attrib[i].attrib['PosY'],tile_attrib[i].attrib['PosZ']]
    return dim_elem_num,dim_len,voxel_len,tile_num,tile_field,tile_pos

def judge_tile_contact(dim_len,tile_pos):
    '''
    judge if two tiles contact with each other
    return - 3Darray, the XY contact array for each two images
    dim_len - linspace(float), the length for one 3D image
    tile_pos - array(float), the XYZ position information of each tile
    '''
    tile_num=tile_pos.shape[0]
    tile_contact=np.zeros((2,tile_num,tile_num),dtype='bool')
    for i in range(tile_num):
        for j in range(tile_num):
            if np.sum(np.abs(tile_pos[i,:]-tile_pos[j,:])<dim_len*np.array([1,0.3,0.3]))==3:
                tile_contact[0,i,j]=True
            if np.sum(np.abs(tile_pos[i,:]-tile_pos[j,:])<dim_len*np.array([0.3,1,0.3]))==3:
                tile_contact[1,i,j]=True
            if i==j:
                tile_contact[:,i,j]=np.array([False,False])
    return tile_contact

def import_img(img_path,ordinal,dim_elem_num):
    '''
    this function reads voxel information and return a 3D np_array.
    return - array, store the 3D image
    img_path - str, the file position,
    ordinal - int, the ordinal number for image,
    dim_elem_num - list, the quantities of voxels for each dimension.
    '''
    voxel_array=np.zeros(tuple(dim_elem_num),dtype='uint8')#the array for storing image, dtyte should be changed according to image type
    #next statements get the img information according to image names, need to be changed according to different naming methods
    for i in range(dim_elem_num[2]):
        img_name=r'%s\Region 1_s%.4d_z%.3d_RAW_ch00.tif'%(img_path,ordinal,i)
        voxel_array[:,:,i]=cv2.imread(img_name,cv2.IMREAD_GRAYSCALE)
    return voxel_array

def get_2img_border(dim_elem_num,dim_len,voxel_len,tile_pos):
    '''
    get the border voxel index for two overlapping images
    return - array, the x/y/z_min/max voxel ordinal for each image,
    dim_elem_num - list, the quantities of voxels for each dimension,
    dim_len - list, the image length,
    tile_pos - array, xyz positions of each img.
    '''
    #x/y/z_min/max, the positions of overlapping image border
    x_min,x_max=np.max(tile_pos[:,0]),np.min(tile_pos[:,0])+dim_len[0]
    y_min,y_max=np.max(tile_pos[:,1]),np.min(tile_pos[:,1])+dim_len[1]
    z_min,z_max=np.max(tile_pos[:,2]),np.min(tile_pos[:,2])+dim_len[2]
    #x/y/zv_min/max, the voxel index of overlapping image border
    xv1_min,xv1_max=np.round((x_min-tile_pos[0,0])/voxel_len[0]),np.round((x_max-tile_pos[0,0])/voxel_len[0])
    yv1_min,yv1_max=np.round((y_min-tile_pos[0,1])/voxel_len[1]),np.round((y_max-tile_pos[0,1])/voxel_len[1])
    zv1_min,zv1_max=np.round((z_min-tile_pos[0,2])/voxel_len[2]),np.round((z_max-tile_pos[0,2])/voxel_len[2])
    xv2_min,xv2_max=np.round((x_min-tile_pos[1,0])/voxel_len[0]),np.round((x_max-tile_pos[1,0])/voxel_len[0])
    yv2_min,yv2_max=np.round((y_min-tile_pos[1,1])/voxel_len[1]),np.round((y_max-tile_pos[1,1])/voxel_len[1])
    zv2_min,zv2_max=np.round((z_min-tile_pos[1,2])/voxel_len[2]),np.round((z_max-tile_pos[1,2])/voxel_len[2])
    voxel_border=np.array([[xv1_min,xv1_max,yv1_min,yv1_max,zv1_min,zv1_max],
              [xv2_min,xv2_max,yv2_min,yv2_max,zv2_min,zv2_max]],dtype='uint')
    return voxel_border

def get_2img_border_after_shift(dim_elem_num,voxel_border,xyz_shift):
    '''
    this function calculates the border of two partly overlapping images after translation
    return - array, the voxel index of overlapping area for each image
    dim_elem_num - the voxel quantities for each dimension
    voxel_border - array, the voxel index before translation
    xyz_shift - list, the translation for each dimension
    '''
    border_after_shift=np.zeros((2,6),dtype='uint')
    for i in range(3):
        if voxel_border[0,2*i]<voxel_border[1,2*i]:
            border_after_shift[0,2*i],border_after_shift[0,2*i+1]=0,np.max([0,voxel_border[0,2*i+1]+xyz_shift[i]])
            border_after_shift[1,2*i],border_after_shift[1,2*i+1]=np.min([dim_elem_num[i],voxel_border[1,2*i]-xyz_shift[i]]),dim_elem_num[i]
        elif voxel_border[0,2*i]>voxel_border[1,2*i]:
            border_after_shift[0,2*i],border_after_shift[0,2*i+1]=np.min([dim_elem_num[i],voxel_border[0,2*i]+xyz_shift[i]]),dim_elem_num[i]
            border_after_shift[1,2*i],border_after_shift[1,2*i+1]=0,np.max([0,voxel_border[1,2*i+1]-xyz_shift[i]])
        else:
            if xyz_shift[i]<0:
                border_after_shift[0,2*i],border_after_shift[0,2*i+1]=0,voxel_border[0,2*i+1]+xyz_shift[i]
                border_after_shift[1,2*i],border_after_shift[1,2*i+1]=-xyz_shift[i],voxel_border[1,2*i+1]
            else:
                border_after_shift[0,2*i],border_after_shift[0,2*i+1]=xyz_shift[i],voxel_border[0,2*i+1]
                border_after_shift[1,2*i],border_after_shift[1,2*i+1]=0,voxel_border[1,2*i+1]-xyz_shift[i]
    return border_after_shift

def choose_reference_tile(tile_contact_array,if_tile_stitched):
    '''
    choose best reference tile for i.th tile
    return - tuple, int (2).
    tile_contact_array - array, bool (2,n).
    if tile_stitched - list (n).
    '''
    j,k=-1,-1
    index_j=[i for i,j in enumerate(tile_contact_array[0,:] and if_tile_stitched) if j==True]
    index_k=[i for i,j in enumerate(tile_contact_array[1,:] and if_tile_stitched) if j==True]
    if len(index_j)!=0:
        j=index_j[0]
    if len(index_k)!=0:
        k=index_k[0]
    return j,k

def adjust_contrast(ovl1,ovl2):
    ovl1,ovl2=ovl1.astype('float32'),ovl2.astype('float32')
    m1,m2=np.mean(ovl1),np.mean(ovl2)
    if m1<5 or m2<5:
        ovl1,ovl2=10/m1*ovl1,10/m2*ovl2
    elif m1<10 or m2<10:
        ovl1,ovl2=20/m1*ovl1,20/m2*ovl2
    elif m1<20 or m2<20:
        ovl1,ovl2=30/m1*ovl1,30/m2*ovl2
    elif m1<30 or m2<30:
        ovl1,ovl2=40/m1*ovl1,40/m2*ovl2
    elif np.abs(m1-m2)>5:
        m_max=np.max(np.array([m1,m2]))
        ovl1,ovl2=m_max/m1*ovl1,m_max/m2*ovl2
    ovl1,ovl2=np.clip(ovl1,0,255),np.clip(ovl2,0,255)
    ovl1,ovl2=cv2.GaussianBlur(ovl1,(3,3),0),cv2.GaussianBlur(ovl2,(3,3),0)
    return ovl1.astype('uint8'),ovl2.astype('uint8')

def compute_sift_and_loss(ovl1,ovl2,sift):
    kp1,des1=sift.detectAndCompute(ovl1,None)
    kp2,des2=sift.detectAndCompute(ovl2,None)
    
    return xyz_shift,loss

def loss_fun():
    return

def affine_trans_array_to_shift():
    return

def calculate_xyz_shift(sift,ids,voxel_range,step):
    img0=import_img(img_path,ids[0],dim_elem_num)
    if ids[1]!=-1:#根据X方向计算
        img1=import_img(img_path,ids[1],dim_elem_num)
        border1=get_2img_border(dim_elem_num,dim_len,voxel_len,tile_pos[[ids[0],ids[1]],:])
        loss_min=np.inf
        xyz_shift=None
        for x in range(-voxel_range,voxel_range+1,step):
            border1_s=get_2img_border_after_shift(dim_elem_num,border1,[x,0,0])
            this_xyz,this_loss=compute_sift_and_loss(img0[border1_s[0,2]:border1_s[0,3],border1_s[0,0],border1_s[0,4]:border1_s[0,5]],
                                                     img1[border1_s[1,2]:border1_s[1,3],border1_s[1,0],border1_s[1,4]:border1_s[1,5]],
                                                     sift)
        
        
    if ids[2]!=-1:
        img2=import_img(img_path,ids[2],dim_elem_num)
        
    
    return

def run_sift_stitcher(lock):
    stitch_num=0
    sift=cv2.xfeatures2d.SIFT_create()
    while(True in if_tile_stiched):
        lock.acquire()
        usable_tile_index=[i for i,j in enumerate(if_tile_stitched or if_tile_shelved) if j==False]
        if len(usable_tile_index)==0:
            for i in range(tile_num):
                if_tile_shelved[i]=False
            print('All shelved tile has been released')
            continue
        i=usable_tile_index[0]
        j,k=choose_reference_tile(tile_contact[:,i,:],if_tile_stitched)
        if j==-1 and k==-1:
            if_tile_shelved[i]=True
            lock.release()
            print('%d.th tile has no appropriate contact tile'%(i))
            continue
        elif j!=-1 and k==-1:
            if_tile_stitched[i]=True
            tile_pos_index[3*i:3*i+3]=np.array([j,j,j])
        elif j==-1 and k!=-1:
            if_tile_stitched[i]=True
            tile_pos_index[3*i:3*i+3]=np.array([k,k,k])
        else:
            if_tile_stitched[i]=True
            tile_pos_index[3*i:3*i+2]=np.array([j,k])
        lock.release()
        xyz_shift,z_index=calculate_xyz_shift(sift,[i,j,k])
        lock.acquire()
        tile_pos_stitch[3*i,3*i+3]=xyz_shift
        tile_pos_index[3*i+2]=z_index
        lock.release()
        stitch_num+=1
        print('')
    print('%s stops and has stitch %d tiles.'%(multiprocessing.current_process().name,stitch_num))

def update_pos():
    return

def start_multi_stitchers():
    dim_elem_num,dim_len,voxel_len,tile_num,tile_field,tile_pos=get_img_xml_info(xml_path)
    tile_contact=judge_tile_contact(dim_len,tile_pos)
    lock=multiprocessing.RLock()
    if_tile_stitched=Array(ctypes.c_bool,[False for i in range(tile_num)])
    if_tile_stitched[0]=True
    if_tile_shelved=Array(ctypes.c_bool,[False for i in range(tile_num)])
    tile_pos_index=Array('i',[-1 for i in range(tile_num*3)])
    tile_pos_stitch=Array('d',[0 for i in range(tile_num*3)])
    process_num=round(0.4*multiprocessing.cpu_count())
    print('Current processing quantities: %d'%(process_num))
    process_list=[]
    for i in range(process_num):
        one_pro=multiprocessing.Process(target=run_sift_stitcher,
                                        args=())
        one_pro.start()
        process_list.append(one_pro)
    for i in process_list:
        i.join()
    return

def update_xml(xml_path,tile_pos_final):
    return

In [2]:
if __name__=='__main__':
    xml_path=r'D:\Albert\Data\ZhaoLab\Imaging\20220219_Thy1_EGFP_M_high_resolution_40X_10overlap_50G\MetaData\Region 1.xml'
    img_path=r'D:\Albert\Data\ZhaoLab\Imaging\20220219_Thy1_EGFP_M_high_resolution_40X_10overlap_50G'
    dim_elem_num,dim_len,voxel_len,tile_num,tile_field,tile_pos=get_img_xml_info(xml_path)
    tile_contact=judge_tile_contact(dim_len,tile_pos)

KeyboardInterrupt: 

In [4]:
i,j=650,651
img1=import_img(img_path,i,dim_elem_num)
img2=import_img(img_path,j,dim_elem_num)
border=get_2img_border(dim_elem_num,dim_len,voxel_len,tile_pos[[i,j],:])
border=get_2img_border_after_shift(dim_elem_num,border,np.array([-5,6,0]))
print(border)
ovl1=img1[border[0,2]:border[0,3],border[0,0],border[0,4]:border[0,5]]
ovl2=img2[border[1,2]:border[1,3],border[1,0],border[1,4]:border[1,5]]
ovl1,ovl2=adjust_contrast(ovl1,ovl2)
cv2.imshow('1',ovl1)
cv2.imshow('2',ovl2)
cv2.waitKey()
cv2.destroyAllWindows()
sift=cv2.xfeatures2d.SIFT_create()
#img3=cv2.drawMatchesKnn(ovl1,kp1,ovl2,kp2,good,None,flags=2)
#cv2.imshow('3',img3)
#cv2.waitKey()
#cv2.destroyAllWindows()

[[456 512   6 512   0 126]
 [  0  56   0 506   0 126]]


In [44]:
matches[0][0].queryIdx
#query是img1，train是img2

0

In [52]:
kp1,des1=sift.detectAndCompute(ovl1,None)
kp2,des2=sift.detectAndCompute(ovl2,None)
matches=bf.knnMatch(des1,des2,k=2)
good=[]
for m in matches:
    if len(m)==2 and m[0].distance<0.75*m[1].distance:
        good.append(m[0])
img3=cv2.drawMatches(ovl1,kp1,ovl2,kp2,good,None,flags=2)
cv2.imshow('3',img3)
cv2.waitKey()
cv2.destroyAllWindows()

In [64]:
pts1,pts2

(array([[ 56.896244, 221.67865 ],
        [ 64.54202 , 202.09543 ],
        [ 68.21898 , 273.5331  ],
        [ 89.62303 , 114.346466],
        [ 89.62303 , 114.346466],
        [ 95.87566 , 466.323   ],
        [ 96.63001 , 498.4076  ],
        [ 97.57501 , 447.06204 ],
        [104.400444, 493.83417 ],
        [104.911064, 153.31549 ],
        [106.307556, 463.19763 ],
        [109.632   , 153.45854 ],
        [110.15449 , 453.5782  ],
        [121.51503 , 387.59186 ],
        [121.67128 , 191.4723  ],
        [122.97225 ,  41.816624]], dtype=float32),
 array([[ 56.609116, 221.51585 ],
        [ 64.817   , 200.40761 ],
        [ 75.11885 , 456.66498 ],
        [ 88.56638 , 114.71671 ],
        [ 88.56638 , 114.71671 ],
        [ 96.09102 , 465.00543 ],
        [ 96.700905, 497.91245 ],
        [ 97.15104 , 444.8361  ],
        [104.35014 , 494.40057 ],
        [ 54.30769 , 248.50444 ],
        [105.82382 , 463.7349  ],
        [104.098045, 151.01514 ],
        [109.301834, 455.18204 

In [None]:
kp1,des1=sift.detectAndCompute(ovl1,None)
kp2,des2=sift.detectAndCompute(ovl2,None)
kp1,kp2=np.float32([kp.pt for kp in kp1]),np.float32([kp.pt for kp in kp2])
bf=cv2.BFMatcher()
matches=bf.knnMatch(des1,des2,k=2)
good=[]
for m in matches:
    if len(m)==2 and m[0].distance<0.75*m[1].distance:
        good.append((m[0].trainIdx,m[0].queryIdx))
if len(good)>4:
    pts1=np.float32([kp1[j,:] for (_,j) in good])
    pts2=np.float32([kp2[i,:] for (i,_) in good])
else:
    print(0)
print(pts1)
print(pts2)
H,status=cv2.findHomography(pts1,pts2,cv2.RANSAC,5)
result=cv2.warpPerspective(ovl1,H,(ovl1.shape[1]*2,ovl1.shape[0]))
#result[:,ovl1.shape[1]:2*ovl1.shape[1]]=ovl2
cv2.imshow('3',result)
cv2.waitKey()
cv2.destroyAllWindows()

[[ 56.896244 221.67865 ]
 [ 64.54202  202.09543 ]
 [ 68.21898  273.5331  ]
 [ 89.62303  114.346466]
 [ 89.62303  114.346466]
 [ 95.87566  466.323   ]
 [ 96.63001  498.4076  ]
 [ 97.57501  447.06204 ]
 [104.400444 493.83417 ]
 [104.911064 153.31549 ]
 [106.307556 463.19763 ]
 [109.632    153.45854 ]
 [110.15449  453.5782  ]
 [121.51503  387.59186 ]
 [121.67128  191.4723  ]
 [122.97225   41.816624]]
[[ 56.609116 221.51585 ]
 [ 64.817    200.40761 ]
 [ 75.11885  456.66498 ]
 [ 88.56638  114.71671 ]
 [ 88.56638  114.71671 ]
 [ 96.09102  465.00543 ]
 [ 96.700905 497.91245 ]
 [ 97.15104  444.8361  ]
 [104.35014  494.40057 ]
 [ 54.30769  248.50444 ]
 [105.82382  463.7349  ]
 [104.098045 151.01514 ]
 [109.301834 455.18204 ]
 [119.81332  392.13077 ]
 [120.46411  193.23158 ]
 [122.93256  132.4474  ]]


In [193]:
H

array([[ 8.80501224e-01,  6.14535965e-03,  4.55339680e+00],
       [-7.10588476e-02,  9.59973950e-01,  5.99304799e+00],
       [-5.59793175e-04,  3.06752876e-05,  1.00000000e+00]])