In [None]:
# Step 0: Imports for the full reconstruction pipeline 


# From SIMScope3D Reconstruction
from pathlib import Path
import numpy as np
import tifffile
from skimage.exposure import match_histograms, rescale_intensity
from skimage.restoration import calibrate_denoiser, denoise_tv_chambolle
import matplotlib.pyplot as plt
from tqdm import tqdm
import os


# From VolPy/Mesmerize 
import cv2
import glob
import h5py
import logging
import inspect

#os.environ["MESMERIZE_N_PROCESSES"] = '40'

try:
    cv2.setNumThreads(0)
except:
    pass

try:
    if __IPYTHON__:
        # this is used for debugging purposes only. allows to reload classes
        # when changed
        get_ipython().magic('load_ext autoreload')
        get_ipython().magic('autoreload 2')
except NameError:
    pass

import caiman as cm
from caiman.motion_correction import MotionCorrect
from caiman.paths import caiman_datadir
from caiman.source_extraction.volpy import utils
from caiman.source_extraction.volpy.volparams import volparams
from caiman.source_extraction.volpy.volpy import VOLPY
from caiman.summary_images import local_correlations_movie_offline
from caiman.summary_images import mean_image
from caiman.utils.utils import download_demo, download_model

logging.basicConfig(format=
                    "%(relativeCreated)12d [%(filename)s:%(funcName)20s():%(lineno)s]" \
                    "[%(process)d] %(message)s",
                    level=logging.INFO)



# From DeepCAD-RT (Python 3.9.18)
from deepcad.train_collection import training_class, testing_class
from deepcad.movie_display import display
from deepcad.utils import get_first_filename,download_demo

In [None]:
# Step 1: Seperate raw stack into individual stacks for each phase

def seperate_phase(imgs):
    p1 = imgs[0::3,:,:]
    p2 = imgs[1::3,:,:]
    p3 = imgs[2::3,:,:]
    return p1, p2, p3

def save_phases(filename):
    img = tifffile.imread(filename)
    p1,p2,p3 = seperate_phase(img)
    tifffile.imsave(filename+'_p1.tif', p1)
    tifffile.imsave(filename+'_p2.tif', p2)
    tifffile.imsave(filename+'_p3.tif', p3)

if False:
    filename = 'O:\\workingdirectory\\072623_PVG8\\best_sofar\\\SIM900_b4_GOOD.tif'
    save_phases(filename)


In [None]:
# Step 2: Motion Correct Each Phase

# NoRMCorre implementation (adapted from VolPy)
def bulk_MC(filepath,save_name):
    print("bulk")
    
    fnames = os.fsdecode(filepath)
    file_dir = os.path.split(fnames)[0]
    

    fr = 500                                    
    ROIs = None                                   
    index = None                                  
    weights = None                                  
                                                    
    # Motion correction parameters
    pw_rigid = False                              
    gSig_filt = (6, 6)                            
                                                   
    max_shifts = (10, 10)                           
    strides = (20, 20)                             
    overlaps = (24, 24)                           
    max_deviation_rigid = 3                        
    border_nan = 'copy'

    opts_dict = {
        'fnames': fnames,
        'fr': fr,
        'index': index,
        'ROIs': ROIs,
        'weights': weights,
        'pw_rigid': pw_rigid,
        'max_shifts': max_shifts,
        'gSig_filt': gSig_filt,
        'strides': strides,
        'overlaps': overlaps,
        'max_deviation_rigid': max_deviation_rigid,
        'border_nan': border_nan
    }

    opts = volparams(params_dict=opts_dict)

    if 'dview' in locals():
        cm.stop_server(dview=dview)

    c, dview, n_processes = cm.cluster.setup_cluster(
            backend='local', n_processes=50, single_thread=False)

    print("Motion COrrection")
    # %%% MOTION CORRECTION
    # first we create a motion correction object with the specified parameters
    mc = MotionCorrect(fnames, dview=dview, **opts.get_group('motion'))
    mc.motion_correct(save_movie=True)
    dview.terminate()

    # Save MC file
    m_rig = cm.load(mc.mmap_file)
    m_rig.save(save_name+'_mc_.tif', to32 = False)
    print(inspect.signature(m_rig.save))
    plt.imshow(mc.total_template_rig, cmap = 'gray')
    #plt.imsave(mc.total_template_rig)   
    if 'dview' in locals():
        cm.stop_server(dview=dview)


