# G5HT-PIPELINE

## TODO:

1. I wonder if I computed a spline on each and every z slice and warped each, oriented each of them, and warped each of them, if the problem of weirdly sheared image stacks would be solved
2. quick mp4 for all recordings
   1. now working in engaging, works per one nd2 sbatch
3. focus check for all recordings
   1. maybe focus check can be used to specify which z slices are good to use and which frames are good to use
4. for recordings starting in december 2025, need to trim first 2 rather than last 2 z slices
5. flip worms so that VNC is always up
6. fixed mask could be automated, but if not, make sure to save which index is fixed
7. extract behavior
8. posture similarity
   1. posture might consist of the spline + thresholded z-stack
      1. I'm thinking that the orientation shouldn't matter, but the z-planes in focus will, and curvature/spline of the head will
      2. maybe need to actually interpolate to 117 z slices
   2. sub registration problems
   3. label each set of registered frames with one set of ROIs, or auto segment ROIs from each set of registered frames
9.  track z over time, which zslices are consistent
   1. focus + correlation
10. beads -> train/test
11. gfp+1 relative to rfp channel (might only apply to pre december 2025 recordings)
12. wholistic 
    1.  parameter sweep, might change
    2.  python version
    3.  actually, wholistic might be tricky to use all the time, because it only works after parameter optimization, which I don't really know how to automate
13. autocorr/scorr
14. automate z slice trimming
    1.  pre december 2025 (trim last 2 z slices)
    2.  post december 2025 (trim first z slice)
15. photobleaching estimation?
    1.  record immo with serotonin
    2.  at least do it for RFP
16. try deltaF/F [ (F(t) - F0) / F0 ]
18. maybe remove right-most part of the worm where spline is usually kinked
19. port everything to engaging

## CONDA ENVIRONMENTS

For steps __1. preprocess__ and __2. mip__, `conda activate g5ht-pipeline`

For step __3. segment__, `conda activate segment-torch` or `conda activate torchcu129`

For step __4. spline, 5. orient, 6. warp, 7. reg__


## IMPORTS

In [2]:
import sys
import os
import importlib
from tqdm import tqdm

try:
    import utils
    is_torch_env = False
except ImportError:
    is_torch_env = True
    print("utils not loaded because conda environment doesn't have nd2reader installed. probably using torchcu129 env, which is totally fine for just doing the segmentation step")

## SPECIFY DATA TO PROCESS

In [3]:
# DATA_PTH = r'C:\Users\munib\POSTDOC\DATA\fluorescent_beads_ch_align\20251219'
DATA_PTH = r'D:\DATA\g5ht-free\20260123'

INPUT_ND2 = 'date-20260123_strain-ISg5HT-nsIS180_condition-fedpatch_worm004.nd2'

INPUT_ND2_PTH = os.path.join(DATA_PTH, INPUT_ND2)

NOISE_PTH = r'C:\Users\munib\POSTDOC\CODE\g5ht-pipeline\noise\noise_042925.tif'

OUT_DIR = os.path.splitext(INPUT_ND2_PTH)[0]

STACK_LENGTH = 41

if not is_torch_env:
    noise_stack = utils.get_noise_stack(NOISE_PTH, STACK_LENGTH)
    num_frames, height, width, num_channels = utils.get_range_from_nd2(INPUT_ND2_PTH, stack_length=STACK_LENGTH) 
    beads_alignment_file = utils.get_beads_alignment_file(INPUT_ND2_PTH)
else:
    print("utils not loaded because conda environment doesn't have nd2reader installed. probably using torchcu129 env, which is totally fine for just doing the segmentation step")

print(INPUT_ND2)
print('Num z-slices: ', STACK_LENGTH)
if not is_torch_env:
    print('Number of frames: ', num_frames)
    print('Height: ', height)
    print('width: ', width)
    print('Number of channels: ', num_channels)
    print('Beads alignment file: ', beads_alignment_file)

