In [8]:
from Supporting_functions import *
from aicsimageio.readers import CziReader
from aicsimageio.writers import OmeTiffWriter
import re
import scipy.ndimage as ndi
from WBNS import WBNS_image
import RedLionfishDeconv as rl

# Predictions for a folder

In [9]:
# Define paths

NucleiMembranesChannelId = 1
modelNameNuclei       = 'NucleiMembrane2Nuclei'
modelNameMembranes    = 'NucleiMembrane2Membrane'
modelNameIsoNuclei    = 'Self_Nuclei'
modelNameIsoMembranes = 'Self_Membranes'
modelNameEnhNucleiMem = 'Enhance_NucleiMembranes'


modelIdNuclei       = 'deblur_net_8_72000.pkl'
modelIdMembranes    = 'deblur_net_8_96000.pkl'
modelIdIsoNuclei    = 'deblur_net_4_8000.pkl'
modelIdIsoMembranes = 'deblur_net_2_8000.pkl'
modelIdEnhNucleiMem = 'deblur_net_16_36000.pkl'


angles = ['050', '140',  '230', '320'] #['010', '145', '280', '325']


# Models
srcpath = r'/media/hmorales/Skynet/IsoNet/Models/'
model_path_Nuclei    = srcpath+modelNameNuclei+'/'+'checkpoint/saved_models/'+modelIdNuclei
model_path_Membranes = srcpath+modelNameMembranes+'/'+'checkpoint/saved_models/'+modelIdMembranes

model_path_IsoNuclei       = srcpath+modelNameIsoNuclei+'/'+'checkpoint/saved_models/'+modelIdIsoNuclei
model_path_IsoMembranes    = srcpath+modelNameIsoMembranes+'/'+'checkpoint/saved_models/'+modelIdIsoMembranes

model_path_EnhNucleiMem    = srcpath+modelNameEnhNucleiMem+'/'+'checkpoint/saved_models/'+modelIdEnhNucleiMem


# Image to test
img_src_path = '/media/hmorales/Skynet/IsoNet/test/'


#output dir
outdir = '/media/hmorales/Skynet/IsoNet/test/Isotropic/' #'/media/hmorales/MyBookDuo/Data/IsoNet/test/Isotropic/' #/run/user/1000/gvfs/smb-share:server=134.34.176.179,share=pmtest_fast/Cornelius/pErk/20230913/Isotropic/'

# CUDA device
device1 = torch.device('cuda:0')
device2 = torch.device('cuda:0')
batch_size = 6

skip_planes = 4   # for Nuclei Prediction
skip_planes2 = 0  # for Nuclei Prediction
        
# Image Normalization
min_v = 0
max_v = 65535
norm_percentiles =  (20, 99.999)  #99.9995 For Nuclei, 99.999 For Membranes
norm_percentiles_out = (50.0, 99.999)  #99.9995 For Nuclei, 99.999 For Membranes
crop_for_calculations = True
thres_crop = 1.5
blur_mask = 2.0

# post processing

# BG subtraction
resolution_px = 0 # FWHM of the PSF
noise_lvl = 2
sigma = 0.0

# deconvolution
padding = 32
Niter = 0
psf_path = r'/home/hmorales/WorkSpace/DataIsoReconstructions/Averaged_transformed_PSF_488.tif'

# Create output folder
if not os.path.exists(outdir):
    os.mkdir(outdir)
      
    

# Open PSF and Prepare PSF
psf = tifffile.imread(psf_path)
psf_f = psf.astype(np.float32)
psf_norm = psf_f/psf_f.sum()

In [10]:
def open_image_from_reader(reader, view, color, order="ZYX", out_type=np.uint16):
    
    lazy_t0 = reader.get_image_dask_data(order, V=view, C=color)  # returns 3D ZYX numpy array
    img = lazy_t0.compute()  # returns in-memory 4D numpy array
    img = img.astype(out_type)
    
    return img, reader.physical_pixel_sizes.X, reader.physical_pixel_sizes.Z

def image_preprocessing(img, mask, percentiles, min_v, max_v):
    
    # Image Normalization
    if percentiles[0] > 0 or percentiles[1] < 100:
        low_thres, high_thres0 = getNormalizationThresholds(img, percentiles) # low thres in whole image
        low_thres0, high_thres = getNormalizationThresholds(img * mask, percentiles) # high thres in FG
        img = remove_outliers_image(img, low_thres, high_thres)
  
    img, scaleI = image_get_scaling(img, min_v, max_v)
    #img = img.astype(np.uint16)
    
    return img, scaleI


