In [24]:
import torch
import numpy as np
import cv2
import dask.array as da
import glob,os
from tqdm import tqdm
import matplotlib.pylab as plt
def read_im(fl,return_pos=False,ncols=4):
    data = os.path.dirname(fl)+os.sep+os.path.basename(fl).split('_')[-1].split('.')[0]+r'\data'
    im = da.from_zarr(fl,component=data)
    im = im[1:]
    im = im.reshape([-1,ncols,im.shape[-2],im.shape[-1]])
    im = im.swapaxes(0,1)
    
    im=im.astype(np.float32)
    im=im*im
    if return_pos is False:
        return im
    else:
        fl_xml = fl.replace('.zarr','.xml')
        x,y = [eval(ln.split('>')[1].split('<')[0]) for ln in open(fl_xml) if 'stage_position' in ln][0]
        return im,x,y
def norm_slice(im,s=50):
    im_=im.astype(np.float32)
    return np.array([im__-cv2.blur(im__,(s,s)) for im__ in im_],dtype=np.float32)
def get_local_maxfast_tensor(im_dif_npy,th_fit=500,im_raw=None,dic_psf=None,delta=1,delta_fit=3,sigmaZ=1,sigmaXY=1.5,gpu=False):
    import torch
    dev = "cuda:0" if (torch.cuda.is_available() and gpu) else "cpu"
    #dev = "cpu"
    im_dif = torch.from_numpy(im_dif_npy).to(dev)
    z,x,y = torch.where(im_dif>th_fit)
    zmax,xmax,ymax = im_dif.shape
    def get_ind(x,xmax):
        # modify x_ to be within image
        x_ = torch.clone(x)
        bad = x_>=xmax
        x_[bad]=xmax-x_[bad]-2
        bad = x_<0
        x_[bad]=-x_[bad]
        return x_
    #def get_ind(x,xmax):return x%xmax
    for d1 in range(-delta,delta+1):
        for d2 in range(-delta,delta+1):
            for d3 in range(-delta,delta+1):
                if (d1*d1+d2*d2+d3*d3)<=(delta*delta):
                    z_ = get_ind(z+d1,zmax)
                    x_ = get_ind(x+d2,xmax)
                    y_ = get_ind(y+d3,ymax)
                    keep = im_dif[z,x,y]>=im_dif[z_,x_,y_]
                    z,x,y = z[keep],x[keep],y[keep]
    h = im_dif[z,x,y]
    
    
    if len(x)==0:
        return []
    if delta_fit>0:
        d1,d2,d3 = np.indices([2*delta_fit+1]*3).reshape([3,-1])-delta_fit
        kp = (d1*d1+d2*d2+d3*d3)<=(delta_fit*delta_fit)
        d1,d2,d3 = d1[kp],d2[kp],d3[kp]
        d1 = torch.from_numpy(d1).to(dev)
        d2 = torch.from_numpy(d2).to(dev)
        d3 = torch.from_numpy(d3).to(dev)
        im_centers0 = (z.reshape(-1, 1)+d1).T
        im_centers1 = (x.reshape(-1, 1)+d2).T
        im_centers2 = (y.reshape(-1, 1)+d3).T
        z_ = get_ind(im_centers0,zmax)
        x_ = get_ind(im_centers1,xmax)
        y_ = get_ind(im_centers2,ymax)
        im_centers3 = im_dif[z_,x_,y_]
        if im_raw is not None:
            im_raw_ = torch.from_numpy(im_raw).to(dev)
            im_centers4 = im_raw_[z_,x_,y_]
            habs = im_raw_[z,x,y]
        else:
            im_centers4 = im_dif[z_,x_,y_]
            habs = x*0
            a = x*0
        Xft = torch.stack([d1,d2,d3]).T
    
        bk = torch.min(im_centers3,0).values
        im_centers3 = im_centers3-bk
        im_centers3 = im_centers3/torch.sum(im_centers3,0)
        if dic_psf is None:
            sigma = torch.tensor([sigmaZ,sigmaXY,sigmaXY],dtype=torch.float32,device=dev)#np.array([sigmaZ,sigmaXY,sigmaXY],dtype=np.flaot32)[np.newaxis]
            Xft_ = Xft/sigma
            norm_G = torch.exp(-torch.sum(Xft_*Xft_,-1)/2.)
            norm_G=(norm_G-torch.mean(norm_G))/torch.std(norm_G)
    
            hn = torch.mean(((im_centers3-im_centers3.mean(0))/im_centers3.std(0))*norm_G.reshape(-1,1),0)
            a = torch.mean(((im_centers4-im_centers4.mean(0))/im_centers4.std(0))*norm_G.reshape(-1,1),0)
            
        zc = torch.sum(im_centers0*im_centers3,0)
        xc = torch.sum(im_centers1*im_centers3,0)
        yc = torch.sum(im_centers2*im_centers3,0)
        Xh = torch.stack([zc,xc,yc,bk,a,habs,hn,h]).T.cpu().detach().numpy()
    else:
        Xh =  torch.stack([z,x,y,h]).T.cpu().detach().numpy()
    return Xh