date-20260123_strain-ISg5HT-nsIS180_condition-fedpatch_worm004.nd2
Num z-slices:  41
Number of frames:  1200
Height:  512
width:  512
Number of channels:  2
Beads alignment file:  D:\DATA\g5ht-free\20260123\date-20260123_strain-ISg5HT-nsIS180_condition-fedpatch_worm004_chan_alignment.nd2


## 0. PROCESS BEADS ALIGNMENT DATA (OPTIONAL, CHANGING THIS SO BEADS ARE PROCESSED SEAMLESSLY IN PIPELINE)

` conda activate g5ht-pipeline`

The registration parameters between green and red channels will be applied to worm recordings

### SHEAR CORRECT AND CHANNEL REGISTER

In [None]:
from preprocess_parallel import main as preprocess_nd2_parallel
_ = importlib.reload(sys.modules['preprocess_parallel'])

num_frames_beads, _, _, _ = utils.get_range_from_nd2(beads_alignment_file, stack_length=STACK_LENGTH) 

# # command-line arguments
sys.argv = ["", beads_alignment_file, "0", str(num_frames_beads-1), NOISE_PTH, STACK_LENGTH, 5, num_frames_beads, height, width, num_channels]

# # Call the main function
preprocess_nd2_parallel()

### MIP

This step saved the median channel registration parameters, need to do this somewhere else

In [None]:
from mip import main as mip

_ = importlib.reload(sys.modules['mip'])
_ = importlib.reload(sys.modules['utils'])

# command-line arguments
sys.argv = ["", beads_alignment_file, STACK_LENGTH, num_frames_beads, 2]

# Call the main function
mip()

## 1. SHEAR CORRECTION

` conda activate g5ht-pipeline`

- shear corrects each volume
  - depending on each exposure time, it can take roughly half a second between the first and last frames of a volume, so any movements need to be corrected for
- creates one `.tif` for each volume and stores it in the `shear_corrected` directory

##### TODO: should probably update stack length after shear correction since we cut it by 2, although not sure it's explicitly needed

In [None]:
import shear_correct
_ = importlib.reload(sys.modules['shear_correct'])

start_index = "0"
end_index = str(num_frames-1)
# start_index = "800"
# end_index = "803"
# cpu_count = str(int(os.cpu_count() / 2))
cpu_count = str(int(os.cpu_count()))

# sys.argv = ["", nd2 file, start_frame, end_frame, noise_pth, stack_length, n_workers, num_frames, height, width, num_channels]
sys.argv = ["", INPUT_ND2_PTH, start_index, end_index, NOISE_PTH, STACK_LENGTH, cpu_count, num_frames, height, width, num_channels]

# Call the main function
shear_correct.main()

## 2. CHANNEL ALIGNMENT

` conda activate g5ht-pipeline`

### 2a. GET MEDIAN CHANNEL ALIGNMENT PARAMETERS FROM ALL FRAMES

- If channel alignment file found, uses that, if not uses worm recording
- creates a `.txt` file for each volume that contains elastix channel registration parameters
- creates `chan_align_params.csv` and  `chan_align.txt`

In [None]:
beads_alignment_file

In [None]:
import get_channel_alignment
import median_channel_alignment
_ = importlib.reload(sys.modules['get_channel_alignment'])
_ = importlib.reload(sys.modules['median_channel_alignment'])

## set beads_alignment_file to None to use worm recording for channel alignment, even if beads file exists
# beads_alignment_file = None

start_index = "0"
# cpu_count = str(int(os.cpu_count() / 2))
cpu_count = str(int(os.cpu_count()))

if beads_alignment_file is not None:
    align_with_beads = True
    num_frames_beads, _, _, _ = utils.get_range_from_nd2(beads_alignment_file, stack_length=STACK_LENGTH) 
    sys.argv = ["", beads_alignment_file, start_index, str(num_frames_beads-1), NOISE_PTH, STACK_LENGTH, cpu_count, num_frames_beads, height, width, num_channels, align_with_beads]
else:
    align_with_beads = False
    sys.argv = ["", INPUT_ND2_PTH, start_index, str(num_frames-1), NOISE_PTH, STACK_LENGTH, cpu_count, num_frames, height, width, num_channels, align_with_beads]

