In [None]:
from Supporting_functions import *
from aicsimageio.readers import CziReader
from aicsimageio.writers import OmeTiffWriter
import re
import scipy.ndimage as ndi



# Predictions for a folder

In [None]:
# Define paths

NucleiChannelId = 0
modelName = 'nuclei2nuclei'
modelId = 'deblur_net_22_7200.pkl'
angles = ['023', '068', '203', '338']


# Models
srcpath = r'/home/hmorales/WorkSpace/DataIsoReconstructions/'+modelName+'/'
model_path = srcpath+'checkpoint/saved_models/deblur/'+modelId

# Image to test
img_src_path = '/run/user/1000/gvfs/smb-share:server=134.34.176.179,share=pmtest_fast/Cornelius/pErk/20230913/sample/'


#output dir
outdir = '/media/hmorales/Skynet/Isotropic/' #/run/user/1000/gvfs/smb-share:server=134.34.176.179,share=pmtest_fast/Cornelius/pErk/20230913/Isotropic/'

# CUDA device
device1 = torch.device('cuda:0')
device2 = torch.device('cuda:0')
batch_size = 4

# Image Normalization
min_v = 0
max_v = 65535
norm_percentiles = (50.0, 99.9995)  #99.9995 For Nuclei, 99.999 For Membranes
norm_percentiles_out = (50.0, 99.999)  #99.9995 For Nuclei, 99.999 For Membranes
crop_for_calculations = True
thres_crop = 1.2



# Create output folder
if not os.path.exists(outdir):
    os.mkdir(outdir)

In [None]:
# Prepare networks
deblur_net_A = Self_net_architecture.define_G(input_nc=1, output_nc=1, ngf=64, netG='deblur_net', device=device1,use_dropout=False,norm='instance')
deblur_net_B = Self_net_architecture.define_G(input_nc=1, output_nc=1, ngf=64, netG='deblur_net', device=device1,use_dropout=False,norm='instance')

#deblur_net_A  = Self_net_architecture.define_G(input_nc=1, output_nc=1, ngf=64, netG='deblur_net', device=device1, use_dropout=False, norm='instance')
#deblur_net_B  = Self_net_architecture.define_G(input_nc=1, output_nc=1, ngf=64, netG='deblur_net', device=device1, use_dropout=False, norm='instance')


 # Load Model
deblur_net_A.load_state_dict(torch.load(model_path))
deblur_net_B.load_state_dict(torch.load(model_path))

In [None]:
# Get all tif images in the folder
image_names = sorted([f for f in os.listdir(img_src_path) if f.endswith('.czi')])

for i, image_name in enumerate(image_names):
    
    start_time = time.time()  # Record the start time 
    
    # get image path 
    czi_file_path = os.path.join(img_src_path, image_name)   
    reader = CziReader(czi_file_path)
    
 
    match = re.search(r'\((\d+)\)', image_name)
    timeId = match.group(1)
    timeId= timeId.zfill(3)
    print('spim_TL'+str(timeId))
    N0 = 0
    if i ==0 :
        N0 = 1
        
    for view in range(N0, reader.dims.V): 
        print("Processing view : ", str(view))
        for color in range(reader.dims.C): 
            
            start_time = time.time()  # Record the start time 
            # Open image
            lazy_t0 = reader.get_image_dask_data("ZYX", V=view, C=color)  # returns 3D ZYX numpy array
            img = lazy_t0.compute()  # returns in-memory 4D numpy array
            img = img.astype(np.uint16)
            #img = img[:,:,:800]
            scale = reader.physical_pixel_sizes.X / reader.physical_pixel_sizes.Z
            print(img.shape)

            
            # Make isotropic image and predict nuclei
            if color == NucleiChannelId:
                
                z,y,x=img.shape
                
                # crop for calculations
                if crop_for_calculations == True:
                    raw_img = img
                    bounds_min, bounds_max = get_image_cropping_box(raw_img, 2.0, scale, thres_crop)          
                    img = raw_img[bounds_min[0]:bounds_max[0], bounds_min[1]:bounds_max[1], bounds_min[2]:bounds_max[2]]
                                       
                # Normalize 
                img = image_preprocessing(img, norm_percentiles, min_v, max_v)
 
                # Predict image
                fusion_stack = upsample_block(img,scale,1,deblur_net_A,deblur_net_B,min_v,max_v, device1, device2,batch_size)              
                               
                # put image back 
                if crop_for_calculations == True:
                    #scale_img = ndi.zoom(raw_img, ( 1/scale,1,1))
                    scale_img=reslice(raw_img,'xy',reader.physical_pixel_sizes.X,reader.physical_pixel_sizes.Z)
                    scale_img[round(bounds_min[0]/scale):round(bounds_max[0]/scale), bounds_min[1]:bounds_max[1], bounds_min[2]:bounds_max[2]] = fusion_stack
                    fusion_stack = scale_img
                
                # Normalize output
                fusion_stack = image_preprocessing(fusion_stack,norm_percentiles_out, min_v, max_v)
                
            else:
                fusion_stack=reslice(img,'xy',reader.physical_pixel_sizes.X,reader.physical_pixel_sizes.Z)
               # fusion_stack=ndi.zoom(img, ( 1/scale,1,1))

                
            # Save image
            fusion_stack = fusion_stack.astype(np.uint16)
            print(fusion_stack.shape)
            outName = 'spim_TL'+str(timeId)+'_Channel'+str(color)+'_Angle'+angles[view]+'.tif'
            img_out = os.path.join(outdir, outName)            
            tifffile.imwrite(      
                img_out,
                fusion_stack,
                imagej=True, 
                bigtiff=True,
                resolution=(1.0/reader.physical_pixel_sizes.X, 1.0/reader.physical_pixel_sizes.Y), 
                metadata={'spacing': reader.physical_pixel_sizes.X, 'unit': 'um', 'axes': 'ZYX'})
            

            Elapsed_time = time.time() - start_time
            print(f"Elapsed Time: {Elapsed_time:.4f} seconds, image {outName}")                  
 

In [None]:
newimg_min = np.amin(fusion_stack)
newimg_max = np.amax(fusion_stack)
print('Intensity Norm (%d, %d) ' % (newimg_min, newimg_max), '\n')

fusion_stack = image_preprocessing(fusion_stack, norm_percentiles, min_v, max_v)

newimg_min = np.amin(fusion_stack)
newimg_max = np.amax(fusion_stack)
print('Intensity Norm (%d, %d) ' % (newimg_min, newimg_max), '\n')