In [None]:
%matplotlib widget

import numpy as np
import cupy as cp
from cupyx.scipy.ndimage import maximum_filter, median_filter
import matplotlib.pyplot as plt
from tqdm.auto import tqdm
import os, glob, h5py
import importlib, pathlib
import warpfield

from daio.h5 import lazyh5
from video import create_projection_image, recording_to_overlay_preview, showvid, get_lenses, array3d_to_video, AVWriter2
from widgets import get_mask_widget, ortho_views_widget, play_video_widget
from register import average_volumes, mini_registration, save_register_recipe, register_recording
from reconstruction import reconstruct_vols_from_imgs, reconstruct_vols_from_imgs_parallel
from i_o import Paths, VolumeReader, RegisteredVolumeReader, get_stimulus, parse_combinations
from signal_extraction import extract_traces

cmap = "gray"

## setup paths

In [None]:
# pn_psf = r"Y:/hpc/r/lfm2025/psfs"
pn_psfs = r"~/hpc-rw/lfm/psfs" 
includes = [''] 
excludes = []

recs = [os.path.basename(rec) for rec in sorted(glob.glob(os.path.expanduser(os.path.join(pn_psfs, '*'))), reverse=True) if os.path.exists(os.path.join(pn_psfs, rec, 'psf.h5'))] 
recs = [rec for rec in recs if all([s in rec for s in includes]) and not any([s in rec for s in excludes])] 
print('\n'.join(recs))

In [None]:
pn_bg = r"~/hpc-rw/lfm/bg"
for file in sorted(pathlib.Path(os.path.expanduser(pn_bg)).iterdir()):
    if file.is_file() and file.suffix == '.npy':
        print(file.name)

In [None]:
pn_rec = r"~/hpc-r/lfm2025/recordings"
# pn_rec = r"~/hpc-r/lfm2025/recordings/2025_0818_zf_audiostim"

includes = ['']
excludes = ["bg"]

recs = [os.path.basename(rec) for rec in sorted(glob.glob(os.path.expanduser(os.path.join(pn_rec, '*'))), reverse=True) if os.path.exists(os.path.join(pn_rec, rec, 'data.h5'))]
recs = [rec for rec in recs if all([s in rec for s in includes]) and not any([s in rec for s in excludes])]
print('\n'.join(recs))

In [None]:
pn_psfs = r"~/hpc-rw/lfm/psfs"
pn_rec = r"~/hpc-r/lfm2025/recordings"
pn_bg = r"~/hpc-rw/lfm/bg"

url_home = r"/home/lubo12/"
pn_out = r"~/hpc-rw/lfm/results"

dataset_name = "20250902_1940_LB_ZF552_f2_stim2_3"
psf_name = "20250731_1539_PSF_LB_redFB_25x_2"
bg_name = "20250819_1730_LB_25x_75fps_bg.npy"

paths = Paths(dataset_name=dataset_name,
              psf_name = psf_name,
              bg_name=bg_name,
              pn_bg= pn_bg,
              pn_rec = pn_rec,
              pn_psfs=pn_psfs,
              pn_out=pn_out,
              url_home=url_home,
              )

## load and inspect data

In [None]:
psf_f = lazyh5(paths.psf)
crop = np.array(psf_f["crop"])
mask = np.array(psf_f["circle_mask"][crop[0]:crop[1],crop[2]:crop[3]])
psf = np.array(psf_f["psf"])
zpos=np.array(psf_f["z_positions"])
print(psf.shape)

In [None]:
fig, ax = plt.subplots(1,1,figsize = (7,7))
mip = create_projection_image(psf, scalebar=200, vmax = 80, zpos=zpos, text= "PSF", text_size=4, gpu = False)
ax.imshow(mip, cmap='gray')
ax.set_xticks([])
ax.set_yticks([])

In [None]:
vmax=80

fig, ax = plt.subplots(2,1,figsize = (7,7))
bg = lazyh5(paths.bg)["data"][crop[0]:crop[1],crop[2]:crop[3]]

