Set your variables

In [None]:
### python script applied and purpose of the run 
# script_run = general_pipline_4D.py
# purpose = preprocessing session 211011 with pipeline from 220220s

### provided paths = data_path, saving_path and n2v model path
path_to_data = '/home/mpg08/aicha.hajiali/TLI_project/TLI_data/preprocessed/2022/T4/220209/SyNRA_old2/'
save_path = '/home/mpg08/aicha.hajiali/TLI_project/TLI_data/preprocessed/2022/T4/220209/SyNRA_old2/compile/'
model_path = '/home/mpg08/aicha.hajiali/TLI_project/preprocessing/n2v/models/'

### basename of output tiff files
output_name = 'SyNRA_GFP_220209_brain2.2_LP40_P36.tif'

### data filtering parameters = name(group), channel names,subset
group = 'SyNRA_GFP_220209_brain2.2_LP40_P36'
ch_names = ['GFP', 'red']  ## if the images has one channel then also specify here (e.g =GFP)
reg_subset = (0,60,325,600,300,650)  ##Z,Y,X slicing. In case no subset required, type 0,0
# reg_subset is not used in current 4D_pipeline

### preprocessing steps to apply
steps = ['compile','preshift', 'postshift', 'ants']
#all means ['compile','preshift', 'trim','postshift', 'ants', 'n2v', 'clahe', 'mask', 'segment']

### rotation and image_flip (not implemented yet)
# rotat_O = 6
# Flip = false

### files metadata
xy_pixel = 0.0764616
z_pixel = 0.4

### clahe parameters =
clipLimit = 1
kernel_size = (45, 45)

### n2v paramters =
model_name = 'n2v_3D_v6'

### registration sequence
reference_last = False
ants_ref_no = 35
ref_reset = 50

### registration parameters
save_pre_shift = True
sigma = 5
drift_corr = ['Rigid', 'Affine', 'SyNRA']
metric = ['meansquares', 'mattes', 'CC']
grad_step = 0.2
flow_sigma = 3
total_sigma = 0
aff_sampling = 32
syn_sampling = 32
reg_iterations = (80,40,10)
aff_iterations = (2100,1200,1200,10)
aff_shrink_factors = (6,4,2,1)
aff_smoothing_sigmas = (3,2,1,0)

Import packages and define functions

In [None]:
from timeit import default_timer as timer
# We import all our dependencies.
import argparse
import os
import cv2 as cv
import numpy as np 
import tifffile as tif
from detect_delimiter import detect
from n2v.models import N2V
import ants
from skimage.registration import phase_cross_correlation as corr
import csv
import psutil
import gc
from scipy import ndimage, spatial, stats
from sklearn import metrics
from skimage.filters import gaussian, threshold_otsu, median
from tqdm import tqdm
import operator
import Neurosetta as neu

These three functions to read txt info file (here not implemented)

In [None]:
def mem_use():
    print('memory usage')
    print('cpu_percent', psutil.cpu_percent())
    print(dict(psutil.virtual_memory()._asdict()))
    print('percentage of used RAM', psutil.virtual_memory().percent)
    print('percentage of available memory', psutil.virtual_memory().available * 100 / psutil.virtual_memory().total)

def str2bool(v):
    """this function convert str to corresponding boolean value"""
    options = ("yes", "true", "t", 'y', 'no','false', 'n','f')
    if v.lower() in options:
        return str(v).lower() in ("yes", "true", "t", "1")
    else:
        return v