if False:
    filepath = '101123\\CY01\\100N36_16x\\_6\\p1.tif'
    save_name = 'p1_mc_'
    bulk_MC(filepath,save_name)


    filepath = '101123\\CY01\\100N36_16x\\_6\\p2.tif'
    save_name = 'p2_mc_'
    bulk_MC(filepath,save_name)

    filepath = '101123\\CY01\\100N36_16x\\_6\\p2.tif'
    save_name = 'p2_mc_'
    bulk_MC(filepath,save_name)




In [None]:
# Step 3: Train 3D-UNet on previous WF recording

# 3D U-Net Training, adapted from demo_train_pipeline.ipynb (DeepCAD-RT)
datasets_path = 'datasets/WF_500_hip'
n_epochs = 5                # number of training epochs
GPU = '0'                   # the index of GPU you will use (e.g. '0', '0,1', '0,1,2')
train_datasets_size = 12000  # datasets size for training (how many 3D patches)
patch_xy = 100              # the width and height of 3D patches
patch_t = 1000              # the time dimension (frames) of 3D patches
overlap_factor = 0.6        # the overlap factor between two adjacent patches
pth_dir = './pth'           # the path for pth file and result images 
num_workers = 0             # if you use Windows system, set this to 0.

# Setup some parameters for result visualization during training period (optional)
visualize_images_per_epoch = True  # whether to show result images after each epoch
save_test_images_per_epoch = True  # whether to save result images after each epoch


train_dict = {
    # dataset dependent parameters
    'patch_x': patch_xy,                          # the width of 3D patches
    'patch_y': patch_xy,                          # the height of 3D patches
    'patch_t': patch_t,                           # the time dimension (frames) of 3D patches
    'overlap_factor':overlap_factor,              # overlap factor
    'scale_factor': 1,                            # the factor for image intensity scaling
    'select_img_num': 1000000,                    # select the number of frames used for training (use all frames by default)
    'train_datasets_size': train_datasets_size,   # datasets size for training (how many 3D patches)
    'datasets_path': datasets_path,               # folder containing files for training
    'pth_dir': pth_dir,                           # the path for pth file and result images 
    
    # network related parameters
    'n_epochs': n_epochs,                         # the number of training epochs
    'lr': 0.00005,                                # learning rate
    'b1': 0.5,                                    # Adam: bata1
    'b2': 0.999,                                  # Adam: bata2
    'fmap': 16,                                   # model complexity
    'GPU': GPU,                                   # GPU index
    'num_workers': num_workers,                   # if you use Windows system, set this to 0.
    'visualize_images_per_epoch': visualize_images_per_epoch,   # whether to show result images after each epoch
    'save_test_images_per_epoch': save_test_images_per_epoch    # whether to save result images after each epoch
}

tc = training_class(train_dict)
tc.run()



In [None]:
# Step 4: Apply pretained model to each phase 

# 3D U-Net Testing, adapted from demo_test_pipeline.ipynb (DeepCAD-RT)
datasets_path = 'datasets/MC100'
denoise_model = 'WF_500_hip'

test_datasize = 6000         
GPU = '0'                            
patch_xy = 100                
patch_t = 300                    
overlap_factor = 0.6                
num_workers = 0                      
visualize_images_per_epoch = False    
save_test_images_per_epoch = True     

test_dict = {
    # dataset dependent parameters
    'patch_x': patch_xy,                 # the width of 3D patches
    'patch_y': patch_xy,                 # the height of 3D patches
    'patch_t': patch_t,                  # the time dimension (frames) of 3D patches
    'overlap_factor':overlap_factor,     # overlap factor, 
    'scale_factor': 1,                   # the factor for image intensity scaling
    'test_datasize': test_datasize,      # the number of frames to be tested
    'datasets_path': datasets_path,      # folder containing all files to be tested
    'pth_dir': './pth',                  # pth file root path
    'denoise_model' : denoise_model,     # A folder containing all models to be tested
    'output_dir' : './results',          # result file root path
    # network related parameters
    'fmap': 16,                          # number of feature maps
    'GPU': GPU,                          # GPU index
    'num_workers': num_workers,          # if you use Windows system, set this to 0.
    'visualize_images_per_epoch': visualize_images_per_epoch,  # whether to display inference performance after each epoch
    'save_test_images_per_epoch': save_test_images_per_epoch   # whether to save inference image after each epoch in pth path
}

tc = testing_class(test_dict)
tc.run()



In [None]:
# Step 5: Re-interleave phase stacks 