# # Call the main function
get_channel_alignment.main()
median_channel_alignment.main()


### 2b. APPLY MEDIAN CHANNEL ALIGNMENT PARAMETERS

- ouputs aligned volumes in `channel_aligned` directory

In [None]:
import apply_channel_alignment
_ = importlib.reload(sys.modules['apply_channel_alignment'])

start_index = "0"
# cpu_count = str(int(os.cpu_count() / 2))
cpu_count = str(int(os.cpu_count()))

# 0786 to 0799 are bad frames in worm005.nd2, copied 0785 for each of those frames

if beads_alignment_file is not None:
    align_with_beads = True
    num_frames_beads, _, _, _ = utils.get_range_from_nd2(beads_alignment_file, stack_length=STACK_LENGTH) 
    sys.argv = ["", INPUT_ND2_PTH, start_index, str(num_frames-1), NOISE_PTH, STACK_LENGTH, cpu_count, num_frames, height, width, num_channels, align_with_beads, beads_alignment_file]
else:
    align_with_beads = False
    sys.argv = ["", INPUT_ND2_PTH, start_index, str(num_frames-1), NOISE_PTH, STACK_LENGTH, cpu_count, num_frames, height, width, num_channels, align_with_beads]


# Call the main function
apply_channel_alignment.main()

In [None]:
# # create copies of 0785 and rename it to 0786 to 0799
# import shutil
# for i in range(786, 800):
#     shutil.copyfile(r'C:\Users\munib\POSTDOC\DATA\g5ht-free\20251223\date-20251223_strain-ISg5HT_condition-starvedpatch_worm005\channel_aligned\0785.tif',
#                     r'C:\Users\munib\POSTDOC\DATA\g5ht-free\20251223\date-20251223_strain-ISg5HT_condition-starvedpatch_worm005\channel_aligned\{:04d}.tif'.format(i))

### 2c. PLOT CHANNEL ALIGNMENT PARAMETER DISTRIBUTIONS

In [None]:
out_dir

In [None]:
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np

# make font sizes larger for visibility
plt.rcParams.update({'font.size': 18})

try:
    out_dir = os.path.splitext(INPUT_ND2_PTH)[0]

    df = pd.read_csv(os.path.join(out_dir, 'chan_align_params.csv'))
    params = ['TransformParameter_0', 'TransformParameter_1', 'TransformParameter_2', 'TransformParameter_3', 'TransformParameter_4', 'TransformParameter_5']
    labels = ['Rx', 'Ry', 'Rz', 'Tx', 'Ty', 'Tz']

    # the xaxis limits for each subplot should be the same across figures

    xlims = np.zeros((6,2))

    plt.figure(figsize=(12,8), tight_layout=True)
    for i,param in enumerate(params):
        plt.subplot(2, 3, i+1)
        plt.hist(df[param], bins=30, color='red', alpha=0.6)
        # plot the median value as a vertical line
        median_value = df[param].median()
        plt.axvline(median_value, color='black', linestyle='dashed', linewidth=2)
        plt.xlabel(labels[i])
        plt.ylabel('Frequency')
        # get xaxis limits
        xlims[i,:] = plt.xlim()
        # title is median value
        plt.title(f'Median: {np.round(median_value,3)}', fontsize=14)
    plt.show()
except FileNotFoundError:
    print("No chan_align_params.csv found for worm recording")

out_dir = os.path.splitext(INPUT_ND2_PTH)[0] + '_chan_alignment'
df = pd.read_csv(os.path.join(out_dir, 'chan_align_params.csv'))
params = ['TransformParameter_0', 'TransformParameter_1', 'TransformParameter_2', 'TransformParameter_3', 'TransformParameter_4', 'TransformParameter_5']
labels = ['Rx', 'Ry', 'Rz', 'Tx', 'Ty', 'Tz']