def txt2dict(path):
    start_time = timer()
    print('getting info from', path)
    with open(path) as f:
        lines = f.readlines()
    for line in lines:
        if line[0] == '#':
            lines.remove(line)
    for ind, line in enumerate(lines):
        if '#' in line:
            lines[ind] = line[0:line.index('#')]
        elif '\n' in line:
            lines[ind] = line.replace('\n','')
    for line in lines:
        if line == '':
            lines.remove(line)
    delimiter = detect(lines[0])
    print(len(lines),'lines found in txt_file with', delimiter, 'as the delimiter')
    try:
        for i in [0]:
            lines = [item.strip().rsplit(delimiter, 2) for item in lines]
            input_txt = {item[0].strip(): item[1].strip() for item in lines}
    except:
        print('failed to read txt_file')
    for key, val in input_txt.items():
        if ',' in val:
            try:
                input_txt[key] = tuple(map(int, val.split(',')))
            except:
                try:
                    input_txt[key] = [item.strip() for item in val.split(',')]
                except:
                    pass
        else:
            try:
                input_txt[key] = float(val)
            except:
                input_txt[key] = str2bool(val)    
    ### adding some default parameters if missing in info_txt
    if 'sigma' not in input_txt.keys():
        input_txt['sigma'] = 0
    if 'steps' not in input_txt.keys():
        input_txt['steps'] = ['all']    
    if 'reg_subset' not in input_txt.keys():
        input_txt['reg_subset'] = [0,0]
    if 'metric' not in input_txt.keys():
        input_txt['metric'] = 'mattes'
    if 'check_ch' not in input_txt.keys():
        input_txt['check_ch'] = input_txt['ch_names'][0]
    # if 'double_register' not in input_txt.keys():
    #     input_txt['double_register'] = False
    #### reasign un-recognized parameters
    if type(input_txt['ch_names']) != list:
        input_txt['ch_names'] = [input_txt['ch_names']]
    if type(input_txt['drift_corr']) != list:
        input_txt['drift_corr'] = [input_txt['drift_corr']]
    if type(input_txt['steps']) == str:
        input_txt['steps'] = [input_txt['steps'].lower()]
    elif type(input_txt['steps']) == tuple:
        input_txt['steps'] = [s.lower() for s in input_txt['steps']]
    if 'all' in input_txt['steps']:
        input_txt['steps'] = ['compile','preshift', 'trim','postshift', 'ants', 'n2v', 'clahe', 'mask', 'segment']
    if 'check_ch' not in input_txt['ch_names']:
        print('channel defined for similarity_check not recognized, so %s used' %input_txt['ch_names'][0])
        input_txt['check_ch'] = input_txt['ch_names'][0]
    parameters = {'grad_step':0.2, 'flow_sigma':3, 'total_sigma':0,
                'aff_sampling':32, 'aff_random_sampling_rate':0.2, 
                'syn_sampling':32, 'reg_iterations':(40,20,0), 
                'aff_iterations':(2100,1200,1200,10), 
                'aff_shrink_factors':(6,4,2,1), 
                'aff_smoothing_sigmas':(3,2,1,0)}
    for para in parameters.keys():
        if para in input_txt.keys():
            pass
        else:
            input_txt[para] = parameters[para]
    if 'ants_ref_st' not in input_txt.keys():
        input_txt['ants_ref_st'] = 0    
    print(input_txt)
    print('reading_text runtime', timer()-start_time)
    return input_txt

In [None]:
def get_file_names(path, group_by='', order=True, nested_files=False, criteria='tif'):
    """returns a list of all files' names in the given directory and its sub-folders
    the list can be filtered based on the 'group_by' str provided
    the files_list is sorted in reverse if the order is set to True. 
    The first element of the list is used later as ref"""
    start_time = timer()
    if os.path.isfile(path):
        file_list = [path]
    else:
        file_list = []
        if nested_files == False:
            file_list = [os.path.join(path, f) for f in os.listdir(path) if os.path.isfile(os.path.join(path, f))]
        else:
            for path, subdirs, files in os.walk(path):
                for name in files:
                    file_list.append(os.path.join(path, name))
        file_list = [file for file in file_list if group_by in file]
        file_list = [file for file in file_list if criteria in file]
        file_list.sort(reverse=order)    
    print('files_list runtime', timer()-start_time)
    return file_list

In [None]:
def img_limits(img, limit=0, ddtype='uint16'):
    # for i in tqdm(range(1), desc = 'img_limit'):
    max_limits = {'uint8': 255, 'uint16': 65530}
    img = img - img.min()        
    if limit == 0:
        limit = img.max()
    if limit > max_limits[ddtype]:
        limit = max_limits[ddtype]
        # print('the limit provided is larger than alocated dtype. limit reassigned as appropriate', limit)
    img = img/img.max()
    img = img*limit
    img = img.astype(ddtype)
    return img

In [None]:
def split_convert(image, ch_names):
    """deinterleave the image into dictionary of two channels"""
    # for i in tqdm(range(1), desc = 'split_convert'):
    start_time = timer()
    image_ch = {}
    for ind, ch in enumerate(ch_names):
        image_ch[ch] = image[ind::len(ch_names)]
    if len(ch_names) > 1:
        image_ch[ch_names[-1]] = median(image_ch[ch_names[-1]])
    for ch, img in image_ch.items():
        image_ch[ch] = img_limits(img, limit=0)
    print('split_convert runtime', timer()-start_time)
    return image_ch