def interleave_stacks(filename_p1,filename_p2, filename_p3):
    p1 = tifffile.imread(filename_p1)
    p2 = tifffile.imread(filename_p2)
    p3 = tifffile.imread(filename_p3)
    #p4 = tifffile.imread(filename_p4)
    img = np.zeros((p1.shape[0]*3,p1.shape[1],p1.shape[2]),dtype=np.uint16)
    img[0::3,:,:] = p1
    img[1::3,:,:] = p2
    img[2::3,:,:] = p3
    #img[3::4,:,:] = p4
    tifffile.imsave(filename_p1[:-6]+'_i.tif', img)

if False:
    filename_p1 = '122123\\CY01_16x\\CY01_16x\\n120_100_3\\DeepCAD\\FB\\p1_best_model_output.tif'
    filename_p2 = '122123\\CY01_16x\\CY01_16x\\n120_100_3\\DeepCAD\\FB\\p2_best_model_output.tif'
    filename_p3 = '122123\\CY01_16x\\CY01_16x\\n120_100_3\\DeepCAD\\FB\\p3_best_model_output.tif'
    interleave_stacks(filename_p1,filename_p2, filename_p3)

In [None]:
# Step 6: Perform interleaved OS-SIM and pWF Reconstructions, adapted from SIMScope3D Reconstruction

def optical_sectioning_sim_I(imgs, method, flag):
    if flag == 0:
        I1 = imgs[0,:]
        I2 = imgs[1,:]
        I3 = imgs[2,:]
    elif flag == 1:
        I1 = imgs[2,:]
        I2 = imgs[0,:]
        I3 = imgs[1,:]
    elif flag == 2:
        I1 = imgs[1,:]
        I2 = imgs[2,:]
        I3 = imgs[0,:]

    if method == 'DD': # for 90 degree phase shifts
        os_image = 0.5 * np.sqrt((2*I2 - I1 - I3)**2 + (I3 - I1)**2) 

    elif method == 'NEIL': # for 120 degree phase shifts
        os_image = np.sqrt(((I1-I2)**2)+((I1-I3)**2)+((I2-I3)**2))

    else:
        print('Invalid method. Please choose either DD or NEIL')  
    
    return os_image

def match_histogram_z_I(imgs, nangles, nphases, flag):
    for ii in range(nangles):
        for jj in range(0, nphases):
            if jj != flag:
               imgs[ii, jj,:] = match_histograms(imgs[ii, jj,:], imgs[ii, flag,:])

    return imgs