plt.figure(figsize=(12,8), tight_layout=True)
for i,param in enumerate(params):
    plt.subplot(2, 3, i+1)
    plt.hist(df[param], bins=30, color='blue', alpha=0.6)
    # plot the median value as a vertical line
    median_value = df[param].median()
    plt.axvline(median_value, color='black', linestyle='dashed', linewidth=2)
    plt.xlabel(labels[i])
    plt.ylabel('Frequency')
    # apply xlims
    # plt.xlim(xlims[i,0], xlims[i,1])
    # title is median value, font size 14
    plt.title(f'Median: {np.round(median_value,3)}', fontsize=14)
plt.show()

## 3. BLEACH CORRECTION

TODO:
- per z slice?



In [None]:
import importlib
import os
import sys

import bleach_correct
_ = importlib.reload(sys.modules['bleach_correct'])


PTH = os.path.splitext(INPUT_ND2_PTH)[0]
REG_DIR = 'channel_aligned' # 'channel_aligned' or 'tif' 
channels = 1
method = 'block' # 'block' or 'exponential'
mode = 'total' # 'total' or 'median'

bleach_correct.correct_bleaching(os.path.join(PTH,REG_DIR), channels=channels, method=method, fbc=0.04, intensity_mode=mode)


# # Correct RFP only with block method (default)
# correct_bleaching("path/to/data")

# # Correct both channels with exponential fit
# correct_bleaching("path/to/data", channels=[0, 1], method='exponential')

# # Command line
# python bleach_correct.py path/to/data --channels 0 1 --method exponential

## 4. MIP

` conda activate g5ht-pipeline`

- outputs `means.png`, `focus.png`, `mip.tif`, and `mip.mp4`, `focus_check.csv`

##### TODO: 
- legend for focus.png, should be frame#
- mip for xy, xz, zy
- mip for several slices

In [None]:
import mip

_ = importlib.reload(sys.modules['mip'])
_ = importlib.reload(sys.modules['utils'])

# command-line arguments
framerate = 8
tif_dir = 'bleach_corrected_RFP_block' # one of 'shear_corrected' 'channel_aligned' 'bleach_corrected_RFP_block'
# tif_dir = 'channel_aligned_beads'
rmax = 850
gmax = 150
mp4_quality = 10
sys.argv = ["", INPUT_ND2_PTH, tif_dir, STACK_LENGTH, num_frames, framerate, rmax, gmax, mp4_quality]

# Call the main function
mip.main()

## 5 DRIFT ESTIMATION

` conda activate g5ht-pipeline`

- outputs  `z_selection.csv`, `z_selection_diagnostics.png`, `sharpness.csv`

TODO:
- use z selection going forward
- also use sharpness/focus (and other things) to determine good/bad frames

In [None]:
import drift_estimation

_ = importlib.reload(sys.modules['drift_estimation'])
_ = importlib.reload(sys.modules['utils'])

# command-line arguments
tif_dir = 'bleach_corrected_RFP_block' # one of 'shear_corrected' 'channel_aligned' 'bleach_corrected_RFP_block'

sys.argv = ["", INPUT_ND2_PTH, tif_dir, STACK_LENGTH, num_frames]

# Call the main function
drift_estimation.main()

## 5. SEGMENT

- outputs `label.tif`, contains segmented MIP for each volume

__on home pc__: 
`conda activate segment-torch`

Uses a separate conda environment from the rest of the pipeline. create it using:
`conda env create -f segment_torch.yml`

__on lab pc__: 
`conda activate torchcu129`

Uses a separate conda environment from the rest of the pipeline. create it following steps in:
`segment_torch_cu129_environment.yml`

### setup each time model weights change
Need to set path to model weights as `CHECKPOINT` in `eval_torch.py`

In [None]:
import segment.segment_torch
_ = importlib.reload(sys.modules['segment.segment_torch'])

MIP_PTH = os.path.join(os.path.splitext(INPUT_ND2_PTH)[0], 'mip_bleach_corrected_RFP_block.tif')

# command-line arguments
sys.argv = ["", MIP_PTH]

segment.segment_torch.main()

## 6. SPLINE

`conda activate g5ht-pipeline`

- outputs `spline.json`, `spline.tif`, and `dilated.tif`

In [None]:
import spline
_ = importlib.reload(sys.modules['spline'])