In [11]:
# Prepare networks
net_Nuclei    = Self_net_architecture.define_G(input_nc=1, output_nc=1, ngf=64, netG='deblur_net', device=device1,use_dropout=False,norm='instance')
net_Membranes = Self_net_architecture.define_G(input_nc=1, output_nc=1, ngf=64, netG='deblur_net', device=device1,use_dropout=False,norm='instance')

net_IsoNuclei    = Self_net_architecture.define_G(input_nc=1, output_nc=1, ngf=64, netG='deblur_net', device=device1,use_dropout=False,norm='instance')
net_IsoMembranes = Self_net_architecture.define_G(input_nc=1, output_nc=1, ngf=64, netG='deblur_net', device=device1,use_dropout=False,norm='instance')

net_EhnNucleiMem = Self_net_architecture.define_G(input_nc=1, output_nc=1, ngf=64, netG='deblur_net', device=device1,use_dropout=False,norm='instance')


# Load Model
net_Nuclei.load_state_dict(torch.load(model_path_Nuclei))
net_Membranes.load_state_dict(torch.load(model_path_Membranes))

net_IsoNuclei.load_state_dict(torch.load(model_path_IsoNuclei))
net_IsoMembranes.load_state_dict(torch.load(model_path_IsoMembranes))

net_EhnNucleiMem.load_state_dict(torch.load(model_path_EnhNucleiMem))


initialize network with normal
initialize network with normal
initialize network with normal
initialize network with normal
initialize network with normal


<All keys matched successfully>

In [12]:
# Get all tif images in the folder
image_names = sorted([f for f in os.listdir(img_src_path) if f.endswith('.czi')])