def run_SIM_Interleaved(file_Name, method, match_hist):
    dx = 1.0
    dz = 1.0
    excitation_wl = 0.470
    emission_wl = 0.520
    na = 0.3
    nangles = 1
    nphases = 3

    # load the data and reshape
    input_file_path = Path(file_Name)
    #input_file_path = Path("c://users//researcher/downloads/TRYSIM_8.tif")
    root_path = input_file_path.parents[0]
    img = tifffile.imread(input_file_path)
    print(img.shape)
    

    # TO DO: correctly parse metadata / load multiple images loop over timelapse
    nt = 1
    nz = int(img.shape[0]/(nangles*nphases))
    ny = img.shape[1]
    nx = img.shape[2]



    img1 = img[0:3*nz,:,:]
    img2 = img[1:3*nz-2,:,:]
    img3 = img[2:3*nz-1,:,:]

    print(img1.shape)

    nz = nz


    img1_reshape = np.reshape(img1,[nz,nangles,nphases,ny,nx])
    img2_reshape = np.reshape(img2,[nz-1,nangles,nphases,ny,nx])
    img3_reshape = np.reshape(img3,[nz-1,nangles,nphases,ny,nx])

    # turn image into float
    #img_reshape = img_reshape.astype(np.float32)
    img1_reshape = img1_reshape.astype(np.float32)
    img2_reshape = img2_reshape.astype(np.float32)
    img3_reshape = img3_reshape.astype(np.float32)

    # create storage variables
    widefield1 = np.zeros((nz,ny,nx),dtype=np.float32)
    os_sim_per_angle1 = np.zeros((nangles,ny,nx),dtype=np.float32)
    os_sim_angle1 = np.zeros((nangles,nz,ny,nx),dtype=np.float32)
    os_sim1 = np.zeros((nz,ny,nx),dtype=np.float32)

      
    widefield2 = np.zeros((nz,ny,nx),dtype=np.float32)
    os_sim_per_angle2 = np.zeros((nangles,ny,nx),dtype=np.float32)
    os_sim_angle2 = np.zeros((nangles,nz,ny,nx),dtype=np.float32)
    os_sim2 = np.zeros((nz,ny,nx),dtype=np.float32)

    widefield3 = np.zeros((nz,ny,nx),dtype=np.float32)  
    os_sim_per_angle3 = np.zeros((nangles,ny,nx),dtype=np.float32)
    os_sim_angle3 = np.zeros((nangles,nz,ny,nx),dtype=np.float32)
    os_sim3 = np.zeros((nz,ny,nx),dtype=np.float32)


    # loop over all timepoints
    for t_idx in tqdm(range(0,nt),desc='time',leave=False):

        # check if there is more than one time point
        if nt==1:
            imgs_to_process1 = img1_reshape[:]
            imgs_to_process2 = img2_reshape[:]
            imgs_to_process3 = img3_reshape[:]
        else:
            imgs_to_process = img1_reshape[t_idx,:]
        
        for z_idx in tqdm(range(0,nz-1),desc='SIM OS per z plane',leave=True):
            
            if match_hist == True:
                matched_imgs1 = match_histogram_z_I(imgs_to_process1[z_idx,:],nangles,nphases,0)
                matched_imgs2 = match_histogram_z_I(imgs_to_process2[z_idx,:],nangles,nphases,2)
                matched_imgs3 = match_histogram_z_I(imgs_to_process3[z_idx,:],nangles,nphases,1)

            else:    
                matched_imgs1 = imgs_to_process1[z_idx,:]
                matched_imgs2 = imgs_to_process2[z_idx,:]
                matched_imgs3 = imgs_to_process3[z_idx,:]

            # calculate widefield image at this z plane
            widefield1[z_idx,:] = np.nanmean(matched_imgs1, axis=(0, 1))
            widefield2[z_idx,:] = np.nanmean(matched_imgs2, axis=(0, 1))
            widefield3[z_idx,:] = np.nanmean(matched_imgs3, axis=(0, 1))
            #print(matched_imgs.shape)
            """
            if method == 'DD':
                matched_imgs_DD = matched_imgs[:,::2,:,:]
                widefield[z_idx,:] = np.nanmean(matched_imgs_DD,axis=(0, 1))
            """
            
            # calculate os-sim image for each angle at this z plane
            for angle_idx in range(0,nangles):
                os_sim_per_angle1[angle_idx,:]=optical_sectioning_sim_I(matched_imgs1[angle_idx,:],method,0)
                os_sim_angle1[angle_idx,z_idx,:]=os_sim_per_angle1[angle_idx,:]

                os_sim_per_angle2[angle_idx,:]=optical_sectioning_sim_I(matched_imgs2[angle_idx,:],method,1)
                os_sim_angle2[angle_idx,z_idx,:]=os_sim_per_angle2[angle_idx,:]

                os_sim_per_angle3[angle_idx,:]=optical_sectioning_sim_I(matched_imgs3[angle_idx,:],method,2)
                os_sim_angle3[angle_idx,z_idx,:]=os_sim_per_angle3[angle_idx,:]

            # average os-sim over all angles at this z plane
            os_sim1[z_idx,:] = np.nanmean(os_sim_per_angle1,axis=0)
            os_sim2[z_idx,:] = np.nanmean(os_sim_per_angle2,axis=0)
            os_sim3[z_idx,:] = np.nanmean(os_sim_per_angle3,axis=0)

    os_img = np.zeros((3*nz,ny,nx),dtype=np.float32)
    os_img[0::3,:,:] = os_sim1
    os_img[1::3,:,:] = os_sim2
    os_img[2::3,:,:] = os_sim3 

    wf_img = np.zeros((3*nz,ny,nx),dtype=np.float32)
    wf_img[0::3,:,:] = widefield1
    wf_img[1::3,:,:] = widefield2
    wf_img[2::3,:,:] = widefield3 
    
    os_img = rescale_intensity(os_img,out_range=(0,65535)).astype(np.uint16)
    wf_img = rescale_intensity(wf_img,out_range=(0,65535)).astype(np.uint16)

    return  os_img, wf_img