#select which images to look at
with h5py.File(paths.raw, "r") as f:
    ims = np.array(f["data"][225:250,crop[0]:crop[1],crop[2]:crop[3]])
im = ims[0]

ax[0].imshow(im, cmap = "gray",vmax=vmax)
ax[0].set_xticks([])
ax[0].set_yticks([])

ax[1].imshow(((im-bg)*mask), cmap="gray",vmax=vmax)
ax[1].set_xticks([])
ax[1].set_yticks([])
print(dataset_name)

### (optional) 

In [None]:
# inspect raw mean brightness
from i_o import VolumeReader
reader = VolumeReader(paths.raw)
means = []
for frame_n, frame in tqdm(reader, desc="Calculating mean brightness"):
    means.append(cp.asarray(frame).flatten().mean().get())
means = np.array(means)
fig, ax = plt.subplots(1,1,figsize = (16,5))
ax.plot(means)
# ax.set_ylim(4,4.5)
np.save(paths.pn_outrec+"raw_means.npy", means)

In [None]:
# stimulus annotated video of the raw data, takes a while
import video, i_o, importlib
importlib.reload(video)
from video import rawh5_to_video
from i_o import get_stimulus,

stim_path = os.path.expanduser(os.path.join(pn_rec, "stimset_LB_zf2"))

_, stim_names_og, stimulus_ids, final_stimulus, sr = get_stimulus(stim_path, fps = 75,)


fn_vid, fn_df_vid =  rawh5_to_video(paths, 
                                   df=True,
                                   stim_labels=[stim_names_og[i] for i in stimulus_ids],
                                   fps=75,
                                   vmin=0, 
                                   vmax=100,
                                   absolute_limits=False,
                                   df_tau=100, 
                                   df_vmin=-0.5,
                                   df_vmax=0.5,
                                    df_bitrate=10000000,
                                    df_absolute_limits=True,)

In [None]:
play_video_widget(fn_vid)
# play_video_widget(fn_df_vid)

In [None]:
# look at the intensity histograms of the normalized vs non normalized frames

def normalize_mean(arr):
    return arr/arr.mean()
fig, ax = plt.subplots(1,2,figsize = (16,5))
ax[0].set_title("raw")
ax[0].hist(im.flatten(),bins=255)
ax[0].set_ylim(0,100000)
ax[1].set_title("normalized")
ax[1].hist(normalize_mean(((im-bg).clip(0,None)*mask).flatten()),bins=255)
ax[1].set_ylim(0,100000)


## Testing deconvolution


In [None]:
import reconstruction
importlib.reload(reconstruction)
from reconstruction import reconstruct_vols_from_imgs

img_idx = [325,326,1] # arguments for the range used to index the raw data

print(f"PSF zspacing: {np.abs(np.diff(zpos[::1])).mean()*1000} um")

objs, mips, losses, kwargs = reconstruct_vols_from_imgs(paths,
                                                        img_idx=img_idx,
                                                        max_iter = 30, #if the deconv is to blurry increase, if there is too much noise decrease
                                                        roi_size=550, # if any part of the fish is cut off at teh edge increase this, otherwise decrease this
                                                        psf_downsample = [11,250,1], #arguments for the range used to index the PSF. should stay wihtin PSF bounds
                                                        img_subtract_bg=False,
                                                        img_mask=True,
                                                        img_normalize=False,
                                                        plot_decon=True, #plot intensity projection at every deconvolution step
                                                        
                                                        # parameters for the video - here for the deconvolution video, when deconvolving full recording for the video of every frame
                                                        projection="max",
                                                        slice_idx=[35,428,609], #the slice to plot if the projection is "slice"
                                                        vmin=0,
                                                        vmax=6,
                                                        absolute_limits=True,

                                                        OTF_normalize=True, #keep true
                                                        OTF_clip=True, #keep true
                                                        reuse_prev_vol = False, #deprecated, keep as false
                                                        crop = crop, # deprecated, in PSFs without aperture, this is just the same dimensions as the whole camera frame
                                                       )