LABEL_PTH = MIP_PTH = os.path.join(os.path.splitext(INPUT_ND2_PTH)[0], 'label.tif')

# command-line arguments
sys.argv = ["", LABEL_PTH]

spline.main()

## 7. ORIENT

`conda activate g5ht-pipeline`

- outputs `oriented.json`, `oriented.png`, `oriented_stack.tif`

NOTE: `orient_v2.py` automated the process of finding orientation completely, whereas `orient.py` requires you to input the (x,y) nose location on the first frame

In [None]:
import orient
_ = importlib.reload(sys.modules['orient'])

SPLINE_PTH = MIP_PTH = os.path.join(os.path.splitext(INPUT_ND2_PTH)[0], 'spline.json')
nose_y = 250
nose_x = 45

# apply constraints
# might need this when there are frames where the spline fitting fails and orientation is lost intermittently
constrain_frame = 515
constrain_frame_nose_y = 288
constrain_frame_nose_x = 180

# command-line arguments
# sys.argv = ["", SPLINE_PTH, str(nose_y), str(nose_x)]
sys.argv = ["", SPLINE_PTH, str(nose_y), str(nose_x), str(constrain_frame), str(constrain_frame_nose_y), str(constrain_frame_nose_x)]

orient.main()

In [None]:
import orient_v2 # tried to automate finding nose point, not working well at the moment
_ = importlib.reload(sys.modules['orient_v2'])

SPLINE_PTH = MIP_PTH = os.path.join(os.path.splitext(INPUT_ND2_PTH)[0], 'spline.json')

# command-line arguments
sys.argv = ["", SPLINE_PTH]

orient_v2.main()

## 8. WARP

`conda activate g5ht-pipeline`

- ouputs: `warped/*.tif` and `masks/*.tif`

TODO: parallelize

In [3]:
import warp
_ = importlib.reload(sys.modules['warp'])

PTH = os.path.splitext(INPUT_ND2_PTH)[0]

start_index = 516
end_index = num_frames

for i in tqdm(range(start_index, end_index)):
    # command-line arguments
    sys.argv = ["", PTH, i]

    warp.main()

100%|██████████| 684/684 [1:08:33<00:00,  6.01s/it]


## 9. REGISTER

`conda activate g5ht-pipeline`

__ALTERNATIVELY__: register using the wholistic registration algorithm, currently in MATLAB

TODO: parallelize / make faster

- pick a good representative fixed frame that you want to register everything to
  - copy it to the main output folder and name it `fixed_xxxx.tif`
  - copy the corresponding mask and name it `fixed_mask_xxxx.tif`

In [4]:
import reg
_ = importlib.reload(sys.modules['reg'])

PTH = os.path.splitext(INPUT_ND2_PTH)[0]

start_index = 222
end_index = num_frames
zoom = 1 # albert was using 3
# zoom = 3

for i in tqdm(range(start_index, end_index)):
    # command-line arguments
    try:
        sys.argv = ["", PTH, i, str(zoom)]
        reg.main()
    except Exception as e:
        print(f"Error processing index {i}: {e}")   

 25%|██▌       | 248/978 [1:30:28<3:43:36, 18.38s/it]

Error processing index 469: D:\a\im\build\cp311-abi3-win_amd64\_deps\elx-src\Core\Main\itkElastixRegistrationMethod.hxx:389:
ITK ERROR: ElastixRegistrationMethod(0000020C65C2F9F0): Internal elastix error: See elastix log (use LogToConsoleOn() or LogToFileOn()).


 27%|██▋       | 261/978 [1:35:38<4:47:10, 24.03s/it]