def save_Reconstructions_i(wf_img, os_img, input_file, match_hist):
    # remove last 6 frames before saving
    wf_img = wf_img[:-6,:,:]
    os_img = os_img[:-6,:,:]
    if match_hist == True:
        output_file_path_SIM = Path(input_file.rsplit('.', 1)[0] + '_interleaved_matchSIM_Reconstruction.tif')
        output_file_path_pWF = Path(input_file.rsplit('.', 1)[0] + '_interleaved_matchpWF_Reconstruction.tif')
        #output_file_path_MCN = Path(input_file.rsplit('.', 1)[0] + '_MCNR_Reconstruction.tif')
    else:

        output_file_path_SIM = Path(input_file.rsplit('.', 1)[0] + '_interleaved_NOmatchSIM_Reconstruction.tif')
        output_file_path_pWF = Path(input_file.rsplit('.', 1)[0] + '_interleaved_NOmatchpWF_Reconstruction.tif')
        
    tifffile.imwrite(output_file_path_SIM,os_img)
    tifffile.imwrite(output_file_path_pWF,wf_img)

def make_numFrames_divisible_by_3(fileName):
    img = tifffile.imread(fileName)
    img = img[0:img.shape[0]-img.shape[0]%3,:,:]
    tifffile.imsave(fileName[:-4]+'_divisible.tif', img)

if False:
    folder_Name = 'GEVI_SIM'
    filename = 'N36_400_X.tif'
    method = 'NEIL'
    match_hist = False
    try:
        os_img, wf_img = run_SIM_Interleaved(folder_Name + '\\' + filename, method, match_hist)
        save_Reconstructions_i(wf_img, os_img,  folder_Name +'\\' + method+ '_1_' + filename, match_hist)

    except:
        make_numFrames_divisible_by_3(folder_Name + '\\' + folder_N + '\\'+ filename)
        os_img, wf_img = run_SIM_Interleaved(folder_Name + '\\' + filename[:-4]+'_divisible.tif', method, match_hist)
        save_Reconstructions_i(wf_img, os_img,  folder_Name +'\\' + method+ '_1_' + filename[:-4]+'_divisible.tif', match_hist)



In [None]:
# Step 7 (applied to reconstructions that were not preprocessed)

# Function for seeded motion correction, adapted from VolPy
def SEED_MC(filepath, filepath2,save_name):
    print("bulk")
    
    fnames = os.fsdecode(filepath)
    fnames2 = os.fsdecode(filepath2)
    file_dir = os.path.split(fnames)[0]
    
  
    fr = 500                                     
    ROIs = None                                    
    index = None                                   
    weights = None                                
                                                   

    pw_rigid = False                                
    gSig_filt = (6, 6)                             
                                                   
    max_shifts = (10, 10)                            
    strides = (20, 20)                            
    overlaps = (24, 24)                            
    max_deviation_rigid = 3                         
    border_nan = 'copy'

    opts_dict = {
        'fnames': fnames,
        'fr': fr,
        'index': index,
        'ROIs': ROIs,
        'weights': weights,
        'pw_rigid': pw_rigid,
        'max_shifts': max_shifts,
        'gSig_filt': gSig_filt,
        'strides': strides,
        'overlaps': overlaps,
        'max_deviation_rigid': max_deviation_rigid,
        'border_nan': border_nan
    }

    opts = volparams(params_dict=opts_dict)

    if 'dview' in locals():
        cm.stop_server(dview=dview)

    c, dview, n_processes = cm.cluster.setup_cluster(
            backend='local', n_processes=50, single_thread=False)

    print("Motion COrrection")
    # %%% MOTION CORRECTION
    # first we create a motion correction object with the specified parameters
    mc = MotionCorrect(fnames, dview=dview, **opts.get_group('motion'))
    mc.motion_correct(save_movie=True)
    dview.terminate()

    # Save MC file
    m_rig = cm.load(mc.mmap_file)
    m_rig.save(save_name+'_mc_.tif', to32 = False)


    mmap_file = mc.apply_shifts_movie(fnames2, save_memmap=True, order='C')
    m_rig = cm.load(mmap_file)
    m_rig.save(save_name + '_seed_mc_.tif', to32 = False)


    print(inspect.signature(m_rig.save))
    plt.imshow(mc.total_template_rig, cmap = 'gray')
    #plt.imsave(mc.total_template_rig)   

if False:
    filepath = '122123\\CY01_16x\\CY01_16x\\n90_500_b4_1\\NEIL_1_n90_500_b4_MMStack.ome_interleaved_NOmatchpWF_Reconstruction.tif'
    filepath2 = '122123\\CY01_16x\\CY01_16x\\n90_500_b4_1\\DeepCAD\\N2N_WF500\\NEIL_1_p1_E_05_Iter_6254_outp_i_interleaved_NOmatchpWF_Reconstruction.tif'
    save_name = 'Seeded_WF_n2n'
    SEED_MC(filepath,filepath2,save_name)