In [None]:
import matplotlib.pyplot as plt
from models import DFCAN
from loss_functions import mse_ssim
import tensorflow as tf
from csbdeep.io import load_training_data
from csbdeep.utils import plot_some
from pathlib import Path
import numpy as np
import cv2 
import os
from tensorflow.keras.models import load_model
import tifffile
from sim_fitting import cal_modamp, create_psf
from csbdeep.utils import normalize
import numpy.fft as F
import mrcfile
from pathlib import Path
from read_otf import read_otf

# Set up paths
root_dir = '../Microtubules'
dn_model_dir = Path(root_dir)/'DNModel'
sr_model_dir = Path(root_dir)/'SRModel_1400_ready'
# output_dir = Path(root_dir)/'DNModelOutput'
output_dir = Path.cwd() / 'DN_Model_plots_and_results'
Path(output_dir).mkdir( exist_ok=True)
otf_path = 'TIRF488_cam1_0_z30_OTF2d.mrc'

In [None]:
mrc_file = f'/share/klab/argha/Microtubules/Test/Cell_019/RawSIMData_level_05.mrc'
# first lets read teh file: 
with mrcfile.open(mrc_file, mode='r') as mrc:
    full_image = mrc.data

full_image = np.transpose(full_image, (1, 2, 0))
print(f'full_image _ DN: {full_image.shape} , min: {np.min(full_image)} , max: {np.max(full_image)} , dtype : {full_image.dtype}' )

def prctile_norm(x, min_prc=1, max_prc=99.9):
    y = (x-np.percentile(x, min_prc))/(np.percentile(x, max_prc)-np.percentile(x, min_prc)+1e-7)
    return y

full_image = prctile_norm(full_image)
print(f'full_image _ DN _ percentile normalized : {full_image.shape} , min: {np.min(full_image)} , max: {np.max(full_image)} , dtype : {full_image.dtype}' )

plt.imshow(full_image[:,:,0])
plt.title(f'full_image_ DN : shape: {full_image.shape}, \n min: {np.min(full_image)},\n max: {np.max(full_image)},\n dtype: {full_image.dtype}')
plt.axis('off')
plt.savefig(f'{output_dir}/DN_TEST_full_image_input.png', bbox_inches='tight')
plt.show()

height, width, channels = 128, 128, 9
print(f'height: {height} , width: {width} , channels: {channels}')


#####   Pparameter Initialization

In [None]:

init_lr = 1e-4
lr_decay_factor = 0.5	# Learning rate decay factor	

batch_size = 6
epochs = 1
beta_1=0.9
beta_2=0.999
wavelength = 0.488 
excNA = 1.35
dx = 62.6e-3
dy = dx
dxy = dx 
scale = 2.0
setupNUM = 1
space = wavelength/excNA/2
k0mod = 1 / space
napodize = 10
nphases = 3
ndirs = 3
sigma_x = 0.5
sigma_y = 0.5
recalcarrays = 2
ifshowmodamp = 0
norders = int((nphases + 1) / 2)
phase_space = 2 * np.pi / nphases

Nx = height
Ny = width
Nx_hr= Nx * scale
Ny_hr =  Ny * scale

[Nx_hr, Ny_hr] = [Nx* scale, Ny* scale] 
[dx_hr, dy_hr] = [x / scale for x in [dxy, dxy]]

xx = dx_hr * np.arange(-Nx_hr / 2, Nx_hr / 2, 1)
yy = dy_hr * np.arange(-Ny_hr / 2, Ny_hr / 2, 1)
[X, Y] = np.meshgrid(xx, yy)

dkx = 1.0 / ( Nx *  dxy)
dky = 1.0 / ( Ny * dxy)
dkr = np.min([dkx, dky])



setupNUM == 0

if setupNUM == 0:
     k0angle_c = [1.48, 2.5272, 3.5744]
     k0angle_g = [0.0908, -0.9564, -2.0036]  
if setupNUM == 1:
     k0angle_c = [-1.66, -0.6128, 0.4344]
     k0angle_g = [3.2269, 2.1797, 1.1325]      
if setupNUM == 2:
     k0angle_c = [1.5708, 2.618, 3.6652]
     k0angle_g = [0, -1.0472, -2.0944] 