Error processing index 483: [Errno 2] No such file or directory: 'D:\\DATA\\g5ht-free\\20260123\\date-20260123_strain-ISg5HT-nsIS180_condition-fedpatch_worm004\\warped\\0483.tif'
Error processing index 484: [Errno 2] No such file or directory: 'D:\\DATA\\g5ht-free\\20260123\\date-20260123_strain-ISg5HT-nsIS180_condition-fedpatch_worm004\\warped\\0484.tif'
Error processing index 485: [Errno 2] No such file or directory: 'D:\\DATA\\g5ht-free\\20260123\\date-20260123_strain-ISg5HT-nsIS180_condition-fedpatch_worm004\\warped\\0485.tif'
Error processing index 486: [Errno 2] No such file or directory: 'D:\\DATA\\g5ht-free\\20260123\\date-20260123_strain-ISg5HT-nsIS180_condition-fedpatch_worm004\\warped\\0486.tif'
Error processing index 487: [Errno 2] No such file or directory: 'D:\\DATA\\g5ht-free\\20260123\\date-20260123_strain-ISg5HT-nsIS180_condition-fedpatch_worm004\\warped\\0487.tif'
Error processing index 488: [Errno 2] No such file or directory: 'D:\\DATA\\g5ht-free\\20260123\\date-202

100%|██████████| 978/978 [6:22:36<00:00, 23.47s/it]   


### REGISTER WITH GFP+1 TO RFP

TRIM LAST RFP ZSLICE, TRIM FIRST GFP ZSLICE

seems to be that as of 20251204, all recordings were taken such that the i zslice in red channel corresponds to i+1 zslice in green channel

In [None]:
import sys
import os
from tqdm import tqdm
import importlib

from reg_gfp_indexing import main as reg_worm

PTH = r'C:\Users\munib\POSTDOC\DATA\g5ht-free\20251028\date-20251028_time-1500_strain-ISg5HT_condition-starvedpatch_worm001_aligned'

for i in tqdm(range(1200)):
    # command-line arguments
    sys.argv = ["", PTH, i, "1"]
    reg_worm()

### MAKE MOVIES OF REGISTERED DATA (see `reg_microfilm.ipynb`)

### REGISTER SINGLE FRAMES WITH ERROR LOGGING

In [None]:
import tifffile
import numpy as np
import scipy.ndimage as ndi
import itk
import sys
import os
import glob

from itk import image_view_from_array

#get channels out of stacks
def register_one(fixed_stack, fixed_mask_stack, moving_stack, moving_mask_stack):
	fixed_rfp = fixed_stack[:, 1].copy()
	moving_gfp, moving_rfp = moving_stack[:, 0].copy(), moving_stack[:, 1].copy()

	#initialize registration parameters
	parameter_object = itk.ParameterObject.New()
	default_rigid_parameter_map = parameter_object.GetDefaultParameterMap('rigid', 4)
	parameter_object.AddParameterMap(default_rigid_parameter_map)
	default_affine_parameter_map = parameter_object.GetDefaultParameterMap('affine', 4)
	parameter_object.AddParameterMap(default_affine_parameter_map)
	default_bspline_128_parameter_map = parameter_object.GetDefaultParameterMap('bspline', 4, 128)
	parameter_object.AddParameterMap(default_bspline_128_parameter_map)
	default_bspline_64_parameter_map = parameter_object.GetDefaultParameterMap('bspline', 4, 64)
	parameter_object.AddParameterMap(default_bspline_64_parameter_map)
	default_bspline_32_parameter_map = parameter_object.GetDefaultParameterMap('bspline', 4, 32)
	parameter_object.AddParameterMap(default_bspline_32_parameter_map)

	#convert to itk images
	fixed_rfp = itk.image_view_from_array(fixed_rfp.astype(np.float32))
	moving_rfp = itk.image_view_from_array(moving_rfp.astype(np.float32))

	fixed_mask_stack = itk.image_view_from_array(fixed_mask_stack.astype(np.ubyte))
	moving_mask_stack = itk.image_view_from_array(moving_mask_stack.astype(np.ubyte))

	#register rfp first and then apply transform to gfp

	registered_rfp, transform_parameters = itk.elastix_registration_method(fixed_rfp, moving_rfp, parameter_object,
																		   fixed_mask=fixed_mask_stack, moving_mask=moving_mask_stack,
																		   log_to_console=True)
	registered_gfp = itk.transformix_filter(moving_gfp, transform_parameters)

	#initialize and fill output
	output_stack = np.zeros((fixed_stack.shape[0], 2, 200, 500), np.uint16)
	output_stack[:, 0] = np.clip(registered_gfp, 0, 4095)
	output_stack[:, 1] = np.clip(registered_rfp, 0, 4095)

	return output_stack
    
	# # enablle elastic error logging
	# elastix_filter = itk.ElastixRegistrationMethod.New(fixed_rfp, moving_rfp, parameter_object,
	# 																	   fixed_mask=fixed_mask_stack, moving_mask=moving_mask_stack,
	# 																	   log_to_console=True) 
    
	# elastix_filter.SetParameterObject(parameter_object)
	# elastix_filter.SetNumberOfThreads(8)
	# elastix_filter.LogToConsoleOn()  # Enable console logging
	# elastix_filter.LogToFileOn()
	# elastix_filter.SetOutputDirectory(r"C:\Users\munib\POSTDOC\CODE\g5ht-pipeline\logs")
	# elastix_filter.Update()

	# return elastix_filter.GetOutput(), elastix_filter.GetTransformParameterObject()