def get_psf(im_,th=1000,th_cor = 0.75,delta=3,delta_fit = 7,sxyzP = [15,30,30],gpu=True):

    """
    Use as :
    
    psfs = []
    for ifov in tqdm(range(55,80)):
        im = read_im(r'X:\CGBB_embryo_4_28_2023\P1_Sox11_Sox2_Dcx_D16\Conv_zscan__'+str(ifov).zfill(3)+'.zarr')
        im_ = np.array(im[0][1:,500:2500,500:2500],dtype=np.float32)
        psf = get_psf(im_,th=1000,th_cor = 0.75,delta=3,delta_fit = 7,sxyzP = [15,60,60])
        psfs.append(psf)
    #napari.view_image(im)
    psff = np.mean([psf for psf in psfs if psf is not None],axis=0)

    psff = np.mean([psf for psf in psfs if psf is not None],axis=0)
    psff_ = np.array([p-np.median(p) for p in psff])
    from scipy.ndimage import median_filter
    psff_med = median_filter(psff_, size=15)
    psfff = (psff_-psff_med)[5:-5,5:-5,5:-5][:-1,:-1,:-1]
    psfff[psfff<0]=0
    psfff = psfff/np.max(psfff)
    np.save('psf_750_Scope1_embryo_big_final.npy',psfff)
    
    """

    im_n = norm_slice(im_,s=30)
    
    Xh = get_local_maxfast_tensor(im_n,im_raw=im_,th_fit=th,gpu=True,delta=delta,delta_fit=delta_fit)
    if Xh is not None:
        if len(Xh)>0:
            Xh_ = Xh[(Xh[:,-2]>th_cor)&(Xh[:,0].astype(int)>0)]

            X = Xh_[:,:3]
            XT = np.round(X).astype(np.int16)
            szP,sxP,syP = sxyzP
            Xi = np.indices([2*szP+1,2*sxP+1,2*syP+1],dtype=np.int16).reshape([3,-1]).T-np.array([szP,sxP,syP])
            #dev = 'cpu'
            dev = "cuda:0" if (torch.cuda.is_available() and gpu) else "cpu"
            XT = torch.from_numpy(XT).to(dev)
            Xi = torch.from_numpy(Xi[:,np.newaxis]).to(dev)
            shape_ = torch.from_numpy(np.array(im_.shape,dtype=np.int16)).to(dev)
            Xf = ((XT+Xi)%(shape_)).type(torch.int)
            imdev = torch.from_numpy(im_).to(dev)
            ims = imdev[Xf[...,0],Xf[...,1],Xf[...,2]].reshape([2*szP+1,2*sxP+1,2*syP+1,-1])


            height,width,depth = [2*szP+1,2*sxP+1,2*syP+1]
            kx, ky, kz = np.mgrid[:height, :width, :depth]  # ,:self.sliceShape[2]]

            kx = np.fft.fftshift(kx - height / 2.) / height
            ky = np.fft.fftshift(ky - width / 2.) / width
            kz = np.fft.fftshift(kz - depth / 2.) / depth

            dx = torch.from_numpy((X-np.round(X)).astype(np.float32)).to(dev)
            k = torch.from_numpy(np.array([kx,ky,kz]).astype(np.float32)).to(dev)
            expK = torch.exp(-2j*np.pi*torch.tensordot(dx,k,dims=1)).moveaxis(0,-1)
            F = torch.fft.fftn(ims,dim=[0,1,2])
            psf = torch.mean(torch.fft.ifftn(F*expK,dim=[0,1,2]).real,-1)
            return psf.cpu().detach().numpy()