for i, image_name in enumerate(image_names):
       
    # get image path 
    czi_file_path = os.path.join(img_src_path, image_name)   
    reader = CziReader(czi_file_path)
    
    # Get time Id 
    match = re.search(r'\((\d+)\)', image_name)
    timeId = match.group(1)
    timeId= timeId.zfill(3)
    
    print(f"Processing image : {image_name}")
    for view in range(0,reader.dims.V ):      
        print(f"** Processing view : {view}")
        
        # Get general mask form NucleiMembranesChannel
        print(f"Getting tissue mask ...")
        start_time = time.time()
        # Open image
        img, pixel_size_X, pixel_size_Z = open_image_from_reader(reader, view, NucleiMembranesChannelId, "ZYX", np.uint16)
       # img, pixel_size_X, pixel_size_Z = open_image_from_reader(reader, view, 0, "ZYX", np.uint16)
        pixel_size_X = 2.0* pixel_size_X #!!!!!!!!!!!!!!!!!!!!! delete Error in objective
        scale = pixel_size_X / pixel_size_Z
        
        # Get mask
        if crop_for_calculations == True:
            bounds_min, bounds_max, mask = get_image_cropping_box(img, blur_mask, scale, thres_crop)                     
        else:
            mask = get_image_simple_mask(img, blur_mask, scale, thres_crop)
            
        # create empty output image as template
        z,y,x = img.shape
        new_z = round(z / scale) 
        base_img  = np.zeros((new_z,y,x))
        Elapsed_time = time.time() - start_time
        print(f"Elapsed Time masking: {Elapsed_time:.4f} seconds")      
    
        # Process color by color
        
        for color in range(reader.dims.C): 
            print("Processing color : ", str(color))
            
            start_time = time.time()  # Record the start time 
            # Open image
            raw_img, pixel_size_X, pixel_size_Z = open_image_from_reader(reader, view, color, "ZYX", np.uint16)
            pixel_size_X = 2.0* pixel_size_X #!!!!!!!!!!!!!!!!!!!!! delete Error in objective
            scale = pixel_size_X / pixel_size_Z

            #  Predict nuclei, membranes and make isotropic image
            if color == NucleiMembranesChannelId:
                
                # prepare image for calculations
                if crop_for_calculations == True:
                    img = raw_img[bounds_min[0]:bounds_max[0], bounds_min[1]:bounds_max[1], bounds_min[2]:bounds_max[2]]
                    # get mask for beads           
                    imgBeads =  img * (~mask)
                else:
                    img = raw_img
                    imgBeads =  raw_img * (~mask)

                # Normalize 
                #low_thres, high_thres = getNormalizationThresholds(img, norm_percentiles)
                #img = remove_outliers_image(img, low_thres, high_thres)
                #img, scaleI = image_get_scaling(img, min_v, max_v)

                img, scaleI = image_preprocessing(img, mask, norm_percentiles, min_v, max_v)
                print("scaleI : ", scaleI)
        
                # Enhance Image
                print("Enhance Image : ")
                #img = predict_stack(img,net_EhnNucleiMem, min_v, max_v, device1, batch_size, np.uint16)       
                #img, scaleI1 = image_preprocessing(img, mask, norm_percentiles, min_v, max_v)
               
                #scaleI = scaleI * scaleI1
                #print("scaleI : ", scaleI)
                
                # Predict image
                print("Predict imgNuclei : ")
                imgNuclei = predict_stack(img,net_Nuclei,    min_v, max_v, device1, batch_size, np.uint16)
                print("Predict imgCells : ")
                imgCells  = predict_stack(img,net_Membranes, min_v, max_v, device1, batch_size, np.uint16)
                
                imgNuclei = imgNuclei * mask
                imgCells  = imgCells  * mask
                
                # Isotropic prediction 
                print("Iso imgNuclei : ")
                imgNuclei = upsample_block(imgNuclei,pixel_size_X,pixel_size_Z,net_IsoNuclei,   net_IsoNuclei,   min_v,max_v, device1, device2, batch_size, skip_planes)              
                print("Iso imgCells : ")
                imgCells  = upsample_block(imgCells, pixel_size_X,pixel_size_Z,net_IsoMembranes,net_IsoMembranes,min_v,max_v, device1, device2, batch_size, skip_planes2)              
            
                # Re-scale intensities
                print("Scaling intensities ...")
                imgNuclei  = imgNuclei.astype(np.float32).astype(np.float32) / scaleI
                imgCells   = imgCells.astype(np.float32) / scaleI
                imgNuclei  = imgNuclei.astype(np.uint16)
                imgCells   = imgCells.astype(np.uint16)

               
                # put image back 
                if crop_for_calculations == True:   
                    ''' 
                    # if not isotropic prediction is done
                    base_img = np.zeros_like(raw_img)
                    imgNuclei = insert_predicted_image(base_img,imgNuclei,bounds_min,bounds_max,1.0)
                    imgCells = insert_predicted_image(base_img,imgCells, bounds_min,bounds_max,1.0)
                    imgBeads = insert_predicted_image(raw_img,imgBeads,bounds_min,bounds_max,1.0)
                    '''
                    imgNuclei = insert_predicted_image(base_img,imgNuclei,bounds_min,bounds_max,scale)
                    imgCells  = insert_predicted_image(base_img,imgCells, bounds_min,bounds_max,scale)
                    imgBeads  = insert_predicted_image(raw_img,imgBeads,bounds_min,bounds_max,1.0)              

                    
                # Make images isotropic 
                '''
                # if not isotropic prediction is done
                imgNuclei = reslice(imgNuclei,'xy',pixel_size_X,pixel_size_Z)
                imgCells  = reslice(imgCells,'xy',pixel_size_X,pixel_size_Z)
                '''
                imgBeads  = reslice_bysize(imgBeads,'xy',new_z)    
                    

                if Niter > 10: 
                    # Padding image
                    imgBeads = np.pad(imgBeads, padding, mode='reflect')
                    imgSizeGB = imgBeads.nbytes / (1024 ** 3)
                    print('     -size(GB) : ', imgSizeGB)
                    # GPU deconvolution
                    res_gpu = rl.doRLDeconvolutionFromNpArrays(imgBeads, psf, niter=Niter,resAsUint8=False)
                    # Removing padding
                    imgBeads = res_gpu[padding:-padding, padding:-padding, padding:-padding]
            
                # Normalize output
                #imgNuclei     = image_normalizing(imgNuclei,norm_percentiles_out, min_v, max_v)
                #imgCells      = image_normalizing(imgCells,norm_percentiles_out, min_v, max_v)
                #imgNucleiTemp = image_normalizing(imgNucleiTemp,norm_percentiles_out, min_v, max_v)

                # Remove noise and BG
                if resolution_px > 0:
                    imgNuclei = WBNS_image(imgNuclei, resolution_px, noise_lvl)
                    imgCells = WBNS_image(imgCells, resolution_px, noise_lvl)

                # Smooth
                if sigma > 0:
                    imgNuclei = ndi.gaussian_filter(imgNuclei, sigma)
                    imgCells = ndi.gaussian_filter(imgCells, sigma)
                
                # Save image
                print("Saving ...")
                outName = 'spim_TL'+str(timeId)+'_Channel'+str(color+2)+'_Angle'+angles[view]+'.tif'        
                custom_save_img(imgNuclei, outdir, outName, pixel_size_X,pixel_size_X, pixel_size_X)
                outName = 'spim_TL'+str(timeId)+'_Channel'+str(color+3)+'_Angle'+angles[view]+'.tif'        
                custom_save_img(imgCells, outdir, outName, pixel_size_X,pixel_size_X, pixel_size_X)
                outName = 'spim_TL'+str(timeId)+'_Channel'+str(color+4)+'_Angle'+angles[view]+'.tif'        
                custom_save_img(imgBeads, outdir, outName, pixel_size_X,pixel_size_X, pixel_size_X)
                 

            else:
                
                if crop_for_calculations == True:
                    img = raw_img[bounds_min[0]:bounds_max[0], bounds_min[1]:bounds_max[1], bounds_min[2]:bounds_max[2]]
                else:
                    img = raw_img
                
                img = img * mask
                fusion_stack=reslice(img,'xy',pixel_size_X,pixel_size_Z)

                if Niter > 0:       
                    # Padding image
                    fusion_stack = np.pad(fusion_stack, padding, mode='reflect')
                    imgSizeGB = fusion_stack.nbytes / (1024 ** 3)
                    print('     -size(GB) : ', imgSizeGB)
                    # GPU deconvolution
                    res_gpu = rl.doRLDeconvolutionFromNpArrays(fusion_stack, psf, niter=Niter,resAsUint8=False)
                    # Removing padding
                    fusion_stack = res_gpu[padding:-padding, padding:-padding, padding:-padding]
                                 
                if crop_for_calculations == True:
                    fusion_stack = insert_predicted_image(base_img,fusion_stack,bounds_min,bounds_max,scale)
                    
                outName = 'spim_TL'+str(timeId)+'_Channel'+str(color)+'_Angle'+angles[view]+'.tif'        
                custom_save_img(fusion_stack, outdir, outName, pixel_size_X,pixel_size_X, pixel_size_X)
  
  
            Elapsed_time = time.time() - start_time
            print(f"Elapsed Time: {Elapsed_time:.4f} seconds, image {outName}")                  
 