In [None]:
def files_to_4D(files_list, ch_names=[''], 
                save=True, save_path='', save_file='', 
                xy_pixel=1, z_pixel=1, ddtype='uint16'):
    """
    read files_list, load the individual 3D_img tifffiles, 
    and convert them into a dict of 4D-arrays of the identified ch
    has the option of saving is as 8uint image
    """
    start_time = timer()
    image_4D = {ch:[] for ch in ch_names}
    files_list.sort()
    for file in tqdm(files_list, desc = 'compiling_files'):
        image = tif.imread(file)
        image = split_convert(image, ch_names=ch_names)
        for ch in ch_names:
            image_4D[ch].append(image[ch])
    z_dim = min([len(img) for img in image_4D[ch_names[0]]])
    print(image_4D.keys(), type(image_4D[ch_names[-1]]), len(image_4D[ch_names[-1]]))
    for ch in ch_names:
        print('compiling the', ch, 'channel')
        image_4D[ch] = [stack[0:z_dim] for stack in image_4D[ch]]
        image_4D[ch] = np.array(image_4D[ch])
        for tim in tqdm(range(len(image_4D[ch])), desc = 'setting img_limits'):
            if image_4D[ch][tim].min()!= 0 or image_4D[ch][tim].dtype != ddtype:
                image_4D[ch][tim] = img_limits(image_4D[ch][tim], limit=0, ddtype=ddtype)    
    if save == True:
        if save_path[-1] != '/':
            save_path += '/'
        if save_file == '':
            name1 = os.path.basename(files_list[0])
            name2 = os.path.basename(files_list[1])
            for s in name1:
                if s in name2:
                    save_file += s
                else:
                    break
        for ch, img in image_4D.items():
            save_name = save_path+'4D_'+ch+'_'+save_file
            if os.path.splitext(save_name)[-1] not in ['.tif','.tiff']:
                save_name += '.tif'
            save_image(save_name, img, xy_pixel=xy_pixel, z_pixel=z_pixel)
    print('files_to_4D runtime', timer()-start_time)
    return image_4D

In [None]:
def img_subset(img, subset):
    print('subsetting the image')
    try:
        subset_img = img[subset[0]:subset[1],subset[2]:subset[3],subset[4]:subset[5]]
    except:
        print('failed to subset image')
    return subset_img

def rot_flip(img, flip, angle=0):
    vert = ['vertical', 'vertically', 2, -1]
    hort = ['horizontally', 'horizontal', 1]
    if flip in vert:
        flipped = img[:,:,::-1]
        print('flipped image vertically')
    elif flip in hort:
        flipped = img[:,::-1,:]
        print('flipped image horizontally')
    else:
        flipped = img.copy()
    
    if len(flipped.shape) == 2:
        flipped = ndimage.rotate(flipped, angle, reshape=False)
        print('rotated image by', angle)
    else:
        for ind, sli in enumerate(flipped):
            flipped[ind] = ndimage.rotate(sli, angle, reshape=False)
        print('rotated image by', angle)
    return flipped

In [None]:
def save_image(name, image, xy_pixel=0.0764616, z_pixel=0.4):
    """save provided image by name with provided xy_pixel, and z_pixel resolution as metadata"""
    if len(image.shape) == 3:
        dim = 'ZYX'
    elif len(image.shape) == 4:
        dim = 'TZYX'
    if image.dtype != 'uint16': ###this part to be omitted later
        print('image type is not uint16')
        image = image.astype('uint16')
    tif.imwrite(name, image, imagej=True, dtype=image.dtype, resolution=(1./xy_pixel, 1./xy_pixel),
                metadata={'spacing': z_pixel, 'unit': 'um', 'finterval': 1/10,'axes': dim})

These two functions for phase correlation, but no longer used because Neurosetta is now used

In [None]:
def phase_corr(fixed, moving, sigma):
    if fixed.shape > moving.shape:
        print('fixed image is larger than moving', fixed.shape, moving.shape)
        fixed = fixed[tuple(map(slice, moving.shape))]
        print('fixed image resized to', fixed.shape)
    elif fixed.shape < moving.shape:
        print('fixed image is smaller than moving', fixed.shape, moving.shape)
        moving = moving[tuple(map(slice, fixed.shape))]
        print('moving image resized to', moving.shape)
    fixed = gaussian(fixed, sigma=sigma)
    moving = gaussian(moving, sigma=sigma)
    print('applying phase correlation')
    try:
        for i in [0]:
            shift, error, diffphase = corr(fixed, moving)
    except:
        for i in [0]:
            shift, error, diffphase = np.zeros(len(moving)), 0, 0
            print("couldn't perform PhaseCorr, so shift was casted as zeros")
    return shift

def phase_corr_4D(image, sigma, xy_pixel=1, 
                  z_pixel=1, ch_names=[1], 
                  ref_ch=-1,                      
                  save=True, save_path='',
                  save_file='', save_shifts=True):
    if isinstance(image, dict) == False:
        image = {ch_names[0]:image}
    pre_shifts = {}
    if len(ch_names) == 1:
        ref_ch = ch_names[0]
    else:
        try:
            ref_ch = ch_names[ref_ch]
        except:
            ref_ch = ch_names[-1]
    ref_im = image[ref_ch]
    current_shift = [0 for i in ref_im[0].shape]
    print('initial shift of 0', current_shift)
    print(len(ref_im[1:]), ref_im[1].shape)
    for ind in tqdm(np.arange(len(ref_im[1:]))) :
        pre_shifts[ind+1] = phase_corr(ref_im[ind], ref_im[ind+1], sigma) 
        current_shift = [sum(x) for x in zip(current_shift, pre_shifts[ind+1])] 
        print(pre_shifts[ind+1], current_shift)
        print('applying preshift on timepoint', ind+1, 'with current pre_shift', current_shift)
        for ch, img in image.items(): 
            image[ch][ind] = ndimage.shift(img[ind], current_shift) 
    if save == True:
        for ch, img in image.items():
            save_name = str(save_path+'PhaseCorr_'+ch+'_'+save_file)
            if '.tif' not in save_name:
                save_name += '.tif'
            save_image(save_name, img, xy_pixel=xy_pixel, z_pixel=z_pixel)   
    if save_shifts == True:
        shift_file = save_path+"PhaseCorr_shifts.csv"
        with open(shift_file, 'w', newline='') as csvfile:
            fieldnames = ['timepoint', 'phase_shift']
            writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
            writer.writeheader()
            for timepoint, shift in pre_shifts.items():
                writer.writerow({'timepoint' : timepoint+1, 'phase_shift' : shift})
        csvfile.close()
    return image, pre_shifts