parameters = {
    'Ny': height,
    'Nx': width,
    'wavelength':wavelength,
    'excNA':excNA,
    'ndirs':ndirs,
    'nphases':nphases,
    'init_lr': init_lr,
    'ifshowmodamp':ifshowmodamp,
    'batch_size': batch_size,
    'epochs': epochs,
    'beta_1':beta_1,
    'beta_2':beta_2,
    'scale_gt': scale,
    'setupNUM': setupNUM,
    'k0angle_c':k0angle_c,
    'k0angle_g':k0angle_g,
    'recalcarrays':recalcarrays,
    'dxy':dxy,
    'space':space,
    'k0mod':k0mod,
    'norders':norders,
    'napodize':napodize,
    'scale': scale,
    'sigma_x': sigma_x,
    'sigma_y': sigma_y,
    #'log_dir': log_dir,
    'den_model_dir': dn_model_dir,
    'sr_model_dir': sr_model_dir
    
}


#### Load the models


In [None]:

# Load the SR model
sr_model_path = sr_model_dir
if len(os.listdir(sr_model_path)) > 0:
    print(f"Loading model from {sr_model_path}")
    with tf.keras.utils.custom_object_scope({'mse_ssim': mse_ssim}):
        sr_trained_model = load_model(sr_model_path)

# Load the DN model
dn_model_path = dn_model_dir
if len(os.listdir(dn_model_path)) > 0:
    print(f"Loading model from {dn_model_path}")
    with tf.keras.utils.custom_object_scope({'mse_ssim': mse_ssim}):
        dn_trained_model = load_model(dn_model_path)

: 

#### step:1 : CHUNK THE IMAGE

In [None]:

def chunk_image(image, chunk_size):

    chunks = []
    chunk_coords = []
    image_height, image_width = image.shape[:2]

    # Iterate over the image with steps of chunk_size
    for y in range(0, image_height, chunk_size):
        for x in range(0, image_width, chunk_size):
            #print(y , x)
            # Calculate end coordinates
            y_end = min(y + chunk_size, image_height)
            x_end = min(x + chunk_size, image_width)
            #print(y_end, x_end)
            if y== 384:
              y = 502-128
            if x == 384:
              x = 502-128


            # Extract chunk
            chunk = image[y:y_end, x:x_end]
            # chunk = prctile_norm(chunk)
            #print(chunk.shape)
            chunks.append(chunk)
            chunk_coords.append((x, y))

    return chunks, chunk_coords

resized_image = full_image

chunk_size = 128
chunks, chunk_coords = chunk_image(resized_image, chunk_size)
print(f'after chunkinh: {len(chunks)} :: {len(chunk_coords)} :: {type(chunks)} {type(chunk_coords)}')
chunks= np.array(chunks).astype(np.float32)

print(f'chunks: {chunks.shape} :: {chunks.dtype} :: {type(chunks)}')



# Visualize the chunks in a grid layout
num_chunks = chunks.shape[0]
ncols = int(np.ceil(np.sqrt(num_chunks)))
nrows = int(np.ceil(num_chunks / ncols))
plt.figure(figsize=(10, 10))
plt.suptitle('Chunks')
print(f'chunks: {chunks.shape} {type(chunks)} : chunk_coords: {len(chunk_coords)}{type(chunk_coords)} ')
for i, (chunk, (x_start, y_start)) in enumerate(zip(chunks, chunk_coords)):
    print(f'chunk.shape: {chunk.shape} :: {type(chunk)} , {x_start}, {y_start}')
    print(f'min value: {np.min(chunk)} :: max value: {np.max(chunk)} dtype: {chunk.dtype}')
    plt.subplot(nrows, ncols, i + 1)
    plt.imshow(chunk[...,1])
    plt.title(f"({x_start}, {y_start}):S {chunk.shape}", fontsize=8)
    plt.axis('off')
    plt.savefig(f'{output_dir}/DN_TEST_chunk_.png')
    # plt.tight_layout()
plt.show()

#### Step 2:  DO THE DN PREDICTIONS

In [None]:

def upscale_chunks(chunks, chunk_coords, upscale_factor=2):


    upscaled_chunks = []
    upscaled_chunk_coords = []
    # will be feed into prediction from SR
    chunks = np.array(chunks)
    print(chunks.shape)
    predictions = sr_trained_model.predict(chunks)

    OTF, prol_OTF, PSF = read_otf(otf_path, Nx_hr, Ny_hr, dkx, dky, dkr)
     
    print(f'PSF: {PSF.shape} {PSF.dtype}:: OTF: {OTF.shape} {OTF.dtype}')     

    def psf_otf ():
        fig, axes = plt.subplots(1, 2, figsize=(15, 15))
        axes[0].imshow(PSF)
        axes[0].set_title('PSF')
        
        axes[1].imshow(abs(OTF))
        axes[1].set_title('OTF')

        plt.tight_layout()
        plt.savefig(f'{output_dir}/PSF_OTF_Prediction.png', bbox_inches='tight')
        plt.show()

    psf_otf ()

    def _get_cur_k( image_gt):

        print(f'inside _get_cur_k: {image_gt.shape}')
        
        cur_k0, modamp = cal_modamp(np.array(image_gt).astype(np.float32), prol_OTF, parameters)
        cur_k0_angle = np.array(np.arctan2(cur_k0[:, 1] , cur_k0[:, 0]))
        cur_k0_angle[1:parameters['ndirs']] = cur_k0_angle[1:parameters['ndirs']] + np.pi
        cur_k0_angle = -(cur_k0_angle - np.pi/2)
        for nd in range(parameters['ndirs']):
            if np.abs(cur_k0_angle[nd] - parameters['k0angle_g'][nd]) > 0.05:
                cur_k0_angle[nd] = parameters['k0angle_g'][nd]
        cur_k0 = np.sqrt(np.sum(np.square(cur_k0), 1))
        given_k0 = 1 / parameters['space']
        cur_k0[np.abs(cur_k0 - given_k0) > 0.1] = given_k0
            
        
        return cur_k0, cur_k0_angle, modamp


    def reshape_to_3_channels( batch):
        
        B, H, W, C = batch.shape
        assert C % ndirs == 0, "The last dimension must be divisible by 3"
        new_batch_size = B * (C // ndirs)
        return batch.reshape(new_batch_size, H, W, nphases)
 

    def _phase_computation( img_SR, modamp, cur_k0_angle, cur_k0):
           

            phase_list = -np.angle(modamp)
            print(f'phase_list: {phase_list},  {len(phase_list)}')
            img_gen = []
            for d in range(ndirs):
                alpha = cur_k0_angle[d]
                
                for i in range(nphases):
                    kxL = cur_k0[d] * np.pi * np.cos(alpha)
                    kyL = cur_k0[d] * np.pi * np.sin(alpha)
                    kxR = -cur_k0[d] * np.pi * np.cos(alpha)
                    kyR = -cur_k0[d] * np.pi * np.sin(alpha)
                    phOffset = phase_list[d] + i * phase_space
                    interBeam = np.exp(1j * (kxL * X + kyL * Y + phOffset)) + np.exp(1j * (kxR * X + kyR * Y))
                    pattern = normalize(np.square(np.abs(interBeam)))
                    patterned_img_fft = F.fftshift(F.fft2(pattern * img_SR)) * OTF
                    modulated_img = np.abs(F.ifft2(F.ifftshift(patterned_img_fft)))
                    modulated_img = normalize(cv2.resize(modulated_img, (Ny, Nx)))    
                    img_gen.append(modulated_img)

                    ####################    Plotting the patterns modulated image    ############################

                    # print(f'pattern shape: {pattern.shape} min {pattern.min()} max {pattern.max()} dtype {pattern.dtype}')
                    
                    # if d in [0, 1, 2]:
                    #     plt.figure(figsize=(25, 12))  # Adjust size as needed
                    #     plt.subplot(1, 5, 1)
                    #     plt.imshow(img_SR)
                    #     plt.title(f'img_SR shape: {img_SR.shape} \n min {np.min(img_SR)} \n max {np.max(img_SR)} \n dtype {img_SR.dtype}')
                    #     plt.axis('off')

                    #     plt.subplot(1, 5, 2)
                    #     plt.imshow(pattern,)
                    #     plt.title(f'pattern_{d} shape: {pattern.shape} \n min {pattern.min()}\n max {pattern.max()} \n dtype {pattern.dtype}')
                    #     plt.axis('off')

                    # # patterned_img_fft = F.fftshift(F.fft2(pattern * img_SR)) * OTF

                    # # print(f'patterned_img_fft shape: {patterned_img_fft.shape} min {patterned_img_fft.min()} max {patterned_img_fft.max()} dtype {patterned_img_fft.dtype}')
                    # # patterned_img_fft shape: (128, 128) min (-75.88900177706913-118.26764698032483j) max (1058.9433023205993+0j) dtype complex128
                    # if d in [0, 1, 2]:
                    #     plt.subplot(1, 5, 3)
                    #     # Compute magnitude
                    #     magnitude = np.abs(patterned_img_fft)

                    #     plt.imshow(magnitude , cmap = 'gray')  # Use 'gray' colormap for grayscale
                    #     plt.title(f' patterned_img_fft_{d} : Magnitude shape: {magnitude.shape}\nmin: {magnitude.min()}\nmax: {magnitude.max()}')
                    #     plt.axis('off')  # Hide axis ticks and labels
                    #     # plt.colorbar()  # Optionally add a color bar to indicate scale
                    #     # plt.show()
                
                
                    #     plt.subplot(1, 5, 4)
                    #     phase = np.angle(patterned_img_fft)
                    #     plt.imshow(phase, cmap='gray' )  # 'hsv' colormap shows phase as color
                    #     plt.title(f'patterned_img_fft_{d} : Phase shape: {phase.shape}\nmin: {np.min(phase)}\nmax: {np.max(phase)}')
                    #     plt.axis('off')  # Hide axis ticks and labels
                    #     print(f'modulated_img shape: {modulated_img.shape} min {modulated_img.min()} max {modulated_img.max()} dtype {modulated_img.dtype}')
                    #     plt.subplot(1, 5, 5)
                    #     plt.imshow(modulated_img)
                        
                    #     plt.title(f'modulated_img_{d} shape: {modulated_img.shape} \nmin {modulated_img.min()} \nmax {modulated_img.max()} \ndtype {modulated_img.dtype}')


                    #     plt.axis('off')
                    #     plt.tight_layout()  # Adjust the subplots to fit into the figure

                    #     # Save and show the figure for the current phase
                    #     print('\n\n\n\n\n\n\n')
                    #     print('all plots have been saved')
                    #     plt.savefig(f'{output_dir}/DN_Test_phase_{d}_patterns.png')  # Save the figure
                    
                    

            
            
            img_gen = np.asarray(img_gen)
      
            
            return img_gen
    ##############################################################################################################


    input_PFE_batch = []
    input_MPE_batch = []
    list_image_in = []
    list_image_gen = []
    print(f'this is goign for prediction: {chunks.shape}')
    sr_predictions = sr_trained_model.predict(chunks)
    print(f'sr prediction: {sr_predictions.shape} : min_value : {np.min(sr_predictions)} :: max_value : {np.max(sr_predictions)} : dtype : {sr_predictions.dtype}') # sr prediction: (15, 256, 256)
    sr_predictions = tf.squeeze(sr_predictions, axis=-1)
    print(f'sr prediction: {sr_predictions.shape}') # sr prediction: (15, 256, 256)



    for i in range(sr_predictions.shape[0]):
        image_in = chunks[i:i+1][0]  
        print(f' from for loop  image_in: {image_in.shape}')
        list_image_in.append(image_in)

        # for image gen
        img_SR = sr_predictions[i:i+1][0]
        print(f'from for loop image_SR: {img_SR.shape}')
        # cur_k0, cur_k0_angle, modamp = self._get_cur_k(image_gt=image_in)
        # cur_k0, modamp = cal_modamp(np.array(image_in).astype(np.float32), OTF, parameters)
        cur_k0, cur_k0_angle, modamp = _get_cur_k(image_gt=image_in) # here is confussion ????

        image_gen = _phase_computation(img_SR, modamp, cur_k0_angle, cur_k0)
   
        print(f'image_gen: {image_gen.shape}  from phase compution')
        image_gen = np.transpose(image_gen, (1, 2, 0))
        list_image_gen.append(image_gen)

    input_PFE_batch = np.asarray(list_image_in)
    input_PFE_batch = reshape_to_3_channels(input_PFE_batch)

    input_MPE_batch = np.asarray(list_image_gen)
    input_MPE_batch = reshape_to_3_channels(input_MPE_batch)
    ## plotting MPE BRANCH : moire pattern extracting: happens in DN
    num_images = input_MPE_batch.shape[0]  # Number of images in the batch
    plt.figure(figsize=(15, 5))  # Create a figure to hold the images, adjust figsize as needed
    plt.title(f'MPE batch : shape : {input_MPE_batch.shape}  ')
    
    for i in range(num_images):
        plt.subplot(1, num_images, i + 1)  # Create subplots for each image
        plt.imshow(input_MPE_batch[i])  # Display the i-th image
        plt.axis('off')  # Turn off axis labels and ticks
        plt.suptitle(f'MPE : {input_PFE_batch[i].shape}')
    
    plt.savefig(f'{output_dir}/DN_Test_input_MPE_batch.png')  # Save the figure
    plt.show()
 


    ## plotting PFE Branch: pattern formation: happens in SR
    num_images = input_PFE_batch.shape[0]  # Number of images in the batch
    plt.figure(figsize=(15, 5))  # Create a figure to hold the images, adjust figsize as needed
    plt.title(f'PFE batch : shape : {input_PFE_batch.shape}  ')

    for i in range(num_images):
        plt.subplot(1, num_images, i + 1)
        plt.imshow(input_PFE_batch[i])
        plt.axis('off')
       
        plt.suptitle(f'PFE : {input_PFE_batch[i].shape}')
    plt.savefig(f'{output_dir}/DN_Test_input_PFE_batch.png')
    plt.show()

    # input MPE (90, 128, 128, 3), input PFE (90, 128, 128, 3),gt (90, 128, 128, 3)
    print(f'inptu_MPE_batch: {input_MPE_batch.shape} :: input_PFE_batch: {input_PFE_batch.shape}')

    def reshape_to_9_channels( batch):
        
        B, H, W, C = batch.shape
        print(f'B: {B} , H: {H}, W: {W} , C: {C} ')
        
        new_batch_size = int(B / (ndirs * nphases / C))
        print(f' new_batch_size : { new_batch_size} B: {B} , H: {H}, W: {W} , C: {C} ')
        return batch.reshape(new_batch_size, H, W, ndirs * nphases)



    predictions = dn_trained_model.predict([input_MPE_batch, input_PFE_batch]) # prediction : (45, 128, 128, 3)
    #predictions = np.squeeze(predictions)
    print(f'prediction : {predictions.shape}') # prediction : (45, 128, 128, 3) --Y 15, 3 , 128, 128, 3
    predictions = reshape_to_9_channels(predictions)
    print(f'prediction reshape to channel 9 : {predictions.shape}\n : min_value : {np.min(predictions)}\n :: max_value : {np.max(predictions)}\n : dtype : {predictions.dtype}') # prediction : (45, 128, 128, 3) --Y 15, 3 , 128, 128, 3

#### RE-ASSEMBLE THE IMAGE

In [None]:
target_size = 1004
channels = 9  # Number of channels in the image

# Initialize the final image with zeros
final_image = np.zeros((target_size, target_size, channels))

for i, (upscaled_chunk, (x_start, y_start)) in enumerate(zip(upscaled_chunks, upscaled_chunk_coords)):
    # # Ensure the chunk has a third dimension (channels)
    # if len(upscaled_chunk.shape) == 2:  # Shape is (256, 256)
    #     upscaled_chunk = np.expand_dims(upscaled_chunk, axis=-1)  # Shape becomes (256, 256, 1)

    # # Now, safely check the number of channels
    # if upscaled_chunk.shape[2] != channels:
    #     raise ValueError(f"Chunk has {upscaled_chunk.shape[2]} channels, expected {channels} channels.")
    print('inside the  reassemble loop')
    print(f'upscaled_chunk: {upscaled_chunk.shape}  :: {type(upscaled_chunk)} dtype: {upscaled_chunk.dtype} :: min_value: {np.min(upscaled_chunk)} :: max_value: {np.max(upscaled_chunk)}')
    
    x_end = min(x_start + upscaled_chunk.shape[1], target_size)
    y_end = min(y_start + upscaled_chunk.shape[0], target_size)

    # Ensure that x_end and y_end are valid indices
    if x_end > x_start and y_end > y_start:
        final_image[y_start:y_end, x_start:x_end, :] = np.maximum(
            final_image[y_start:y_end, x_start:x_end, :],
            upscaled_chunk[:y_end-y_start, :x_end-x_start, :]
        )

tifffile.imwrite(f'{output_dir}/prediction _ fullimage.tif', final_image.transpose(1,2,0))
# Visualize the final image
plt.figure(figsize=(10, 10))
plt.title(f'Reassembled Image _DN  :: shape :: {final_image.shape} \n final_imahge:dtype :: {final_image.dtype} :: \n min_value: {np.min(final_image)} ::\n max_value: {np.max(final_image)}')
plt.imshow(final_image[...,1])

plt.axis('off')
plt.savefig(f'{output_dir}/DN_Test_final_image.png')
plt.show()