Processing image : embryo__2023_12_07__14_03_40_043(0).czi
** Processing view : 0
Getting tissue mask ...
     -threshold_value: 310.5
Elapsed Time masking: 100.4250 seconds
Processing color :  0




Elapsed Time: 106.7644 seconds, image spim_TL000_Channel0_Angle050.tif
Processing color :  1
scaleI :  36.367924528301884
Enhance Image : 
Predict imgNuclei : 


Predicting : 100%|█████████████████████████████████████████████████████████████████████████████| 28/28 [00:25<00:00,  1.08 planes/s]


Predict imgCells : 


Predicting : 100%|█████████████████████████████████████████████████████████████████████████████| 28/28 [00:25<00:00,  1.08 planes/s]


Iso imgNuclei : 


Predicting : 100%|█████████████████████████████████████████████████████████████████████████████| 43/43 [00:28<00:00,  1.51 planes/s]
Predicting : 100%|█████████████████████████████████████████████████████████████████████████████| 56/56 [00:28<00:00,  1.97 planes/s]


Iso imgCells : 


Predicting : 100%|███████████████████████████████████████████████████████████████████████████| 171/171 [01:53<00:00,  1.50 planes/s]
Predicting : 100%|███████████████████████████████████████████████████████████████████████████| 226/226 [01:54<00:00,  1.97 planes/s]


Scaling intensities ...
Saving ...




Elapsed Time: 716.6096 seconds, image spim_TL000_Channel5_Angle050.tif
** Processing view : 1
Getting tissue mask ...
     -threshold_value: 310.5
Elapsed Time masking: 89.0073 seconds
Processing color :  0




Elapsed Time: 91.5640 seconds, image spim_TL000_Channel0_Angle140.tif
Processing color :  1
scaleI :  72.65521064301552
Enhance Image : 
Predict imgNuclei : 