In [None]:
def check_similarity(ref, image):
    try:
        for i in [1]:
            image = image.numpy()
            ref = ref.numpy()
    except:
        pass
    check = sum(metrics.pairwise.cosine_similarity(image.ravel().reshape(1,-1), 
                           ref.ravel().reshape(1,-1)))[0]
    # print('check_similarity of image to ref is', check)
    return check

def similarity_4D(image_4D, save=True, save_path='', save_file=''):
    start_time = timer()
    similairties = {1:1}
    for t in tqdm(np.arange(len(image_4D[1:])), desc='cosine_sim for timepoint'):
        img_t = image_4D[t]
        similairties[t+2] = check_similarity(img_t, image_4D[t+1])
    if save == True:
        if save_file == '':
            save_file = "phase_similarity_check.csv"
        checks_file = save_path+save_file
        if '.csv' not in checks_file:
            checks_file +='.csv'
        with open(checks_file, 'w', newline='') as csvfile:
            fieldnames = ['timepoint', 'cosine_similarity']
            writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
            writer.writeheader()
            for timepoint, check in similairties.items():
                writer.writerow({'timepoint' : timepoint, 'cosine_similarity' : check})
        csvfile.close()
        print('finished measuring similarity check', timer()-start_time)
    return similairties

In [None]:
def N2V_predict(model_name, model_path, xy_pixel=1, z_pixel=1, image=0, file='', save=True, save_path='', save_file=''):
    """apply N2V prediction on image based on provided model
    if save is True, save predicted image with provided info"""
    if file != '':
        image = tif.imread(file)
    file_name = os.path.basename(file)
    model = N2V(config=None, name=model_name, basedir=model_path)
    predict = model.predict(image, axes='ZYX', n_tiles=None)
    if predict.min() != 0:
        predict = img_limits(predict, limit=0)
    if save == True:
        if save_file == '':
            save_name = str(save_path+'N2V_'+file_name)
        else:
            save_name = str(save_path+'N2V_'+save_file)
        if '.tif' not in save_name:
            save_name +='.tif'
        save_image(save_name, predict, xy_pixel=xy_pixel, z_pixel=z_pixel)
    return predict

def N2V_4D(image_4D, model_name, model_path, xy_pixel=1, z_pixel=1, save=True, save_path='', save_file=''):
    for ind, stack in enumerate(image_4D):
        image_4D[ind] = N2V_predict(image=stack,
                                    model_name=model_name, 
                                    model_path=model_path, 
                                    save=False)
    if save == True:
        if save_file == '':
            save_name = str(save_path+'N2V_4D.tif')
        else:
            save_name = str(save_path+'N2V_'+save_file)
        if '.tif' not in save_name:
            save_name +='.tif'
        save_image(save_name, image_4D, xy_pixel=xy_pixel, z_pixel=z_pixel)
    return image_4D

In [None]:
def apply_clahe(kernel_size, xy_pixel=1, z_pixel=1, image=0, file='', clipLimit=1, save=True, save_path='', save_file=''):
    """apply Clahe on image based on provided kernel_size and clipLimit
    if save is True, save predicted image with provided info"""
    if file != '':
        image = tif.imread(file)
    if image.min()<0:
        image = (image - image.min())
    image = image.astype('uint16')
    print(image.dtype)
    file_name = os.path.basename(file)
    image_clahe= np.empty(image.shape)
    clahe_mask = cv.createCLAHE(clipLimit=clipLimit, tileGridSize=kernel_size)
    for ind, slice in enumerate(image):
        image_clahe[ind] = clahe_mask.apply(slice)
        image_clahe[ind] = cv.threshold(image_clahe[ind], 
                            thresh=np.percentile(image_clahe[ind], 95), 
                            maxval=image_clahe[ind].max(), 
                            type= cv.THRESH_TOZERO)[1]
    if image_clahe.min() != 0:
        image_clahe = img_limits(image_clahe, limit=0)
    if save == True:
        if save_file == '':
            save_name = save_path+'clahe_'+file_name
        else:
            save_name = save_path+'clahe_'+save_file
        if '.tif' not in save_name:
            save_name += '.tif'
        save_image(save_name, image_clahe, xy_pixel=xy_pixel, z_pixel=z_pixel)
    return image_clahe

