In [None]:
from Supporting_functions import *
from WBNS import WBNS_image
import RedLionfishDeconv as rl

import scipy.ndimage as ndi
from aicsimageio.readers import CziReader
from aicsimageio.writers import OmeTiffWriter
import re

from skimage import measure  

import os
import numpy as np
import tifffile
from tqdm import tqdm

In [None]:
# Input

## Generate initial tiffs form czi 

# Initial data extraction , if False, use pre-cropped images
extractImagesFromCzi = False                                               
czi_folder_path = '/media/hmorales/Elements/Cornelius/20230929'
ChannelId2Export = 0
blurWnd = 2 # for masking
thres_crop = 1.5
dataexperiment = '20230929'


## Tiff source folders

# Source paths for the tiff images if extractImagesFromCzi== True, the folders will be created
srcpath = r'/media/hmorales/Skynet/IsoNet/SourceData/'
psf_path = r'/home/hmorales/WorkSpace/DataIsoReconstructions/Averaged_transformed_PSF_488.tif'

dirOut    = r'/media/hmorales/Skynet/IsoNet/Models/Self_Nuclei_Membrane/'

dirSource = srcpath + 'Nuclei_Membrane_20230929_cleaned/'
dirTarget = srcpath + 'NucleiPlusMembrane_deconvolved/' 

tempScale = 0.23161465

## Generate image planes for training data 

## Process Target image
processTargetImages = False
# BG subtraction
resolution_px = 0 # FWHM of the PSF
resolution_pz = 0
noise_lvl = 2
# deconvolution
padding = 32
Niter = 2
# post processing
sigmaLoG = 2.0
sigmaLoGAddScale = 20.0
sigma = 0.8
# image normalization
thres_scale_target = 2.0 # threhold for simple mask for normalization
percentiles_target = (20, 99.995)
min_v_target = 0
max_v_target = 65535
# save processed images
dirProcTarget = srcpath + 'Nuclei_Membrane_20230929_proc/' 

## Process Source image
# image normalization
thres_scale_source = 1.5  # threhold for simple mask for normalization
percentiles_source = (20, 99.999)
min_v_source = 0
max_v_source = 65535



## Parameters for the patch selection
patch_size = 128
stride = 128
signal_intensity_threshold = 2000  #parameter for selecting image patches containing signals

xy_interval=2
xz_interval=8

# Generate initial tiffs from czi 

In [None]:
if extractImagesFromCzi == True:

    # Create folders

    createFolder(dirSource)
    createFolder(dirTarget)

    # Get all tif images in the folder
    image_names = [f for f in os.listdir(czi_folder_path) if f.endswith('.czi')]

    for i, image_name in enumerate(image_names):

        start_time = time.time()  # Record the start time 

        # get image path 
        czi_file_path = os.path.join(czi_folder_path, image_name)   
        reader = CziReader(czi_file_path)


        match = re.search(r'\((\d+)\)', image_name)
        timeId = match.group(1)
        timeId= timeId.zfill(3)
        print('spim_TL'+str(timeId))

        for view in range(reader.dims.V): 
            for color in range(reader.dims.C): 

                start_time = time.time()  # Record the start time 

                # Process only the channel of interest
                if color == ChannelId2Export:

                    # Open image
                    lazy_t0 = reader.get_image_dask_data("ZYX", V=view, C=color)  # returns 3D ZYX numpy array
                    img = lazy_t0.compute()  # returns in-memory 3D numpy array
                    img = img.astype(np.uint16)
                    scale = reader.physical_pixel_sizes.X / reader.physical_pixel_sizes.Z
                    print(img.shape)

                    # Crop for calculations
                    bounds_min, bounds_max = get_image_cropping_box(img, blurWnd, scale, thres_crop)          
                    img = img[bounds_min[0]:bounds_max[0], bounds_min[1]:bounds_max[1], bounds_min[2]:bounds_max[2]]
                    print(img.shape)

                    # Make image isotropic
                    img_shape = img.shape
                    img = reslice(img,'xy',reader.physical_pixel_sizes.X,reader.physical_pixel_sizes.Z)
                    new_img_shape = img.shape   
                    new_physical_pixel_sizeZ = img_shape[0] * reader.physical_pixel_sizes.Z / new_img_shape[0]
                    print(f"image dimension from : {img_shape} to {new_img_shape}")
                    print(f"z-space from : {reader.physical_pixel_sizes.Z} to {new_physical_pixel_sizeZ}") 

                    # Save image
                    img = img.astype(np.uint16)
                    outName = 'spim'+dataexperiment+'_TL'+str(timeId)+'_Channel'+str(color)+'_Angle'+str(view)+'.tif'
                    img_out = os.path.join(dirSource, outName)            
                    tifffile.imwrite(      
                        img_out,
                        img,
                        imagej=True, 
                        bigtiff=True,
                        resolution=(1.0/reader.physical_pixel_sizes.X, 1.0/reader.physical_pixel_sizes.Y), 
                        metadata={'spacing': new_physical_pixel_sizeZ, 'unit': 'um', 'axes': 'ZYX'})

        

        Elapsed_time = time.time() - start_time
        print(f"Elapsed Time: {Elapsed_time:.4f} seconds, image {image_name}")   