input_dir = os.path.splitext(INPUT_ND2_PTH)[0]
warped_path = os.path.join(input_dir, 'warped')
output_path = os.path.join(input_dir, 'registered_fixed_sweep')
os.makedirs(output_path, exist_ok=True)

fixed_list = np.arange(0, 400, 50)
mov_list = np.arange(0, 400, 30)

for fixed in fixed_list:
    for mov in mov_list:
        print(f'Processing fixed: {fixed}, moving: {mov}')
        
		# load stacks
	
        moving_path = os.path.join(warped_path,f'{mov:04d}.tif')
        moving_stack = tifffile.imread(moving_path).astype(np.float32)
        fixed_path = os.path.join(warped_path,f'{fixed:04d}.tif')
        fixed_stack = tifffile.imread(fixed_path).astype(np.float32)
        fixed_mask_path = os.path.join(input_dir, 'masks', f'{fixed:04d}.tif')
        fixed_mask = tifffile.imread(fixed_mask_path)
        fixed_mask_stack = np.stack([fixed_mask] * fixed_stack.shape[0])

        moving_mask_path = os.path.join(input_dir, 'masks', f'{mov:04d}.tif')
        moving_mask = tifffile.imread(moving_mask_path)
        moving_mask_stack = np.stack([moving_mask] * fixed_stack.shape[0])


        output_stack = register_one(fixed_stack, fixed_mask_stack, moving_stack, moving_mask_stack)
        # save output stack with fixed and moving indices in filename
        output_file_path = os.path.join(output_path, f'fixed_{fixed:04d}_mov_{mov:04d}.tif')
        tifffile.imwrite(output_file_path, output_stack, imagej=True)

# fixed = 100
# mov = 200

# # load stacks
# moving_path = os.path.join(warped_path,f'{mov:04d}.tif')
# moving_stack = tifffile.imread(moving_path).astype(np.float32)
# fixed_path = os.path.join(warped_path,f'{fixed:04d}.tif')
# fixed_stack = tifffile.imread(fixed_path).astype(np.float32)
# fixed_mask_path = os.path.join(input_dir, 'masks', f'{fixed:04d}.tif')
# fixed_mask = tifffile.imread(fixed_mask_path)
# fixed_mask_stack = np.stack([fixed_mask] * fixed_stack.shape[0])

# moving_mask_path = os.path.join(input_dir, 'masks', f'{mov:04d}.tif')
# moving_mask = tifffile.imread(moving_mask_path)
# moving_mask_stack = np.stack([moving_mask] * fixed_stack.shape[0])


# output_stack = register_one(fixed_stack, fixed_mask_stack, moving_stack, moving_mask_stack)

In [None]:
output_stack = register_one(fixed_stack, fixed_mask_stack, moving_stack, moving_mask_stack)

In [None]:
output_stack.shape

In [None]:

import matplotlib.pyplot as plt

plt.figure()
plt.pcolormesh(fixed_stack[10,1,:,:])
plt.show()

plt.figure()
plt.pcolormesh(output_stack[10,1,:,:])
plt.show()