def clahe_4D(image_4D, kernel_size, clipLimit=1, xy_pixel=1, z_pixel=1, save=True, save_path='', save_file=''):
    for ind, stack in enumerate(image_4D):
        image_4D[ind] = apply_clahe(image=stack,
                                    kernel_size=kernel_size, 
                                    clipLimit=clipLimit, 
                                    save=False)
    if save == True:
        if save_file == '':
            save_name = str(save_path+'clahe_4D.tif')
        else:
            save_name = str(save_path+'clahe_'+save_file)
        if '.tif' not in save_name:
            save_name +='.tif'
        save_image(save_name, image_4D, xy_pixel=xy_pixel, z_pixel=z_pixel)
    return image_4D

these two functions to create masked images, but are not used because Neurosetta is implemented

In [None]:
def mask_image(volume, return_mask = False ,sig = 2):
    """
    Create a binary mask from a 2 or 3-dimensional np.array.
    Method normalizes the image, converts it to greyscale, then applies gaussian bluring (kernel width set to 2 by default, can be changed with sig parameter).
    This is followed by thresholding the image using the isodata method and returning a binary mask. 
    Parameters
    ----------
    image           np.array
                    np.array of an image (2 or 3D)
    return_mask     bool
                    If False (default), the mask is subtracted from the original image. If True, a boolian array is returned, of the shape of the original image, as a mask. 
    sig             Int
                    kernel width for gaussian smoothing. set to 2 by default.
    Returns
    -------
    mask            np.array
                    Returns a binary np.array of equal shape to the original image, labeling the masked area.
    """
    for i in tqdm(range(1), desc = '3D_mask'):
        start_time = timer()
        image = volume.copy()
        # if input image is 2D...
        image = image.astype('float32')
        # normalize to the range 0-1
        image -= image.min()
        image /= image.max()
        # blur and grayscale before thresholding
        blur = gaussian(image, sigma=sig)
        # perform adaptive thresholding
        t = threshold_otsu(blur.ravel())
        mask = blur > t
        # convert to bool
        mask = np.array(mask, dtype=bool)
        print('mask_image runtime', timer()-start_time)
    if return_mask == False:
        image[mask==False] = 0
        return image
    else:
        return mask

def mask_4D(image, xy_pixel=1, z_pixel=1, sig=2, save=True, save_path='', save_file=''):
    start_time = timer()
    mask = image.copy()
    mask_image = image.copy()
    for i, img in enumerate(image):
        print('calculating mask for stack#', i)
        try:
            mask[i] = mask_image(img, return_mask=False ,sig=sig)
            mask_image[i] = mask_image(img, return_mask=True ,sig=sig)
        except:
            mask[i] = mask[i]
            mask_image[i] = mask_image[i]
        mask[i] = img_limits(mask[i], limit=255, ddtype='uint16')
    if save == True:
        if save_file == '':
            save_name = save_path+'masked_image.tif'
            mask_name = save_path+'mask.tif'
        else:
            mask_name = save_path+'mask_'+save_file
            save_name = save_path+'image_mask_'+save_file
        if '.tif' not in save_name:
            save_name += '.tif'
        save_image(mask_name, mask, xy_pixel=xy_pixel, z_pixel=z_pixel)
        save_image(save_name, mask_image, xy_pixel=xy_pixel, z_pixel=z_pixel)
    print('mask_4D runtime', timer()-start_time)
    return mask, mask_image

This part is about Antspy

In [None]:
def antspy_regi(ref, img, drift_corr, metric='mattes',
                reg_iterations=(40,20,0), 
                aff_iterations=(2100,1200,1200,10), 
                aff_shrink_factors=(6,4,2,1), 
                aff_smoothing_sigmas=(3,2,1,0),
                grad_step=0.2, flow_sigma=3, total_sigma=0,
                aff_sampling=32, syn_sampling=32):

    """claculate drift of image from ref using Antspy with provided drift_corr"""
    try:
        for i in [1]:
            fixed= ants.from_numpy(np.float32(ref.copy()))
            moving= ants.from_numpy(np.float32(img.copy()))
    except:
        for i in [1]:
            fixed= ref.copy()
            moving= img.copy()
    # shift = ants.registration(fixed, moving, type_of_transform=drift_corr,
    #                             aff_metric=metric, syn_metric=metric)            
    shift = ants.registration(fixed, moving, type_of_transform=drift_corr, 
                              aff_metric=metric, syn_metric=metric,
                              reg_iterations=(reg_iterations[0],reg_iterations[1],reg_iterations[2]), 
                              aff_iterations=(aff_iterations[0],aff_iterations[1],aff_iterations[2],aff_iterations[3]), 
                              aff_shrink_factors=(aff_shrink_factors[0],aff_shrink_factors[1],aff_shrink_factors[2],aff_shrink_factors[3]), 
                              aff_smoothing_sigmas=(aff_smoothing_sigmas[0],aff_smoothing_sigmas[1],aff_smoothing_sigmas[2],aff_smoothing_sigmas[3]),
                              grad_step=grad_step, flow_sigma=flow_sigma, total_sigma=total_sigma,
                              aff_sampling=aff_sampling, syn_sampling=syn_sampling)
    try:
        ref = ref.numpy().astype('uint16')
    except:
        pass
    try:
        img = img.numpy().astype('uint16')
    except:
        pass
    # print(type(ref), type(img))
    del fixed, moving
    return shift