# Generate image planes for training data 

In [None]:
## Create folders to export data
outDir = dirOut+"raw_data/"
xy_data = outDir+"xy/"
xy_lr_data = outDir+"xy_lr/"
xz_data = outDir+"xz/"
   
createFolder(dirOut)
createFolder(outDir)
createFolder(xy_data)
createFolder(xy_lr_data)
createFolder(xz_data)

## Generate image planes for target and intermediate images

In [None]:
# Get all tif images in the target folder, process and export them
image_names = sorted([f for f in os.listdir(dirTarget) if f.endswith('.tif')])

if processTargetImages == True:
    #Creat output folder
    createFolder(dirProcTarget)
    # Open PSF and Prepare PSF
    psf = tifffile.imread(psf_path)
    psf_f = psf.astype(np.float32)
    psf_norm = psf_f/psf_f.sum()

    
count = 1
for i, image_name in tqdm(enumerate(image_names)):

    start_time = time.time()  # Record the start time 
    print(f"** Processing image : {image_name}")
    
    # Open image and get metadata
    img_path = os.path.join(dirTarget, image_name)   
    img = tifffile.imread(img_path)
    img_shape = img.shape
    [physical_pixel_sizeX,physical_pixel_sizeY,physical_pixel_sizeZ] = read_tiff_voxel_size(img_path)    
    scale = physical_pixel_sizeX / physical_pixel_sizeZ
    
    # Get mask
    mask = get_image_simple_mask(img, sigma, scale, thres_scale_target)  
    mask =  mask.astype(np.int16)
    
    if processTargetImages == True:
        # Make image isotropic
        if abs(1.0-scale) > 1e-4: 
            img = reslice(img,'xy',physical_pixel_sizeX,physical_pixel_sizeZ)
        img = img.astype(np.float32)
        new_img_shape = img.shape   
        new_physical_pixel_sizeZ = img_shape[0] * physical_pixel_sizeZ / new_img_shape[0]
        print(f"     -image dimension from : {img_shape} to {new_img_shape}")
        print(f"     -z-space from : {physical_pixel_sizeZ} to {new_physical_pixel_sizeZ}")
        
        # Deconvolution
        if Niter > 0: 
            # Padding image
            img = np.pad(img, padding, mode='reflect')
            imgSizeGB = img.nbytes / (1024 ** 3)
            print('     -size(GB) : ', imgSizeGB)
            # GPU deconvolution
            res_gpu = rl.doRLDeconvolutionFromNpArrays(img, psf, niter=Niter,resAsUint8=False)
            # Removing padding
            img = res_gpu[padding:-padding, padding:-padding, padding:-padding]

        # Remove noise and BG
        if resolution_px > 0:
            img = WBNS_image(img, resolution_px, noise_lvl)
            if resolution_pz > 0:
                img_xz=np.transpose(img,[1,0,2])
                img_xz = WBNS_image(img_xz, resolution_pz, 0)
                img = np.transpose(img_xz,[1,0,2])
        # LoG filter
        if sigmaLoG > 0 :
            imgBorders = ndi.gaussian_laplace(img, sigmaLoG)
            imgBorders *= -1.0
            #imgBorders[imgBorders < 0] = 0
            imgBorders *= sigmaLoGAddScale
            img += imgBorders
        
        # Smooth
        if sigma > 0:
            img = ndi.gaussian_filter(img, sigma)
        
        # Image Normalization
        if percentiles_target[0] > 0 or percentiles_target[1] < 100:
            low_thres, high_thres0 = getNormalizationThresholds(img, percentiles_target) # low thres in whole image
            low_thres0, high_thres = getNormalizationThresholds(img * mask, percentiles_target) # high thres in FG
            img = remove_outliers_image(img, low_thres, high_thres)

        img = image_scaling(img, min_v_target, max_v_target, True)
        img = img.astype(np.uint16)
        
        # Save processed image
        img_out_name = os.path.join(dirProcTarget, image_name)            
        tifffile.imwrite(      
            img_out_name,
            img,
            imagej=True, 
            resolution=(1.0/physical_pixel_sizeX, 1.0/physical_pixel_sizeY), 
            metadata={'spacing': new_physical_pixel_sizeZ, 'unit': 'um', 'axes': 'ZYX'})
        
    # Generate intermediate images down sample and then upsample
    img_lr = np.zeros_like(img)
    z,y,x = img.shape    
    new_y = round(y * tempScale)
    new_x = round(x * tempScale)
    
    for i in range(z):
        temp_img  = cv2.resize(img[i,:,:],(new_x,new_y),interpolation=cv2.INTER_CUBIC)
        #img_lr[i,:,:] = cv2.resize(temp_img,(x,y),interpolation=cv2.INTER_CUBIC)  
        img_lr[i,:,:] = cv2.resize(temp_img,(x,y),interpolation=cv2.INTER_LINEAR) 
            
    # Export planes for target and intermediate images, each plane as a TIFF image
    temp0 = img * mask
    temp = temp0.flatten() 
    temp = temp[temp > 0]
    len0 = len(temp)
    meanI = np.mean(temp)
 
    
    for i in range(z):
        # Get plane stats
        temp = temp0[i,:,:]
        temp = temp.flatten() 
        temp = temp[temp > 0]
        
        if len(temp) > 0.1 * len0 / z :
            if  np.mean(temp) > meanI:              
                outName_target = f"{xy_data}{count}.tif"
                outName_interm = f"{xy_lr_data}{count}.tif"
                tifffile.imwrite(outName_target, img[i,:,:])
                tifffile.imwrite(outName_interm, img_lr[i,:,:])      
                count += 1
  
    Elapsed_time = time.time() - start_time
    print(f"Elapsed Time: {Elapsed_time:.4f} seconds, image {image_name}, {count-1} images exported ") 