def linear_flat_correction(ims,fl=None,reshape=True,resample=4,vec=[0.1,0.15,0.25,0.5,0.75,0.9]):
    #correct image as (im-bM[1])/bM[0]
    #ims=np.array(ims)
    if reshape:
        ims_pix = np.reshape(ims,[ims.shape[0]*ims.shape[1],ims.shape[2],ims.shape[3]])
    else:
        ims_pix = np.array(ims[::resample])
    ims_pix_sort = np.sort(ims_pix[::resample],axis=0)
    ims_perc = np.array([ims_pix_sort[int(frac*len(ims_pix_sort))] for frac in vec])
    i1,i2=np.array(np.array(ims_perc.shape)[1:]/2,dtype=int)
    x = ims_perc[:,i1,i2]
    #x = np.mean(np.mean(ims_perc,axis=-1),axis=-1)
    X = np.array([x,np.ones(len(x))]).T
    y=ims_perc
    a = np.linalg.inv(np.dot(X.T,X))
    cM = np.swapaxes(np.dot(X.T,np.swapaxes(y,0,-2)),-2,1)
    bM = np.swapaxes(np.dot(a,np.swapaxes(cM,0,-2)),-2,1)
    if fl is not None:
        np.save(fl,bM)
    return bM

In [30]:
fls = glob.glob(rf'S:\12_04_2025_JenieSample\H1\*.zarr')
from tqdm import tqdm
for icol in [1,2,3]:
    ims = np.array([np.array(read_im(fl)[icol][20],dtype=np.float32)for fl in tqdm(fls)])
    immed = np.median(ims,axis=0)
    np.save(f"med{icol}.npy",immed)

100%|████████████████████████████████████████████████████████████████████████████████| 563/563 [01:31<00:00,  6.13it/s]
100%|████████████████████████████████████████████████████████████████████████████████| 563/563 [01:28<00:00,  6.36it/s]
100%|████████████████████████████████████████████████████████████████████████████████| 563/563 [01:16<00:00,  7.32it/s]


In [25]:
bM = linear_flat_correction(ims,fl=f"medcor{icol}.npy",reshape=False,resample=1,vec=[0.1,0.15,0.25,0.5,0.75,0.9])

In [31]:
napari.view_image(immed)