In [None]:
def antspy_drift(ref, img, shift, check=False):
    """shifts image based on ref and provided shift"""
    try:
        ref = ref.numpy()
    except:
        pass
    try:
        img = img.numpy()
    except:
        pass
    ref= ants.from_numpy(np.float32(ref))
    img= ants.from_numpy(np.float32(img))
    vol_shifted = ants.apply_transforms(ref, img, transformlist=shift)
    ref = ref.numpy().astype('uint16')
    img = img.numpy().astype('uint16')
    vol_shifted = vol_shifted.numpy().astype('uint16')
    # print((vol_shifted == img).all())
    if check == True:
        pre_check = check_similarity(ref, img)
        post_check = check_similarity(ref, vol_shifted)
        print('similarity_check', pre_check, 'improved to', post_check)
        if (pre_check - post_check) > 0.1:
            vol_shifted = img.copy()
            print('similarity_check was smaller after shift, so shift was ignored:', pre_check, '>>', post_check)
    print(shift, (vol_shifted == img).all())
    return vol_shifted

In [None]:
def apply_ants_channels(ref, image, drift_corr='Rigid',  xy_pixel=1, 
                        z_pixel=1, ch_names=[''], ref_ch=-1,
                        metric='mattes',
                        reg_iterations=(40,20,0), 
                        aff_iterations=(2100,1200,1200,10), 
                        aff_shrink_factors=(6,4,2,1), 
                        aff_smoothing_sigmas=(3,2,1,0),
                        grad_step=0.2, flow_sigma=3, total_sigma=0,
                        aff_sampling=32, syn_sampling=3,  
                        check_ch='',                       
                        save=True, save_path='',save_file=''):
    """calculate and apply shift on both channels of image based on ref, which is dictionary of two channels.
    if save is True, save shifted channels individually with provided info"""
    shift = antspy_regi(ref[ch_names[ref_ch]], image[ch_names[ref_ch]], drift_corr, metric,
                        reg_iterations=reg_iterations, 
                        aff_iterations=aff_iterations, 
                        aff_shrink_factors=aff_shrink_factors, 
                        aff_smoothing_sigmas=aff_smoothing_sigmas,
                        grad_step=grad_step, flow_sigma=flow_sigma, 
                        total_sigma=total_sigma,
                        aff_sampling=aff_sampling, 
                        syn_sampling=syn_sampling)
    shifted = {}
    for ch, img in image.items():
        shifted[ch]= antspy_drift(ref[ch],img,shift=shift['fwdtransforms'],check=False)
    if check_ch in image.keys():
        pre_check = check_similarity(ref[check_ch], image[check_ch])
        post_check = check_similarity(ref[check_ch], shifted[check_ch])        
        if (pre_check - post_check) <= 0.1:
            print('similarity_check', pre_check, 'improved to', post_check)
        else:
            print('similarity_check was smaller after shift, so it was ignored:', pre_check, '>>', post_check)
            shifted = image.copy()
    else:
        print(check_ch, 'not a recognized ch in image')
    for ch, img in shifted.items():
        if img.min() != 0:
            shifted[ch] = img_limits(img, ddtype='uint16')
        if save == True:
            save_name = str(save_path+drift_corr+'_'+ch+'_'+save_file)
            if '.tif' not in save_name:
                save_name += '.tif'
            save_image(save_name, shifted[ch], xy_pixel=xy_pixel, z_pixel=z_pixel)
    return shifted, shift