## Generate image planes for source images

In [None]:
# Get all tif images in the dirSource folder, process and export them
image_names = sorted([f for f in os.listdir(dirSource) if f.endswith('.tif')])

count = 1
for i, image_name in tqdm(enumerate(image_names)):

    start_time = time.time()  # Record the start time 
    print(f"** Processing image : {image_name}")
    
    # Open image and get metadata
    img_path = os.path.join(dirSource, image_name)   
    img = tifffile.imread(img_path)
    img_shape = img.shape
    [physical_pixel_sizeX,physical_pixel_sizeY,physical_pixel_sizeZ] = read_tiff_voxel_size(img_path)    
    scale = physical_pixel_sizeX / physical_pixel_sizeZ
 
    # Get mask
    mask = get_image_simple_mask(img, sigma, scale, thres_scale_source)  
    mask =  mask.astype(np.int16)
    
    # Image Normalization
    if percentiles_source[0] > 0 or percentiles_source[1] < 100:
        low_thres, high_thres0 = getNormalizationThresholds(img, percentiles_source) # low thres in whole image
        low_thres0, high_thres = getNormalizationThresholds(img * mask, percentiles_source) # high thres in FG
        img = remove_outliers_image(img, low_thres, high_thres)
  
    img = image_scaling(img, min_v_source, max_v_source, True)
    img = img.astype(np.uint16)

    # reslice Image
    img_xz = reslice(img,'xz',physical_pixel_sizeX,physical_pixel_sizeZ)
    
    z,y,x = img_xz.shape 
    
    for i in range(z):
        outName_target = f"{xz_data}{count}.tif"
        tifffile.imwrite(outName_target, img_xz[i,:,:])
        count += 1
  
    Elapsed_time = time.time() - start_time
    print(f"Elapsed Time: {Elapsed_time:.4f} seconds, image {image_name}, {count-1} images exported ")    
    