### (loading parameters from previously deconvolved dataset)

In [None]:
# for loading paramteters from a previously deconvolved recording
reg = lazyh5(paths.deconvolved[:-3]+".h5")
kwargs = dict(reg["deconvolution_params"])
# kwargs.pop("out_crop")
# kwargs["img_normalize"]=False
kwargs["vmax"]=7
kwargs["transpose"] = False
# kwargs["vmin"]=-0.5
kwargs

In [None]:
import reconstruction
importlib.reload(reconstruction)
from reconstruction import reconstruct_vols_from_imgs

img_idx = [325,326,1]

print(f"PSF zspacing: {np.abs(np.diff(zpos[::1])).mean()*1000} um")

objs, mips, losses, kwargs = reconstruct_vols_from_imgs(paths,
                                                        img_idx=img_idx,
                                                        **kwargs)

### inspect test denconvolution

In [None]:
fig, ax = plt.subplots(1,1,figsize = (7,7))

# objs.shape = [frames (=len(range(*img_idx))) , z (=len(range(*psf_downsample))) , y (= roi_size) , x (= roi_size)]

#adjust the vmax here to find the right one for the full deconvolution

ax.imshow(create_projection_image(objs[-1,:,:,:], vmax = 20, absolute_limits = True,transpose = True,
                                 scalebar = 200, zpos = zpos, pad = 30), cmap = "gray")
# out_crop is defined further down, to reduce the size of the deconvolved data
# ax.imshow(create_projection_image(objs[-1,:,out_crop[0]:out_crop[1],out_crop[2]:out_crop[3]], vmax = 20, absolute_limits = True,transpose = True,
#                                  scalebar = 200, zpos = zpos, pad = 30), cmap = "gray")
ax.set_xticks([])
ax.set_yticks([])

In [None]:
ortho_views_widget(objs[0], vmin=0, vmax=10, gpu=False)

In [None]:
#inspect the deconvolution itself to find optimal number of deconvolution iterations

from video import array3d_to_video

fn_vid = os.path.join(paths.pn_outrec , f"deconv_mips_{kwargs["vmin"]}-{kwargs["vmax"]}{'_al' if kwargs["absolute_limits"]else''}.mp4")

array3d_to_video(mips[0], fn_vid)

play_video_widget(fn_vid)

In [None]:
# define out_crop to rduce data storage size

x1, y1 = 270, 20 #top left
x2, y2 = 830, 950 #bottom right
fig, ax = plt.subplots(1,1,figsize=(10,10))
ax.imshow(objs[0].max(axis=0), cmap = "gray", vmax=4)
rect = plt.Rectangle(
        (x1, y1), x2 - x1, y2 - y1,
        linewidth=2, edgecolor="cyan", facecolor="cyan", alpha=0.3
    )
ax.add_patch(rect)
out_crop = (y1,y2,x1,x2)
kwargs["out_crop"] = out_crop

In [None]:
#(optional) final adjustment to the parameters for the full deconvolution
kwargs["vmax"] = 5
kwargs["transpose"]=True
kwargs["max_iter"] = 30
kwargs

## Deconvolve the whole recording

### locally

In [None]:
# a lot slower, since it can only use 1 GPU shard on notebook
import reconstruction
importlib.reload(reconstruction)
from reconstruction import reconstruct_vols_from_imgs_parallel

kwargs, save_fn, vid_fn = reconstruct_vols_from_imgs_parallel(paths,img_idx=None,write_mip_video=True, verbose = 1, **kwargs)
print(save_fn,"\n",vid_fn)

### on the cluster

In [None]:
import slurm, reconstruction
importlib.reload(slurm)
from slurm import PythonExecutorSLURM
importlib.reload(reconstruction)
from reconstruction import reconstruct_vols_from_imgs_parallel