Predicting : 100%|█████████████████████████████████████████████████████████████████████████████| 25/25 [00:17<00:00,  1.41 planes/s]


Predict imgCells : 


Predicting : 100%|█████████████████████████████████████████████████████████████████████████████| 25/25 [00:17<00:00,  1.40 planes/s]


Iso imgNuclei : 


Predicting : 100%|█████████████████████████████████████████████████████████████████████████████| 33/33 [00:18<00:00,  1.76 planes/s]
Predicting : 100%|█████████████████████████████████████████████████████████████████████████████| 55/55 [00:19<00:00,  2.83 planes/s]


Iso imgCells : 


Predicting : 100%|███████████████████████████████████████████████████████████████████████████| 135/135 [01:17<00:00,  1.75 planes/s]
Predicting : 100%|███████████████████████████████████████████████████████████████████████████| 220/220 [01:17<00:00,  2.83 planes/s]


Scaling intensities ...
Saving ...




Elapsed Time: 661.2334 seconds, image spim_TL000_Channel5_Angle140.tif
** Processing view : 2
Getting tissue mask ...
     -threshold_value: 310.5
Elapsed Time masking: 98.7541 seconds
Processing color :  0




Elapsed Time: 92.6896 seconds, image spim_TL000_Channel0_Angle230.tif
Processing color :  1
scaleI :  99.59726443768997
Enhance Image : 
Predict imgNuclei : 


Predicting : 100%|█████████████████████████████████████████████████████████████████████████████| 33/33 [00:28<00:00,  1.14 planes/s]


Predict imgCells : 


Predicting : 100%|█████████████████████████████████████████████████████████████████████████████| 33/33 [00:28<00:00,  1.17 planes/s]


Iso imgNuclei : 


Predicting : 100%|█████████████████████████████████████████████████████████████████████████████| 42/42 [00:30<00:00,  1.37 planes/s]
Predicting : 100%|█████████████████████████████████████████████████████████████████████████████| 53/53 [00:30<00:00,  1.73 planes/s]


Iso imgCells : 


Predicting : 100%|███████████████████████████████████████████████████████████████████████████| 168/168 [02:03<00:00,  1.36 planes/s]
Predicting : 100%|███████████████████████████████████████████████████████████████████████████| 213/213 [02:03<00:00,  1.72 planes/s]


Scaling intensities ...
Saving ...




Elapsed Time: 880.8501 seconds, image spim_TL000_Channel5_Angle230.tif
** Processing view : 3
Getting tissue mask ...
     -threshold_value: 310.5
Elapsed Time masking: 83.1766 seconds
Processing color :  0




Elapsed Time: 63.9215 seconds, image spim_TL000_Channel0_Angle320.tif
Processing color :  1
scaleI :  43.458222811671085
Enhance Image : 
Predict imgNuclei : 


Predicting : 100%|█████████████████████████████████████████████████████████████████████████████| 27/27 [00:27<00:00,  1.03s/ planes]


Predict imgCells : 


Predicting : 100%|█████████████████████████████████████████████████████████████████████████████| 27/27 [00:27<00:00,  1.03s/ planes]


Iso imgNuclei : 


Predicting : 100%|█████████████████████████████████████████████████████████████████████████████| 43/43 [00:29<00:00,  1.44 planes/s]
Predicting : 100%|█████████████████████████████████████████████████████████████████████████████| 59/59 [00:29<00:00,  1.99 planes/s]


Iso imgCells : 


Predicting : 100%|███████████████████████████████████████████████████████████████████████████| 172/172 [01:56<00:00,  1.47 planes/s]
Predicting : 100%|███████████████████████████████████████████████████████████████████████████| 237/237 [01:58<00:00,  1.99 planes/s]


Scaling intensities ...
Saving ...




OSError: 2993356800 requested and 221880128 written

In [None]:
imgNuclei.shape

In [None]:
'''
scale_img = reslice(raw_img,'xy',reader.physical_pixel_sizes.X,reader.physical_pixel_sizes.Z)            
beads_image =  img * (~mask)
beads_image = insert_predicted_image(scale_img,beads_image,bounds_min,bounds_max,1.0)


outName = 'spim_TL'+str(timeId)+'_Channel'+str(color+2)+'_Angle'+angles[view]+'.tif'        
custom_save_img(beads_image, outdir, outName, reader.physical_pixel_sizes.X, reader.physical_pixel_sizes.Y, reader.physical_pixel_sizes.Z)
'''