In [None]:
def apply_ants_4D(image, drift_corr,  xy_pixel=1, 
                  z_pixel=1, ch_names=[1], ref_t=0,
                  ref_ch=-1, metric='mattes',
                  reg_iterations=(40,20,0), 
                  aff_iterations=(2100,1200,1200,10), 
                  aff_shrink_factors=(6,4,2,1), 
                  aff_smoothing_sigmas=(3,2,1,0),
                  grad_step=0.2, flow_sigma=3, total_sigma=0,
                  aff_sampling=32, syn_sampling=3,  
                  check_ch='',                       
                  save=True, save_path='',save_file=''):
    """"""
    start_time = timer()
    if isinstance(image, dict) == False:
        image = {ch_names[0]:image}
    s_range = len(image[ch_names[ref_ch]])
    scope = np.arange(0,ref_t)
    scope = np.concatenate((scope, np.arange(ref_t,s_range)))
    print('ants seq for 4D regi',scope)
    if ref_t== -1:
        ref_t= len(image[ch_names[-1]])-1
        
    raw_sim_checks = [1]
    for i in tqdm(np.arange(len(image[check_ch])-1), desc='similarity_check'):
        raw_sim_checks.append(check_similarity(image[check_ch][i], image[check_ch][i+1]))
    fixed = {ch:img[ref_t].copy() for ch, img in image.items()}
    shifts = [0]
    desc = 'AntsPy_'+drift_corr
    for i in tqdm(scope, desc=desc):
        moving = {ch:img[i].copy() for ch, img in image.items()}
        shifted, shift = apply_ants_channels(fixed, moving, drift_corr=drift_corr,  
                                            xy_pixel=xy_pixel, 
                                            z_pixel=z_pixel, ch_names=ch_names, 
                                            ref_ch=ref_ch,
                                            metric=metric,
                                            reg_iterations=reg_iterations, 
                                            aff_iterations=aff_iterations, 
                                            aff_shrink_factors=aff_shrink_factors, 
                                            aff_smoothing_sigmas=aff_smoothing_sigmas,
                                            grad_step=grad_step, flow_sigma=flow_sigma, 
                                            total_sigma=total_sigma,
                                            aff_sampling=aff_sampling, 
                                            syn_sampling=syn_sampling,  
                                            check_ch=check_ch,                       
                                            save=False)
        shifts.append(shift['fwdtransforms'])
        for ch in ch_names:
            image[ch][i] = shifted[ch] 
            # print('last step',(image[ch][i] == shifted[ch]).all())
            if image[ch][i].min() != 0:
                image[ch][i] = img_limits(image[ch][i], ddtype='uint16')
        del shifted, moving, shift

    sim_checks = [1]
    for i in tqdm(np.arange(len(image[check_ch])-1), desc='similarity_check'):
        sim_checks.append(check_similarity(image[check_ch][i], image[check_ch][i+1]))

    if save == True:
        for ch, img in image.items():
            save_name = str(save_path+drift_corr+'_'+ch+'_'+save_file)
            if '.tif' not in save_name:
                save_name += '.tif'
            save_image(save_name, img, xy_pixel=xy_pixel, z_pixel=z_pixel)   

        shift_file = save_path+drift_corr+'AntsShifts_'+save_file+'.csv'
        shift_file = shift_file.replace('.tif','')
        with open(shift_file, 'w', newline='') as csvfile:
            fieldnames = ['timepoint', 'shift_mat']
            writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
            writer.writeheader()
            for timepoint, shift in enumerate(shifts):
                writer.writerow({'timepoint' : timepoint+1, 'shift_mat' : shift})
        csvfile.close()    

        check_file = save_path+drift_corr+'ANTsCheck_'+save_file+".csv"
        check_file = check_file.replace('.tif','')
        with open(check_file, 'w', newline='') as csvfile1:
            fieldnames = ['timepoint', 'original','after_shift']
            writer = csv.DictWriter(csvfile1, fieldnames=fieldnames)
            writer.writeheader()
            for timepoint, check in enumerate(sim_checks):
                writer.writerow({'timepoint' : timepoint+1, 'original':raw_sim_checks[timepoint],'after_shift' : check})
        csvfile1.close()    
    print('ants_round runtime', timer()-start_time)  
    del shifts, sim_checks
    return image

In [None]:
####### compiling the single 3D tif files to 4D_image, or reading 4D_image(s)
start_time = timer()
try:
    file_4D = output_name
except:    
    file_4D = group.split('_')[0]+'.tif'
if os.path.isdir(path_to_data):
    files_list = get_file_names(path_to_data, 
                                group_by=group, 
                                order=reference_last)
    print('the first 5 files (including ref) are', files_list[0:5])
    if 'compile' in steps:
        # loading raw skimage files into 4D array and saving raw 4D
        print('compiling 3D image_files into dict of 4D_images of specified channels')
        image_4D = files_to_4D(files_list, ch_names=ch_names, save=True, 
                            save_path=save_path, 
                            save_file='raw_'+file_4D, 
                            xy_pixel=xy_pixel, 
                            z_pixel=z_pixel, 
                            ddtype='uint16')
    else:
        temp = ch_names.copy()
        temp.sort()
        image_4D = {ch:tif.imread(files_list[ind]) for ind, ch in enumerate(temp)} 
elif os.path.isfile(path_to_data):
    image_4D = {ch_names[0]:tif.imread(path_to_data)}
    files_list = [i for i in np.arange(len(image_4D))]  ### I don't remember why I needed this line
    if file_4D == '':
        file_4D = os.path.basename(path_to_data)
        file_4D = file_4D.split('_')[0]+'.tif'    