email = 'lukas.born@charite.de' 
#cmd_prefix = ['ssh', 'lubo12@s-sc-frontend1'] # if outside HPC (add your user name), otherwise: 
cmd_prefix = [] #if inside HPC
from slurm import PythonExecutorSLURM
slex = PythonExecutorSLURM(job_path=paths.pn_outrec, job_name='deconv'+paths.dataset_name, conda_env='lfm1',
                                     time="1-00:00", partition="pgpu", gres="gpu:4", cpus_per_task=16,mem='128G',ntasks=1,nodes=1,
                                     cmd_prefix=cmd_prefix, mail_user=email, mail_type='ALL,TIME_LIMIT_90')

In [None]:
job_id = slex.submit(reconstruct_vols_from_imgs_parallel, paths, **kwargs, verbose=2)
vid_fn = paths.deconvolved[:-3] + f"_f{"_all" if kwargs["img_idx"] is None else kwargs["img_idx"]}_mip_vmin{kwargs["vmin"]}_vmax{kwargs["vmax"]}{"_al" if kwargs["absolute_limits"] else ""}.mp4"
print(paths.dataset_name,"\n", vid_fn)

### Inspect deconvolved recording

In [None]:
lazyh5(paths.deconvolved)

In [None]:
play_video_widget(vid_fn)

In [None]:
# (optional) inspect deconvolved mean brightness
from i_o import VolumeReader
reader = VolumeReader(paths.deconvolved)
means = []
for frame_n, frame in tqdm(reader, desc="Calculating mean brightness"):
    means.append(cp.asarray(frame).flatten().mean().get())
means = np.array(means)
fig, ax = plt.subplots(1,1,figsize = (16,5))
ax.plot(means)
# ax.set_ylim(4,4.5)
np.save(paths.deconvolved[:-3]+"_means.npy", means)

## Registration

In [None]:
if os.path.exists(paths.reg_recipe):
    with h5py.File(paths.reg_recipe, "r") as f:
        recipe = warpfield.Recipe.from_yaml(f["recipe_path"])
        crop = np.array(f["crop"])
        print(crop[1]-crop[0],crop[3]-crop[2],crop)
        (y1, y2, x1, x2) = crop
    with h5py.File(paths.deconvolved, "r") as f:
        im = np.array(f["data"][0])
        print(im.shape)
else:
    recipe = warpfield.Recipe.from_yaml('default.yml')
    with h5py.File(paths.deconvolved, "r") as f:
        im = np.array(f["data"][0])
        print(im.shape)
    x1, y1 = 0, 0
    x2, y2 = im.shape[2], im.shape[1]
    #define a new crop if the one you saved isnt good enough
    # x1, y1 = 350, 40
    # x2, y2 = 610, 830
    crop = (y1, y2, x1, x2)
    
fig, ax = plt.subplots(1,1,figsize=(10,10))
ax.imshow(im[:,crop[0]:crop[1],crop[2]:crop[3]].max(axis=0), cmap = "gray", vmax=5)
rect = plt.Rectangle(
        (x1, y1), x2 - x1, y2 - y1,
        linewidth=2, edgecolor="cyan", facecolor="cyan", alpha=0.3
    )
ax.add_patch(rect)


### (refine the register recipe)

In [None]:
recipe = warpfield.Recipe.from_yaml('default.yml')
print(recipe.pre_filter)
# print("\n")
for i in range(4):
    print(recipe.levels[i])
    # print("\n")

In [None]:
#generate default registration recipe
# create a basic recipe:
recipe = warpfield.Recipe()
recipe.pre_filter.clip_thresh = 0  # clip DC background, if present

# translation level properties
recipe.levels[0].project.max = True
recipe.levels[0].repeats = 1

# affine level properties
recipe.levels[-1].block_stride = 0.5
recipe.levels[-1].project.max = True
recipe.levels[-1].repeats = 10