Viewer(camera=Camera(center=(0.0, np.float64(1399.5), np.float64(1399.5)), zoom=np.float64(0.2032321428571428), angles=(0.0, 0.0, 90.0), perspective=0.0, mouse_pan=True, mouse_zoom=True), cursor=Cursor(position=(1.0, 1.0), scaled=True, style=<CursorStyle.STANDARD: 'standard'>, size=1.0), dims=Dims(ndim=2, ndisplay=2, order=(0, 1), axis_labels=('0', '1'), rollable=(True, True), range=(RangeTuple(start=np.float64(0.0), stop=np.float64(2799.0), step=np.float64(1.0)), RangeTuple(start=np.float64(0.0), stop=np.float64(2799.0), step=np.float64(1.0))), margin_left=(0.0, 0.0), margin_right=(0.0, 0.0), point=(np.float64(1399.0), np.float64(1399.0)), last_used=0), grid=GridCanvas(stride=1, shape=(-1, -1), enabled=False), layers=[<Image layer 'immed' at 0x25a194ea160>], help='use <2> for transform', status='Ready', tooltip=Tooltip(visible=False, text=''), theme='dark', title='napari', mouse_over_canvas=False, mouse_move_callbacks=[], mouse_drag_callbacks=[], mouse_double_click_callbacks=[<functio

In [7]:
numbers = [
    '393', '392', '391', '390', '154', '153', '152', '065', '066',
    '405', '406', '407', '408', '061', '062', '063', '064', '055',
    '412', '411', '410', '409', '060', '059', '058', '057', '056',
    '420', '442', '443', '444', '446', '447', '448', '449', '450',
    '422', '421', '441', '440', '445', '474', '473', '472', '471',
    '423', '437', '438', '439', '475', '476', '477', '486', '487',
    '427', '436', '434', '433', '480', '479', '478', '485', '489'
]
len(numbers)

63

In [8]:
ifov_str = numbers[0]
fl = rf'S:\12_04_2025_JenieSample\H1\Conv_zscan1__{ifov_str}.zarr'
im = read_im(fl)
im_ = np.array(im[2],dtype=np.float32)

In [10]:
import napari
napari.view_image(norm_slice(im_,s=30))

Viewer(camera=Camera(center=(0.0, np.float64(1399.5), np.float64(1399.5)), zoom=np.float64(0.1957678571428571), angles=(0.0, 0.0, 90.0), perspective=0.0, mouse_pan=True, mouse_zoom=True), cursor=Cursor(position=(np.float64(19.0), 1.0, 0.0), scaled=True, style=<CursorStyle.STANDARD: 'standard'>, size=1.0), dims=Dims(ndim=3, ndisplay=2, order=(0, 1, 2), axis_labels=('0', '1', '2'), rollable=(True, True, True), range=(RangeTuple(start=np.float64(0.0), stop=np.float64(39.0), step=np.float64(1.0)), RangeTuple(start=np.float64(0.0), stop=np.float64(2799.0), step=np.float64(1.0)), RangeTuple(start=np.float64(0.0), stop=np.float64(2799.0), step=np.float64(1.0))), margin_left=(0.0, 0.0, 0.0), margin_right=(0.0, 0.0, 0.0), point=(np.float64(19.0), np.float64(1399.0), np.float64(1399.0)), last_used=0), grid=GridCanvas(stride=1, shape=(-1, -1), enabled=False), layers=[<Image layer 'Image' at 0x1bc9241b670>], help='use <2> for transform', status='Ready', tooltip=Tooltip(visible=False, text=''), the

Cannot find steve


In [9]:
sx,sy = im_.shape[-2:]
sxy = 400
keys = []
for x in np.arange(0,sx,sxy):
    for y in np.arange(0,sx,sxy):
        keys.append((0,x,y))
len(keys)

49

In [None]:
dic_psfs = {key:[]for key in keys}

In [16]:
for ifov_str in numbers[1:]:
    fl = rf'S:\12_04_2025_JenieSample\H1\Conv_zscan1__{ifov_str}.zarr'
    im = read_im(fl)
    im_ = np.array(im[2],dtype=np.float32)
    
    for z,x,y in tqdm(keys):
        psf = get_psf(im_[z:z+sxy,x:x+sxy,y:y+sxy],th=5000,th_cor = 0.5,delta=3,delta_fit = 3,sxyzP = [15,60,60])
        dic_psfs[(z,x,y)].append(psf)

100%|██████████████████████████████████████████████████████████████████████████████████| 49/49 [04:48<00:00,  5.89s/it]
100%|██████████████████████████████████████████████████████████████████████████████████| 49/49 [03:25<00:00,  4.19s/it]
100%|██████████████████████████████████████████████████████████████████████████████████| 49/49 [05:48<00:00,  7.12s/it]
100%|██████████████████████████████████████████████████████████████████████████████████| 49/49 [06:30<00:00,  7.98s/it]
100%|██████████████████████████████████████████████████████████████████████████████████| 49/49 [03:03<00:00,  3.75s/it]
100%|██████████████████████████████████████████████████████████████████████████████████| 49/49 [04:13<00:00,  5.16s/it]
100%|██████████████████████████████████████████████████████████████████████████████████| 49/49 [10:57<00:00, 13.41s/it]
100%|██████████████████████████████████████████████████████████████████████████████████| 49/49 [15:40<00:00, 19.19s/it]
100%|███████████████████████████████████

KeyboardInterrupt: 

In [18]:
import pickle
pickle.dump(dic_psfs,open("psfs.pkl",'wb'))

In [23]:
mean_psf = {key:np.mean(dic_psfs[key],axis=0) for key in dic_psfs}

In [26]:
mean_psf2 = {key:np.mean([psf for psf in dic_psfs[key] if np.corrcoef(psf.ravel(),mean_psf[key].ravel())[0,1]>0.75],axis=0) for key in dic_psfs}

In [27]:
impsfs = np.array([mean_psf2[key] for key in mean_psf2])

In [29]:
pickle.dump(mean_psf2,open("psf_final.pkl",'wb'))

In [31]:
keys = list(mean_psf2.keys())


In [35]:
keys = list(mean_psf2.keys())
psf_final = {}
for key in tqdm(keys):
    psf = mean_psf2[key]
    psff_ = np.array([p-np.median(p) for p in psf])
    from scipy.ndimage import median_filter
    psff_med = median_filter(psff_, size=15)
    psfff = (psff_-psff_med)[5:-5,5:-5,5:-5][:-1,:-1,:-1]
    psfff[psfff<0]=0
    psfff = psfff/np.max(psfff)
    psf_final[key]=psfff
#np.save(r'T:\20250309_R165_AllenBrain_S4\psf_R165.npy',psfff)

100%|██████████████████████████████████████████████████████████████████████████████████| 49/49 [08:38<00:00, 10.58s/it]


In [36]:
pickle.dump(psf_final,open("psf_final.pkl",'wb'))

In [3]:
import pickle,numpy as np
psf_final = pickle.load(open("psf_final.pkl",'rb'))

In [4]:
psf_ = np.mean(list(psf_final.values()),axis=0)

In [45]:
napari.view_image(psf_)

Viewer(camera=Camera(center=(0.0, np.float64(54.5), np.float64(54.5)), zoom=np.float64(4.9831818181818175), angles=(0.0, 0.0, 90.0), perspective=0.0, mouse_pan=True, mouse_zoom=True), cursor=Cursor(position=(np.float64(9.0), 1.0, 0.0), scaled=True, style=<CursorStyle.STANDARD: 'standard'>, size=1.0), dims=Dims(ndim=3, ndisplay=2, order=(0, 1, 2), axis_labels=('0', '1', '2'), rollable=(True, True, True), range=(RangeTuple(start=np.float64(0.0), stop=np.float64(19.0), step=np.float64(1.0)), RangeTuple(start=np.float64(0.0), stop=np.float64(109.0), step=np.float64(1.0)), RangeTuple(start=np.float64(0.0), stop=np.float64(109.0), step=np.float64(1.0))), margin_left=(0.0, 0.0, 0.0), margin_right=(0.0, 0.0, 0.0), point=(np.float64(9.0), np.float64(54.0), np.float64(54.0)), last_used=0), grid=GridCanvas(stride=1, shape=(-1, -1), enabled=False), layers=[<Image layer 'psf_' at 0x1ef1e6a8940>], help='use <2> for transform', status='Ready', tooltip=Tooltip(visible=False, text=''), theme='dark', ti

Cannot find steve


In [5]:
fl = rf'S:\12_04_2025_JenieSample\H1\Conv_zscan1__200.zarr'
im = read_im(fl)
im_ = np.array(im[-1],dtype=np.float32)

In [33]:
imd = full_deconv(im_,s_=400,pad=100,psf=psf_final,parameters={'method': 'wiener', 'beta': 0.0001},gpu=True,force=True)#apply_deconv(im_[:,:,:],psf=psf_final[(0,1600,1600)],plt_val=False,parameters = {'method':'wiener','beta':0.01,'niter':50},gpu=True,force=True,pad=None)

In [34]:
import napari
V = napari.view_image(imd,scale=[0.4,0.10833,0.10833])
V.add_image(im_,scale=[0.4,0.10833,0.10833])

<Image layer 'im_' at 0x25a17c22670>

In [32]:
fl = rf'S:\12_04_2025_JenieSample\H1\Conv_zscan1__434.zarr'
im = read_im(fl)
im_ = np.array(im[0],dtype=np.float32)

In [9]:
imd = full_deconv(im_,s_=400,pad=100,psf=psf_final,parameters={'method': 'wiener', 'beta': 0.005, 'niter': 50},gpu=True,force=True)#apply_deconv(im_[:,:,:],psf=psf_final[(0,1600,1600)],plt_val=False,parameters = {'method':'wiener','beta':0.01,'niter':50},gpu=True,force=True,pad=None)

In [10]:
import napari
V = napari.view_image(imd,scale=[0.4,0.10833,0.10833])
V.add_image(im_,scale=[0.4,0.10833,0.10833])

<Image layer 'im_' at 0x25975ba3cd0>

In [2]:
def get_local_max_tile(im_,th=2500,s_ = 300,pad=50,psf=None,plt_val=None,snorm=30,gpu=False,deconv={'method':'wiener','beta':0.001},
                        delta=1,delta_fit=3,sigmaZ=1,sigmaXY=1.5):
    sx,sy = im_.shape[1:]
    ixys = []
    for ix in np.arange(0,sx,s_):
        for iy in np.arange(0,sy,s_):
            ixys.append([ix,iy])
    Xhf = None
    for ix,iy in ixys:#tqdm(ixys):
        imsm = im_[:,ix:ix+pad+s_,iy:iy+pad+s_]
        out_im = imsm
        if deconv is not None:
            force = True
            psf_ = psf
            if type(psf) is dict:
                force=True
                keys = list(psf.keys())
                ikey = np.argmin(np.sum(np.abs(np.array(keys)-[0,ix,iy]),axis=-1))
                psf_ = psf[keys[ikey]]
            out_im = apply_deconv(imsm,psf=psf_,plt_val=False,parameters = deconv,gpu=gpu,force=True,pad=None)
        out_im2 = norm_slice(out_im,s=snorm)
        #print(time.time()-t)
        Xh = get_local_maxfast_tensor(out_im2,th,im_raw=imsm,dic_psf=None,delta=delta,delta_fit=delta_fit,sigmaZ=sigmaZ,sigmaXY=sigmaXY,gpu=gpu)
        ### exclude outside the padded area
        if Xh is not None:
            if len(Xh)>0:
                keep = np.all(Xh[:,1:3]<(s_+pad/2),axis=-1)
                keep &= np.all(Xh[:,1:3]>=(pad/2*np.array([ix>0,iy>0])),axis=-1)
                Xh = Xh[keep]
                Xh[:,1]+=ix
                Xh[:,2]+=iy
                #Xh[:,:3]-=1
                if Xhf is None: Xhf=Xh
                else: Xhf = np.concatenate([Xhf,Xh])
        #print(time.time()-t)
    if plt_val is not None:
        import napari
        v = napari.Viewer()
        #im__ = norm_slice(im_,s=30)
        v.add_image(im_,name='Original image')
        v.add_image(out_im2,name='Deconv image')
        H= Xhf[:,-1]
        size=None
        if type(plt_val) is dict:
            size = plt_val.get('size')
        if size is None:
            size = np.clip(H/np.percentile(H,99.99),0,1)*5
        v.add_points(Xhf[:,:3],face_color=[0,0,0,0],edge_color='y',size=size)
    return Xhf
def apply_deconv(imsm,psf=None,plt_val=False,parameters = {'method':'wiener','beta':0.001,'niter':50},gpu=True,force=True,pad=None):
    r"""Applies deconvolution to image <imsm> using sdeconv: https://github.com/sylvainprigent/sdeconv/
    Currently assumes 60x objective with ~1.4 NA using SPSFGibsonLanni. Should be modified to find 
    
    Recomendations: the default wiener method with a low beta is the best for very fast local fitting. Albeit larger scale artifacts.
    For images: recommend the lucy method with ~30 iterations.
    
    This wraps around pytoch.
    
    To install:
    pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu117
    pip install sdeconv
    Optional: decided to modify the __init__ file of the SSettingsContainer in 
    C:\Users\BintuLabUser\anaconda3\envs\cellpose\Lib\site-packages\sdeconv\core\_settings.py
    
    import os
    gpu = True
    if os.path.exists("use_gpu.txt"):
        gpu = eval(open("use_gpu.txt").read())
    self.device = torch.device("cuda:0" if (torch.cuda.is_available() and gpu) else "cpu")
    to toggle the GPU on or off. By default it just uses the GPU if GPU detected by pytorch"""
    
    #import sdeconv,os
    #fl = os.path.dirname(sdeconv.__file__)+os.sep+'core'+os.sep+'use_gpu.txt'
    #fid = open(fl,'w')
    #fid.write('True')
    #fid.close()
    import torch
    from sdeconv.core import SSettings
    obj = SSettings.instance()
    obj.device = torch.device("cuda:0" if (torch.cuda.is_available() and gpu) else "cpu")
    if force:
        if hasattr(obj,'dic_psf'): del obj.dic_psf
    # map to tensor
    
    imsm_ = torch.from_numpy(np.array(imsm,dtype=np.float32))
    if psf is None:
        from sdeconv.psfs import SPSFGibsonLanni
        #psf_generator = SPSFGaussian((1,1.5, 1.5), imsm_.shape)
        psf_generator = SPSFGibsonLanni(M=60,shape=imsm_.shape)
        psf = psf_generator().to(obj.device)
    else:
        psff = np.zeros(imsm_.shape,dtype=np.float32)
                
        slices = [(slice((s_psff-s_psf_full_)//2,(s_psff+s_psf_full_)//2),slice(None)) if s_psff>s_psf_full_ else
         (slice(None),slice((s_psf_full_-s_psff)//2,(s_psf_full_+s_psff)//2))
          
          for s_psff,s_psf_full_ in zip(psff.shape,psf.shape)]
        sl_psff,sl_psf_full_ = list(zip(*slices))
        psff[sl_psff]=psf[sl_psf_full_]
        psf = torch.from_numpy(np.array(psff,dtype=np.float32)).to(obj.device)
        
    method = parameters.get('method','wiener')
    if pad is None:
        pad = int(np.min(list(np.array(imsm.shape)-1)+[50]))
    if method=='wiener':
        from sdeconv.deconv import SWiener
        beta = parameters.get('beta',0.001)
        filter_ = SWiener(psf, beta=beta, pad=pad)
        #monkey patch _wiener_3d to allow recycling the fft of the psf components
        #filter_._wiener_3d = _wiener_3d.__get__(filter_, SWiener)
    elif method=='lucy':
        from sdeconv.deconv import SRichardsonLucy
        niter = parameters.get('niter',50)
        filter_ = SRichardsonLucy(psf, niter=niter, pad=pad)
    elif method=='spitfire':
        from sdeconv.deconv import Spitfire
        filter_ = Spitfire(psf, weight=0.6, reg=0.995, gradient_step=0.01, precision=1e-6, pad=pad)
    out_image = filter_(imsm_)
    out_image = out_image.cpu().detach().numpy().astype(np.float32)
    if plt_val:
        import napari
        viewer = napari.view_image(out_image)
        viewer.add_image(imsm)
    return out_image

def full_deconv(im_,s_=300,pad=100,psf=None,parameters={'method': 'wiener', 'beta': 0.001, 'niter': 50},gpu=True,force=True):
    im0=np.zeros_like(im_)
    sx,sy = im_.shape[1:]
    ixys = []
    for ix in np.arange(0,sx,s_):
        for iy in np.arange(0,sy,s_):
            ixys.append([ix,iy])
    
    for ix,iy in ixys:#ixys:#tqdm(ixys):
        imsm = im_[:,ix:ix+pad+s_,iy:iy+pad+s_]
        if type(psf) is dict:
            keys = list(psf.keys())
            ikey = np.argmin(np.sum(np.abs(np.array(keys)-[0,ix,iy]),axis=-1))
            psf_ = psf[keys[ikey]]
            force=True
        else:
            psf_ = psf
        imt = apply_deconv(imsm,psf=psf_,parameters=parameters,gpu=gpu,plt_val=False,force=force)
        start_x = ix+pad//2 if ix>0 else 0
        end_x = ix+pad//2+s_
        start_y = iy+pad//2 if iy>0 else 0
        end_y = iy+pad//2+s_
        #print(start_x,end_x,start_y,end_y)
        im0[:,start_x:end_x,start_y:end_y] = imt[:,(start_x-ix):(end_x-ix),(start_y-iy):(end_y-iy)]
    return im0

In [33]:
import napari
napari.view_image(psfff)

Viewer(camera=Camera(center=(0.0, np.float64(54.5), np.float64(54.5)), zoom=np.float64(5.181818181818181), angles=(0.0, 0.0, 90.0), perspective=0.0, mouse_pan=True, mouse_zoom=True), cursor=Cursor(position=(np.float64(9.0), 1.0, 0.0), scaled=True, style=<CursorStyle.STANDARD: 'standard'>, size=1.0), dims=Dims(ndim=3, ndisplay=2, order=(0, 1, 2), axis_labels=('0', '1', '2'), rollable=(True, True, True), range=(RangeTuple(start=np.float64(0.0), stop=np.float64(19.0), step=np.float64(1.0)), RangeTuple(start=np.float64(0.0), stop=np.float64(109.0), step=np.float64(1.0)), RangeTuple(start=np.float64(0.0), stop=np.float64(109.0), step=np.float64(1.0))), margin_left=(0.0, 0.0, 0.0), margin_right=(0.0, 0.0, 0.0), point=(np.float64(9.0), np.float64(54.0), np.float64(54.0)), last_used=0), grid=GridCanvas(stride=1, shape=(-1, -1), enabled=False), layers=[<Image layer 'psfff' at 0x1ef1b14ac70>], help='use <2> for transform', status='Ready', tooltip=Tooltip(visible=False, text=''), theme='dark', ti

Cannot find steve