# Generate training data from image planes

In [None]:
# Create folder output 
train_data_path = os.path.join(dirOut, 'train_data/')
createFolder(train_data_path)

In [None]:
# Initialize arrays

xy = []
xy_lr = []
xz = []

# Loop over lateral images
file_list_xy = os.listdir(xy_data)
for i in tqdm(range(0, len(file_list_xy), xy_interval)):
    xy_img = tifffile.imread(xy_data + str(i + 1) + '.tif')
    xy_lr_img = tifffile.imread(xy_lr_data + str(i + 1) + '.tif')
    L0 = min(xy_img.shape[0], xy_lr_img.shape[0])
    L1 = min(xy_img.shape[1], xy_lr_img.shape[1])
    for m in range(0, L0 - patch_size + 1, stride):
        for n in range(0, L1 - patch_size + 1, stride):
            crop_xy    =    xy_img[m:m + patch_size, n:n + patch_size]
            crop_xy_lr = xy_lr_img[m:m + patch_size, n:n + patch_size]
            
            if np.max(crop_xy) >= signal_intensity_threshold:
                xy.append(crop_xy)
                xy_lr.append(crop_xy_lr)

# Loop over axial images   
file_list_xz = os.listdir(xz_data)
for i in tqdm(range(0, len(file_list_xz), xz_interval)):
    xz_img = tifffile.imread(xz_data + str(i + 1) + '.tif')
    for m in range(0, xz_img.shape[0] - patch_size + 1, stride):
        for n in range(0, xz_img.shape[1] - patch_size + 1, stride):
            crop_xz = xz_img[m:m + patch_size, n:n + patch_size]

            if np.max(crop_xz) >= signal_intensity_threshold:
                xz.append(crop_xz)


In [None]:
print(len(xy))
print(len(xy_lr))
print(len(xz))


In [None]:
# Convert to arays and save

xy = np.array(xy, dtype=np.float32)
xy_lr = np.array(xy_lr, dtype=np.float32)
xz = np.array(xz, dtype=np.float32)
print(xy.shape, xy_lr.shape, xz.shape)

np.savez(os.path.join(train_data_path, 'train_data.npz'), xy=xy, xy_lr=xy_lr, xz=xz)

In [None]:
#Save tiff to double check

import tifffile as tiff
tiff.imwrite(os.path.join(train_data_path,'xy.tif'), xy)
tiff.imwrite(os.path.join(train_data_path,'xy_lr.tif'), xy_lr)
tiff.imwrite(os.path.join(train_data_path,'xz.tif'), xz)