# add non-rigid registration levels:
recipe.add_level(block_size=[64, 128, 128])
recipe.levels[-1].smooth.sigmas = [1.0, 1.0, 1.0]
recipe.levels[-1].repeats = 10

recipe.add_level(block_size=[64, 64, 64])
recipe.levels[-1].block_stride = 0.5
recipe.levels[-1].smooth.sigmas = [1.0, 1.0, 1.0]
recipe.levels[-1].repeats = 15

recipe.add_level(block_size=[32, 32, 32])
recipe.levels[-1].block_stride = 0.5
recipe.levels[-1].smooth.sigmas = [2.0, 2.0, 2.0]
recipe.levels[-1].repeats = 15


print("my recipe:")
print(recipe.pre_filter)

for i in range(5):
    
    print(recipe.levels[i])

### define and refine reference volume

In [None]:
#inspect the indexes to use for generating the refenrence - should be very little movement
import register
importlib.reload(register)




ref_idx = [4000,6000,100]
ref_vol_unreg, video_fn = average_volumes(paths,
                                           ref_idx,
                                           preprocess = lambda x: x[:,crop[0]:crop[1],crop[2]:crop[3]],
                                           vmax=3,
                                           vmin=0,
                                           fps=1,
                                           absolute_limits=True, transpose=True)

play_video_widget(video_fn)

In [None]:
ortho_views_widget(ref_vol_unreg, vmin = 0, vmax=3, absolute_limits=True, transpose = True)

In [None]:
importlib.reload(warpfield)

#refine reference volume by registering every volume to the average and then averaging again
def register_reference(vol):
    return warpfield.register_volumes(ref_vol_unreg, vol[:,crop[0]:crop[1],crop[2]:crop[3]], recipe)[0]
    


ref_vol, video_fn = average_volumes(paths,
                                    ref_idx,
                                    preprocess = register_reference,
                                    vmax=100,
                                    vmin=0,
                                    absolute_limits=False,
                                    transpose = True,
                                    fps=1
                                      )
play_video_widget(video_fn)

In [None]:
ortho_views_widget(ref_vol, vmin = 0, vmax=3, absolute_limits=True, transpose = True)

### (optional) define mask

In [None]:
# Controls:
# double left click: define new point
# single left click and drag: move nearest point. 
#     Sometimes you have to click near the point you want to move for it to register that you want to move, and then click and drag
# right click: delete closest point

get_mask_widget(ref_vol, paths.reg_mask, vmin=0, vmax=1, figsize=(12,5), mask_every=10, sigma=5.0, transpose=True)

In [None]:
with h5py.File(paths.reg_mask, "r") as f:
    mask_3d = np.array(f["mask_3d"])
    
ortho_views_widget(mask_3d*ref_vol, vmax = 1, absolute_limits=True, transpose=True)

### Test registration

In [None]:
# ideally choose frames with movement or very obvious activity that could cause registration failure, also to judge what the r_threshold should be
test_reg_frames = [1350, 1450, 1]
fn_addendum = ""
video_fn, video_reg_fn, warpfields, metrics = mini_registration(paths, 
                                                                test_reg_frames,
                                                                ref_vol.transpose((0,2,1)),
                                                                recipe,
                                                                fn_addendum = fn_addendum,
                                                                preprocess = lambda x: x[:,crop[0]:crop[1],crop[2]:crop[3]].transpose((0,2,1)),
                                                                vmax=3,
                                                                vmin=0,
                                                                absolute_limits=True,
                                                                fps=10,)


In [None]:
# unregistered
# play_video_widget(video_fn)

# registered
play_video_widget(video_reg_fn)

In [None]:
fig, ax = plt.subplots(1,4,figsize=(15,5))
ax[0].plot(metrics["r"]) # this is the 
ax[0].set_title("r")
ax[1].plot(metrics["mse"])
ax[1].set_title("mse")
ax[2].plot(metrics["ssim"])
ax[2].set_title("ssim")
ax[3].plot(metrics["dmf"])
ax[3].set_title("dmf")
plt.savefig(os.path.join(paths.pn_outrec,f"mini_registration_registered_{fn_addendum}_f{test_reg_frames}_metrics.png"))