print('finished reading and compiling images', timer()-start_time)

In [None]:
####### initial registration of images using phase_correlation on red (last) channel using Neurosetta
if 'preshift' in steps:
    start_time = timer()
    print('applying preshift')
    neuron = neu.Neuron(image_4D[ch_names[-1]])
    shifts = neu.shifts(neuron)
    for ch, img in image_4D.items():
        for ind, sta in img:
            current_shift = sum(shifts[:ind])
            image_4D[ch][ind] = ndimage.shift(sta, current_shift)
    sa_name = save_path+'Phase_'+ch+'_'+output_name
    save_image(sa_name, image_4D[ch], xy_pixel=xy_pixel, z_pixel=z_pixel)   
    del neuron 

In [None]:
###### optional deletion of last quater of slices in Z_dim of each 3D image
#### this is to reduce the run time for Ants a little bit
if 'trim' in steps:
    start_time = timer()
    trim = int((3*image_4D[ch_names[-1]].shape[1])/4)
    print('image size before trimming is', image_4D[ch_names[-1]].shape)
    print('trimming all images in Z dim to 0:', trim)
    for ch, img in image_4D.items():
        image_4D[ch] = img[:,0:trim]
    print('image size after trimming is', image_4D[ch_names[-1]].shape)
    print('finished trimming images', timer()-start_time)

In [None]:
###### applying Ants registration based on the last (red) channel
if 'ants' in steps:
    start_time = timer()
    ref_t = ants_ref_no
    if isinstance(ref_t, int) == False or ref_t < 0 or ref_t > len(image_4D[ch_names[-1]]):
        ref_t = 0
    for i, drift_t in enumerate(drift_corr):
        ants_step = str(i+1)
        try:
            metric_t = metric[i]
        except:
            for i in [0]:
                print('optimization metric not recognized. mattes used instead')
                metric_t = 'mattes'
        image_4D = apply_ants_4D(image_4D, 
                                drift_corr=drift_t,  
                                xy_pixel=xy_pixel, 
                                z_pixel=z_pixel, 
                                ch_names=ch_names, 
                                ref_t=ref_t,
                                ref_ch=-1, 
                                metric=metric_t,
                                reg_iterations=reg_iterations, 
                                aff_iterations=aff_iterations, 
                                aff_shrink_factors=aff_shrink_factors, 
                                aff_smoothing_sigmas=aff_smoothing_sigmas,
                                grad_step=grad_step, 
                                flow_sigma=flow_sigma, 
                                total_sigma=total_sigma,
                                aff_sampling=aff_sampling, 
                                syn_sampling=syn_sampling, 
                                check_ch=ch_names[0],                       
                                save=True, 
                                save_path=save_path,
                                save_file=ants_step+'_'+file_4D)
        print('finished ants run with', drift_t)
    print('finished antspy registration', timer()-start_time)

In [None]:
if 'postshift' in steps:
    start_time = timer()
    if 'neurons' not in locals():
        neurons = {1: image_4D[ch_names[0]]}
    ref_t = ants_ref_no
    if isinstance(ref_t, int) == False or ref_t < 0 or ref_t > len(image_4D[ch_names[-1]]):
        ref_t = 0
    for l, neuron in neurons.items():
        image = image_4D.copy()
        image[ch_names[0]] = neuron
        for i, drift_t in enumerate(drift_corr):
            ants_step = str(i+1)
            try:
                metric_t = metric[i]
            except:
                for i in [0]:
                    print('optimization metric not recognized. mattes used instead')
                    metric_t = 'mattes'
            image_4D = apply_ants_4D(image, 
                                    drift_corr=drift_t,  
                                    xy_pixel=xy_pixel, 
                                    z_pixel=z_pixel, 
                                    ch_names=ch_names, 
                                    ref_t=ref_t,
                                    ref_ch=0, ### this is the main defference between ants and postshift
                                    metric=metric_t,
                                    reg_iterations=reg_iterations, 
                                    aff_iterations=aff_iterations, 
                                    aff_shrink_factors=aff_shrink_factors, 
                                    aff_smoothing_sigmas=aff_smoothing_sigmas,
                                    grad_step=grad_step, 
                                    flow_sigma=flow_sigma, 
                                    total_sigma=total_sigma,
                                    aff_sampling=aff_sampling, 
                                    syn_sampling=syn_sampling,
                                    check_ch=ch_names[0],                       
                                    save=True, 
                                    save_path=save_path,
                                    save_file='neuron'+str(l)+'_'+ants_step+'_'+file_4D)
            print('finished postshift on neuron %i run with' %l, drift_t)
            del image
    print('total ruuntime of postshift', timer()-start_time)

##The parts for N2V, clahe, mask and segment are to be added