In [None]:
save_register_recipe(paths, 
                     recipe=recipe,
                     ref_vol=ref_vol,
                     crop = crop,
                     r_threshold=0.93, # choose based on R from the cell above
                     cov_tau=100,
                     vid_params = {"write_video": True,
                                   "write_dff_video": True,
                                   "fps": 75,
                                   "vid": {"vmax": 8,
                                           "vmin": 0,
                                           "absolute_limits": True},
                                   "dff": {"vmax": 0.5,
                                           "vmin": -0.5,
                                           "absolute_limits": True,
                                           "tau": 100,},
                                   "zpos": list(lazyh5(paths.deconvolved)["zpos"]),
                                   "scalebar": 200,
                                   "transpose": True},

                    )

## Register the whole recording

### locally

In [None]:
register_recording(paths)

### on the cluster

In [None]:
email = 'lukas.born@charite.de' 
#cmd_prefix = ['ssh', 'lubo12@s-sc-frontend1'] # if outside HPC (add your user name), otherwise: 
cmd_prefix = [] #if inside HPC
slex = PythonExecutorSLURM(job_path=paths.pn_outrec, job_name='reg'+paths.dataset_name, conda_env='lfm1',
                                     time="2-00:00:00", partition="gpu", gres='gpu:nvidia_a100_80gb_pcie', cpus_per_task=16, mem='256G',ntasks=1,nodes=1,exclude="s-sc-pgpu03",
                                     cmd_prefix=cmd_prefix, mail_user=email, mail_type='ALL,TIME_LIMIT_90')
job_id = slex.submit(register_recording, paths)
print(paths.dataset_name)

### (automatically segment and extract traces)

In [None]:
#locally
import register
importlib.reload(register)
from register import register_segment_extract
register_segment_extract(paths, 
                         z_crop = (10,-10),
                         n_traces = 200000,
                         fn_add = "",
                         step=0.000001, 
                         voxel_size=[3,2,2],)

In [None]:
import register
importlib.reload(register)
from register import register_segment_extract

job_id = slex.submit(register_segment_extract, paths)

### Inpect registration

In [None]:
reg = lazyh5(paths.registered)
reg_recipe = lazyh5(paths.reg_recipe)
reg

In [None]:
vid_params = reg_recipe["vid_params"]
fn_vid = os.path.expanduser(paths.pn_outrec + f'/registered{"_T" if vid_params["transpose"] else ""}_vmin{vid_params["vid"]["vmin"]}-vmax{vid_params["vid"]["vmax"]}{"_al" if vid_params["vid"]["absolute_limits"] else ""}.mp4')
play_video_widget(fn_vid)

In [6]:
dff_vid_params = reg_recipe["vid_params"]
fn_dff_vid = os.path.expanduser(paths.pn_outrec + f'/registered_dff{"_T" if vid_params["transpose"] else ""}_vmin{vid_params["dff"]["vmin"]}-vmax{vid_params["dff"]["vmax"]}{"_al" if vid_params["dff"]["absolute_limits"] else ""}.mp4')
play_video_widget(fn_dff_vid)

NameError: name 'reg_recipe' is not defined

In [None]:
fig, axs = plt.subplots(4,1,figsize=(12,8))
ax = axs.flatten()
metrics = reg["metrics"].T
ax[0].plot(metrics[0])
ax[0].set_ylabel("Pearsons R")
ax[1].plot((maximum_filter(median_filter(cp.asarray(metrics[0]), 3), size=21) - cp.asarray(metrics[0])).get())
ax[1].set_ylabel("DMF")
ax[2].plot(metrics[1])
ax[2].set_ylabel("mean squared error")
ax[3].plot(metrics[2])
ax[3].set_ylabel("Structural similarity index")

## Segment and extract traces

### Voxel based segmentation

In [None]:
reg = lazyh5(paths.registered)
cov_map = reg["cov_map"]
reg

In [None]:
fig, ax = plt.subplots(1,1,figsize = (8,4))
arr = cov_map.flatten()
ax.hist(arr[np.isfinite(arr)], bins=100)
ax.set_ylim(0,1000)
print(cov_map.shape)

In [None]:
z_crop = [50,200] #because there is a lot of registration noise at the very top and very bottom
ortho_views_widget(cov_map[z_crop[0]:z_crop[1]],vmin = 0.0001, vmax=0.001, absolute_limits = True, transpose = True)

In [None]:
voxel_size = np.array([3,2,2]) #z,y,x
brightness_threshold = 0.0002
cov_map_dwn = cov_map[z_crop[0]:z_crop[1]:voxel_size[0],::voxel_size[1],::voxel_size[2]]

n_voxels = cov_map_dwn[cov_map_dwn >= brightness_threshold].flatten().shape[0]
print(f"{n_voxels} voxels")
voxel_coords=(np.argwhere(cov_map_dwn>= brightness_threshold)*voxel_size + [z_crop[0],0,0])
labels = np.zeros_like(cov_map, dtype = np.uint32)
for i in tqdm(range(n_voxels)):
    coords = voxel_coords[i]
    labels[max(coords[0]-voxel_size[0]//2, 0):min(cov_map.shape[0],coords[0]+voxel_size[0]//2+voxel_size[0]%2),
           max(coords[1]-voxel_size[1]//2, 0):min(cov_map.shape[1],coords[1]+voxel_size[1]//2+voxel_size[1]%2),
           max(coords[2]-voxel_size[2]//2, 0):min(cov_map.shape[2],coords[2]+voxel_size[2]//2+voxel_size[2]%2)] = i+1
ortho_views_widget((labels!=0).astype(int), transpose=True)

In [None]:
ortho_views_widget(reg["cov_map"]*(labels!=0),vmin = 0.0001, vmax=0.005, absolute_limits = True, transpose = True)

In [None]:
fn = paths.segmentation[:-3]+".h5"
with h5py.File(fn, "w") as f:
    f.create_dataset("segmentation", data=labels)
lazyh5(fn)

### Trace extraction

### locally

In [None]:
extract_traces(paths, voxel_size)

### on the cluster

In [None]:
email = 'lukas.born@charite.de' 
#cmd_prefix = ['ssh', 'lubo12@s-sc-frontend1'] # if outside HPC (add your user name), otherwise: 
cmd_prefix = [] #if inside HPC

slex = PythonExecutorSLURM(job_path=paths.pn_outrec, job_name='trex'+paths.dataset_name, conda_env='lfm1',
                           time="1-00:00:00", partition="gpu", gres='gpu:nvidia_a100_80gb_pcie', cpus_per_task=16, mem='256G',ntasks=1,nodes=1,exclude="s-sc-pgpu03",
                           cmd_prefix=cmd_prefix, mail_user=email, mail_type='ALL,TIME_LIMIT_90')

In [None]:
job_id = slex.submit(extrace_traces, paths, voxel_size)

### Analysis: fully in the ./thesis_figures notebooks and analysis.ipynb

In [None]:
from cupyx.scipy.ndimage import maximum_filter, median_filter, convolve1d

r_thresh = 0.95
dmf_thresh = 0.0013

fig, axs = plt.subplots(2,1,figsize=(12,8))
ax = axs.flatten()

reg = lazyh5(paths.registered)
metrics = reg["metrics"].T
r = metrics[0]
ax[0].plot(r)
ax[0].set_ylabel("Pearsons R")
dmf =(maximum_filter(median_filter(cp.asarray(r), 3), size=21) - cp.asarray(r)).get()
ax[1].plot(dmf)
ax[1].set_ylabel("DMF")
trace_mask = np.ones_like(metrics[0], dtype = bool)
trace_mask[(dmf > dmf_thresh) | (r < r_thresh) | np.isnan(r)] = 0

neighbor_sum = convolve1d(cp.asarray(trace_mask).astype(cp.int8), weights=cp.array([1,1, 1]), mode='constant', cval=0).get()
trace_mask = trace_mask ^ ((trace_mask == True) & (neighbor_sum == 1))


if dataset_name == "20250818_2111_LB_zf552_f25x_f2_audiostim_40fps_1":
    trace_mask = np.ones(shape=(7219,))
    trace_mask[1764:1796]=0
    trace_mask[5322:5342]=0
    trace_mask[7198:]=0
    trace_mask = trace_mask.astype(bool)

diff = np.diff(np.concatenate(([False], trace_mask, [False])).astype(int))

# Find the start and end frame indices of each shaded region
starts_idx = np.where(diff == 1)[0]
ends_idx = np.where(diff == -1)[0]



for start, end in zip(starts_idx, ends_idx):
    # axvspan draws a vertical rectangle from start_time to end_time
    ax[0].axvspan(start,end, color='gray', alpha=0.4,zorder=0,linewidth=0  )
    ax[1].axvspan(start,end, color='gray', alpha=0.4,zorder=0,linewidth=0  )
# Redraw the figure to show the new spans
fig.canvas.draw()

In [None]:
import i_o, importlib
importlib.reload(i_o)
from i_o import get_stimulus, parse_combinations

stim_path = os.path.expanduser(os.path.join(pn_rec, "stimset_LB_zf2"))
timestamps = np.array(lazyh5(paths.raw)["tstmp"])
timestamps = timestamps[~np.isnan(timestamps)]
_, stim_names_og, stimulus_id_fps, final_stimulus, sr= get_stimulus(stim_path,timestamps,lag_frames =5, fps = 75,)

fig, ax = plt.subplots(2,1,figsize=(12,8))
ax[0].plot(final_stimulus)
ax[1].plot(stimulus_id_fps)
# combs = {
#     "R": "angle90.0",
#     "L": "angle270.0",
#     "400Hz": ("400Hz", "4e+02Hz"),  
#     "800Hz": ("800Hz", "8+02Hz"),  
#     "1200Hz": ("1200Hz", "1.2e+03Hz"),  
#     "R_gammatone": ["angle90.0", "gammatone"],  
#     "L_gammatone": ["angle270.0", "gammatone"],  
#     "R_sine": ["angle90.0", "gammatone"],  
#     "L_sine": ["angle270.0", "gammatone"],  
#     "R_400Hz": ["angle90.0", ("400Hz", "4e+02Hz")], 
#     "R_800Hz": ["angle90.0", ("800Hz", "8+02Hz")], 
#     "R_1200Hz": ["angle90.0", ("1200Hz", "1.2e+03Hz")], 
#     "L_400Hz": ["angle270.0", ("400Hz", "4e+02Hz")], 
#     "L_800Hz": ["angle270.0", ("800Hz", "8+02Hz")], 
#     "L_1200Hz": ["angle270.0", ("1200Hz", "1.2e+03Hz")], 
# }
combs = {"pulse": "pulse"}

combinations = parse_combinations(stim_names_og, combs)
bool_stim, stim_names, _, _, _ = get_stimulus(stim_path, timestamps = timestamps,lag_frames =6, fps = 75, combinations=combinations)
print(bool_stim.shape)
fig, ax = plt.subplots(1,1,figsize=(12,4))
ax.imshow(bool_stim[:,:].T, aspect="auto")
stim_names

In [None]:
traces_data = lazyh5(paths.traces)
traces = traces_data["traces"]
# segmentation =  lazyh5(paths.segmentation)["segmentation"]
segmentation = traces_data["segmentation"]
print(np.unique(segmentation).shape)
traces_data