### CONDA ENVIRONMENTS

For steps __1. preprocess__ and __2. mip__, `conda activate g5ht-pipeline`

For step __3. segment__, `conda activate segment-torch` or `conda activate torchcu129`

For step __4. spline, 5. orient, 6. warp, 7. reg__

## TODO:

1. I wonder if I computed a spline on each and every z slice and warped each, oriented each of them, and warped each of them, if the problem of weirdly sheared image stacks would be solved
2. quick mp4 for all recordings
   1. now working in engaging, works per one nd2 sbatch
3. focus check for all recordings
   1. maybe focus check can be used to specify which z slices are good to use and which frames are good to use
4. for recordings starting in december 2025, need to trim first 2 rather than last 2 z slices
5. flip worms so that VNC is always up
6. fixed mask could be automated, but if not, make sure to save which index is fixed
7. extract behavior
8. posture similarity
   1. posture might consist of the spline + thresholded z-stack
      1. I'm thinking that the orientation shouldn't matter, but the z-planes in focus will, and curvature/spline of the head will
      2. maybe need to actually interpolate to 117 z slices
   2. sub registration problems
   3. label each set of registered frames with one set of ROIs, or auto segment ROIs from each set of registered frames
9.  track z over time, which zslices are consistent
   1. focus + correlation
10. beads -> train/test
11. gfp+1 relative to rfp channel (might only apply to pre december 2025 recordings)
12. wholistic 
    1.  parameter sweep, might change
    2.  python version
13. autocorr/scorr
14. automate z slice trimming
    1.  pre december 2025 (trim last 2 z slices)
    2.  post december 2025 (trim first z slice)
15. photobleaching estimation?
    1.  record immo with serotonin
    2.  at least do it for RFP
16. try deltaF/F [ (F(t) - F0) / F0 ]
17. coding directions (preencounter-baseline) (postencounter-baseline)
    1.  then show voxel weights
18. port everything to engaging

In [1]:
import sys
import os
import importlib
from tqdm import tqdm

try:
    import utils
    is_torch_env = False
except ImportError:
    is_torch_env = True
    print("utils not loaded because conda environment doesn't have nd2reader installed. probably using torchcu129 env, which is totally fine for just doing the segmentation step")

## SPECIFY DATA TO PROCESS

In [2]:
# DATA_PTH = r'C:\Users\munib\POSTDOC\DATA\fluorescent_beads_ch_align\20251219'
DATA_PTH = r'D:\DATA\g5ht-free\20260123'

INPUT_ND2 = 'date-20260123_strain-ISg5HT-nsIS180_condition-fedpatch_worm005.nd2'

INPUT_ND2_PTH = os.path.join(DATA_PTH, INPUT_ND2)

NOISE_PTH = r'C:\Users\munib\POSTDOC\CODE\g5ht-pipeline\noise\noise_042925.tif'

OUT_DIR = os.path.splitext(INPUT_ND2_PTH)[0]

STACK_LENGTH = 41

if not is_torch_env:
    noise_stack = utils.get_noise_stack(NOISE_PTH, STACK_LENGTH)
    num_frames, height, width, num_channels = utils.get_range_from_nd2(INPUT_ND2_PTH, stack_length=STACK_LENGTH) 
    beads_alignment_file = utils.get_beads_alignment_file(INPUT_ND2_PTH)
else:
    print("utils not loaded because conda environment doesn't have nd2reader installed. probably using torchcu129 env, which is totally fine for just doing the segmentation step")

print(INPUT_ND2)
print('Num z-slices: ', STACK_LENGTH)
if not is_torch_env:
    print('Number of frames: ', num_frames)
    print('Height: ', height)
    print('width: ', width)
    print('Number of channels: ', num_channels)
    print('Beads alignment file: ', beads_alignment_file)

date-20260123_strain-ISg5HT-nsIS180_condition-fedpatch_worm005.nd2
Num z-slices:  41
Number of frames:  715
Height:  512
width:  512
Number of channels:  2
Beads alignment file:  D:\DATA\g5ht-free\20260123\date-20260123_strain-ISg5HT-nsIS180_condition-fedpatch_worm005_chan_alignment.nd2


## 10. LABEL ROIs

`conda activate g5ht-pipeline`

- after this step, use `lbl` conda env to label ROI of fixed frame
  - run `labelme` in terminal


maybe also see here for video annotation: https://github.com/wkentaro/labelme/tree/main/examples/video_annotation

### EXPORT FIXED VOLUME AS PNGs for labeling with `labelme`

In [3]:
# code that exports each z-slice of fixed.tif as a separate png
import tifffile
import os
import glob
import scipy.ndimage as ndi

PTH = os.path.splitext(INPUT_ND2_PTH)[0]

# in PTH directory, find a fixed_XXXX*.tif file, where XXXX are digits
fixed_fn = glob.glob(os.path.join(PTH, 'fixed_[0-9][0-9][0-9][0-9]*.tif'))[0]
fixed_pth = os.path.join(PTH, fixed_fn)

# fixed_pth = os.path.join(PTH, 'fixed.tif')
# fixed_stack = ndi.zoom(tifffile.imread(fixed_pth), zoom=(3,1,1,1))
fixed_stack = tifffile.imread(fixed_pth)

out_dir = os.path.join(PTH, 'fixed_png')
if not os.path.exists(out_dir):
    os.makedirs(out_dir)

for i in range(fixed_stack.shape[0]):
    slice_pth = os.path.join(out_dir, f'fixed_z{i:02d}.png')
    # make sure to save channel 1, and that it is visible, correct data type, clipped to 0-255
    slice_img = fixed_stack[i,1,:,:]
    slice_img = (slice_img - slice_img.min()) / (slice_img.max() - slice_img.min()) * 255
    slice_img = slice_img.astype('uint8')
    tifffile.imwrite(slice_pth, slice_img) 

out_dir = os.path.join(PTH, 'fixed_xz_png')
if not os.path.exists(out_dir):
    os.makedirs(out_dir)

for i in range(fixed_stack.shape[2]):
    slice_pth = os.path.join(out_dir, f'fixed_xz{i:02d}.png')
    # make sure to save channel 1, and that it is visible, correct data type, clipped to 0-255
    slice_img = fixed_stack[:,1,i,:]
    slice_img = (slice_img - slice_img.min()) / (slice_img.max() - slice_img.min()) * 255
    slice_img = slice_img.astype('uint8')
    tifffile.imwrite(slice_pth, slice_img) 

out_dir = os.path.join(PTH, 'fixed_yz_png')
if not os.path.exists(out_dir):
    os.makedirs(out_dir)

for i in range(fixed_stack.shape[3]):
    slice_pth = os.path.join(out_dir, f'fixed_yz{i:02d}.png')
    # make sure to save channel 1, and that it is visible, correct data type, clipped to 0-255
    slice_img = fixed_stack[:,1,:,i]
    slice_img = (slice_img - slice_img.min()) / (slice_img.max() - slice_img.min()) * 255
    slice_img = slice_img.astype('uint8')
    # save as high quality tiff
    tifffile.imwrite(slice_pth, slice_img) 

### PARSE OUTPUT OF `labelme`

- outputs `roi.tif`

In [19]:
roi_labels

['dnc', 'vnc', 'nerve ring', 'isthmus']

In [20]:
import numpy as np
import json
from skimage.draw import polygon

PTH = os.path.splitext(INPUT_ND2_PTH)[0]

out_dir = os.path.join(PTH, 'fixed_png')
fixed_fn = glob.glob(os.path.join(PTH, 'fixed_[0-9][0-9][0-9][0-9]*.tif'))[0]
fixed_pth = os.path.join(PTH, fixed_fn)
fixed_stack = tifffile.imread(fixed_pth)
Z,C,H,W = fixed_stack.shape

roi = np.zeros((Z, H, W), dtype=fixed_stack.dtype) # ZHW

# get all unique roi_labels from json files
roi_json_files = glob.glob(os.path.join(out_dir, 'fixed_z[0-9][0-9]*.json'))
roi_labels = []
for roi_json_file in roi_json_files:
    with open(roi_json_file, 'r') as f:
        roi_dict = json.load(f)
        roi_labels.append([shape['label'] for shape in roi_dict['shapes']])
# get all unique roi_labels
roi_labels = list(set([item for sublist in roi_labels for item in sublist]))

# # roi_labels = ['PC','MC','IM','TB','NR','VNC','DNC']
# # procorpus, metacorpus, isthmus, terminal bulb, nerve ring, ventral nerve cord, dorsal nerve cord

for i in range(Z):
    slice_roi_json = os.path.join(out_dir, f'fixed_z{i:02d}.json')
    # if slice_roi_json doesn't exist, continue
    if not os.path.exists(slice_roi_json):
        continue
    with open(slice_roi_json, 'r') as f:
        roi_dict = json.load(f)
        # loop through each shape in roi_dict['shapes']
        for shape in roi_dict['shapes']:
            label = shape['label']
            if label in roi_labels:
                points = shape['points']
                # get integer coordinates
                points = [(int(round(p[1])), int(round(p[0]))) for p in points]
                # create a mask for the polygon
                
                rr, cc = polygon([p[0] for p in points], [p[1] for p in points], shape=(H,W))
                # should set to correct z slice
                roi[i, rr, cc] = roi_labels.index(label) + 1 # start from 1

# save roi stack as tif image, imagej=true and save the roi labels as metadata
roi_pth = os.path.join(PTH, 'roi.tif')
tifffile.imwrite(roi_pth, roi.astype(np.uint8), imagej=True, metadata={'Labels': roi_labels})

## 11. QUANTIFY

`conda activate g5ht-pipeline`

Have to first label dorsal and ventral nerve rings and pharynx. See ...

In [None]:
import sys
import os
import quantify
from numpy import genfromtxt
import matplotlib.pyplot as plt
import numpy as np
import importlib

import matplotlib
font = {'family' : 'Arial',
        'weight' : 'normal',
        'size'   : 15}
matplotlib.rc('font', **font)
matplotlib.rcParams['svg.fonttype'] = 'none'

_ = importlib.reload(sys.modules['quantify'])

PTH = os.path.splitext(INPUT_ND2_PTH)[0]
REG_DIR = r'registered_elastix'
# PLOT_ONLY = True

%matplotlib inline
# %matplotlib qt


sys.argv = ["", PTH, REG_DIR, PLOT_ONLY]
quantify.main()

## 12 QUANTIFY VOXELS

In [None]:
import sys
import os
import quantify_voxels
from numpy import genfromtxt
import matplotlib.pyplot as plt
import numpy as np
import tifffile
from skimage.morphology import erosion, disk
import importlib
import glob

import matplotlib
%matplotlib widget
font = {'family' : 'DejaVu Sans',
        'weight' : 'normal',
        'size'   : 15}
matplotlib.rc('font', **font)

_ = importlib.reload(sys.modules['quantify_voxels'])

PTH = r'D:\DATA\g5ht-free\20251028\date-20251028_time-1500_strain-ISg5HT_condition-starvedpatch_worm001'
reg_dir = 'registered'
bin_factor = 2
fps = 1/0.533

sys.argv = ["", PTH, reg_dir, bin_factor]
quantify_voxels.main()

In [None]:
# load normalized_voxels.npy, also been binned
g5 = np.load(os.path.join(PTH, 'normalized_voxels.npy'))
rfp_mean = np.load(os.path.join(PTH, 'rfp_mean.npy'))
gfp_mean = np.load(os.path.join(PTH, 'gfp_mean.npy'))
baseline = np.load(os.path.join(PTH, 'baseline.npy'))

# load mask, bin it
# find the fixed_mask_*.tif file in PTH directory
try:
    fixed_mask_fn = glob.glob(os.path.join(PTH, 'fixed_mask_*.tif'))[0]
except:
    fixed_mask_fn = glob.glob(os.path.join(PTH, 'fixed_mask*.tif'))[0]

mask = tifffile.imread(fixed_mask_fn)

h, w = mask.shape
h_binned = h // bin_factor
w_binned = w // bin_factor
# binning of mask
mask_binned = mask.reshape(h_binned, bin_factor, w_binned, bin_factor).max(axis=(1,3))

In [None]:
g5_masked = g5 * mask_binned[np.newaxis, :, :]

In [None]:
print(g5.shape)
print(rfp_mean.shape)
print(gfp_mean.shape)
print(baseline.shape)
print(mask_binned.shape)

In [None]:
g5_masked.shape

# g5_masked is shape (T, Z, H, W) array with zeros outside the worm region
# find all voxels where g5_masked is zero for each time point
zero_mask = g5_masked == 0
# calculate the probability of a voxel being zero across time, but don't include time points where the voxel is masked out (i.e., outside the worm)
zero_prob = np.mean(zero_mask, axis=0)
# set zero probability to NaN for voxels outside the worm (mask binned needs a z dimension added)
zero_prob[mask_binned[np.newaxis].repeat(zero_prob.shape[0], axis=0) == 0] = np.nan

plt.close('all')
%matplotlib qt
# plot zero_prob as an image, with colorbar, for each z slice in subplots
fig, axes = plt.subplots(4, 10, figsize=(20, 8), constrained_layout=True)
for z in range(zero_prob.shape[0]):
    ax = axes[z // 10, z % 10]
    im = ax.imshow(zero_prob[z, :, :], cmap='viridis', vmin=0, vmax=1)
    ax.set_title(f'Z={z}')
    ax.axis('off')
    # replace last subplot with colorbar
    if z == zero_prob.shape[0] - 1:
        fig.colorbar(im, ax=ax, fraction=0.046, pad=0.04)
# delete last subplots if z slices are less than 40
for z in range(zero_prob.shape[0], 40):
    ax = axes[z // 10, z % 10]
    ax.axis('off')
plt.show()

plt.figure()
plt.hist(zero_prob.flatten(), bins=50, color='blue', alpha=0.7)
plt.xlabel('Probability of Voxel Being Zero')
plt.ylabel('Frequency')
plt.title('Histogram of Zero Probability Across Voxels')
plt.show()

In [None]:
probability_threshold = 0.3 # if a voxel is zero more than this fraction of the time, consider it outside the worm or a bad voxel
# create a mask of good voxels
good_voxel_mask = zero_prob < probability_threshold
normalized_data = g5_masked * good_voxel_mask[np.newaxis].repeat(g5_masked.shape[0], axis=0)

In [None]:
normalized_data.shape

In [None]:
norm_data_mean = np.mean(normalized_data, axis=0)

# clip norm_data mean between 0th and 99th percentiles
p0 = np.percentile(norm_data_mean, 0)
p99 = np.percentile(norm_data_mean, 99)
norm_data_mean = np.clip(norm_data_mean, p0, p99)

z2plot = [0,5,15,25,35]
# z2plot = [0,5]

%matplotlib inline
plt.close('all')
fig, axes = plt.subplots(len(z2plot), 3, figsize=(15, 9), constrained_layout=True)
for i in z2plot: # for each z slice
    plt.subplot(len(z2plot), 3, z2plot.index(i)*3+1)
    plt.pcolor(gfp_mean[i,:,:])
    plt.title(f'GFP Mean Z={i}')
    # plt.colorbar()

    plt.subplot(len(z2plot), 3, z2plot.index(i)*3+2)
    plt.pcolor(rfp_mean[i,:,:])
    plt.title(f'RFP Mean Z={i}')
    # plt.colorbar()

    plt.subplot(len(z2plot), 3, z2plot.index(i)*3+3)
    plt.pcolor(norm_data_mean[i,:,:])
    # add colorbar outside to the right
    plt.colorbar()
    # plt.clim(0, 12)
    # make all axes not be squished
plt.show()


In [None]:
# cluster normalized_data using k means
from sklearn.cluster import KMeans
n_clusters = 10
kmeans = KMeans(n_clusters=n_clusters, random_state=0)
T, Z, H, W = normalized_data.shape
normalized_data_reshaped = normalized_data.reshape(T, Z*H*W).T  # shape (Z*H*W, T)
kmeans.fit(normalized_data_reshaped)


In [None]:
# plot cluster mean activity over time
cluster_means = np.zeros((n_clusters, T))
for c in range(n_clusters):
    cluster_means[c, :] = normalized_data_reshaped[kmeans.labels_ == c, :].mean(axis=0)

In [None]:
plt.figure(figsize=(10,6))
t = np.arange(T) * (1/fps)
ic = 0
for c in range(6,9):#range(n_clusters):
    plt.plot(t, cluster_means[c, :], label=f'Cluster {ic+1}')
    ic += 1
plt.legend(frameon=False)
plt.xlabel('Time (sec)')
plt.ylabel('$\ R / R_{baseline}$')
plt.show()

In [None]:
# plot cluster 1 and 2 spatial maps at each z slice
plt.close('all')
ic = 0
for c in range(6,9):#range(n_clusters):
    fig, axes = plt.subplots(4, 10, figsize=(20, 8), constrained_layout=True)
    for z in range(Z):
        ax = axes[z // 10, z % 10]
        cluster_map = kmeans.labels_.reshape(Z, H, W)[z, :, :] == c
        im = ax.imshow(cluster_map, cmap='gray')
        ax.set_title(f'Cluster {ic+1} Z={z}')
        ax.axis('off')
    plt.suptitle(f'Spatial Map of Cluster {c+1} Across Z Slices')
    ic += 1
    # deleate last subplots if z slices are less than 40
    for z in range(Z, 40):
        ax = axes[z // 10, z % 10]
        ax.axis('off')
    plt.show()

In [None]:
# PCA
from sklearn.decomposition import PCA
# flatten normalized_data to (T, Z*H*W)
T, Z, H, W = normalized_data.shape
data_reshaped = normalized_data.reshape(T, Z*H*W)
pca = PCA(n_components=5)
pca.fit(data_reshaped)

In [None]:

# get scores, plot
scores = pca.transform(data_reshaped)
# plot the first 5 principal component scores
for i in range(5):
    plt.figure()
    plt.plot(scores[:,i])
    plt.show()

# plot the first 5 principal component weights as images across z slices
components = pca.components_



In [None]:

# plot gfp_mean, rfp_mean, dat in a (5,3) grid of subplots (5 z slices, 3 columns)


# divide gfp_mean by rfp_mean, but only if rfp_mean is not zero

dat = gfp_mean / rfp_mean
# dat = np.divide(gfp_mean, rfp_mean, out=np.zeros_like(gfp_mean), where=rfp_mean!=0)
# divide each voxel in dat by its baseline, but only if baseline is not zero
dat = np.divide(dat, baseline, out=np.zeros_like(dat), where=baseline!=0)
dat = dat * mask_binned[np.newaxis, :, :]

# remove outliers by clipping to 1st and 99th percentile
p1 = np.percentile(dat, 0)
p99 = np.percentile(dat, 99.9)
dat = np.clip(dat, p1, p99)


# plt.figure()
# plt.hist(baseline.ravel(), bins=1000, range=(0,100))
# plt.show()

# z2plot = [0,5,15,25,35]
z2plot = [0,5]

%matplotlib inline
plt.close('all')
fig, axes = plt.subplots(len(z2plot), 3, figsize=(15, 5), constrained_layout=True)
for i in z2plot: # for each z slice
    plt.subplot(len(z2plot), 3, z2plot.index(i)*3+1)
    plt.pcolor(gfp_mean[i,:,:])
    plt.title(f'GFP Mean Z={i}')
    # plt.colorbar()

    plt.subplot(len(z2plot), 3, z2plot.index(i)*3+2)
    plt.pcolor(rfp_mean[i,:,:])
    plt.title(f'RFP Mean Z={i}')
    # plt.colorbar()

    plt.subplot(len(z2plot), 3, z2plot.index(i)*3+3)
    plt.pcolor(dat[i,:,:])
    # add colorbar outside to the right
    plt.colorbar()
    plt.clim(0, 12)
    # make all axes not be squished
plt.show()


In [None]:
# load mask, bin
# find the fixed_mask_*.tif file in PTH directory
try:
    fixed_mask_fn = glob.glob(os.path.join(PTH, 'fixed_mask_*.tif'))[0]
except:
    fixed_mask_fn = glob.glob(os.path.join(PTH, 'fixed_mask*.tif'))[0]

mask = tifffile.imread(fixed_mask_fn)

h, w = mask.shape
h_binned = h // bin_factor
w_binned = w // bin_factor
# binning of mask
mask_binned = mask.reshape(h_binned, bin_factor, w_binned, bin_factor).max(axis=(1,3))
# shrink the mask slightly using erosion skimage
mask_binned = erosion(mask_binned, disk(5))

# zero out values outside the mask
g5_masked = g5 * mask_binned[np.newaxis,np.newaxis,:,:]



In [None]:
g5_masked.shape

In [None]:
g5_masked_trimmed  = g5_masked[:, 5:27, :,40:230]
g5_masked_trimmed.shape

In [None]:
plt.close('all')

# get the intensity over time of the voxels in the middle of the (x,y) plane, definding a box, and all z slices
middle_x = int(g5_masked_trimmed.shape[3] // 1.5)
middle_y = int(g5_masked_trimmed.shape[2] // 1.95)
box_size = 10  # Define the size of the box around the middle point
middle_voxels = g5_masked_trimmed[:, 5:27, middle_y-box_size:middle_y+box_size, middle_x-box_size:middle_x+box_size].mean(axis=(2,3))
# divide each z slice by the mean activity in the first 60 time points to get F/F0
middle_voxels = middle_voxels / (middle_voxels[:60,:].mean(axis=0) + 1e-6)

t = np.arange(middle_voxels.shape[0]) / fps

plt.close('all')

plt.figure(figsize=(10, 6))
plt.plot(middle_voxels)
plt.xlabel('Time (sec)')
plt.ylabel('$F/F_{\\mathrm{baseline}}$')
# plt.ylim(0, 5)
plt.show()
# plot same thing as heatmap and sort by mean signal across time for each z slice
sorted_indices = np.argsort(middle_voxels.mean(axis=0))
middle_voxels_sorted = middle_voxels[:, sorted_indices]

plt.figure(figsize=(10, 6))
plt.pcolor(t, np.arange(middle_voxels.shape[1]), middle_voxels_sorted.T, cmap='plasma',vmin=0,vmax=5)
plt.colorbar(label='$F/F_{\\mathrm{baseline}}$')
plt.xlabel('Time (sec)')
plt.ylabel('Z slice (sorted by mean intensity)')
plt.show()

# plot g5_masked and the box defined by box_size and middle_x, middle_y to confirm that the box is in the middle of the (x,y) plane and covers the expected area
for i in range(g5_masked_trimmed.shape[1]):
    plt.figure(figsize=(10, 4))
    plt.pcolor(g5_masked_trimmed[158,i,:,:],vmin=0, vmax=10)
    plt.plot([middle_x-box_size, middle_x+box_size, middle_x+box_size, middle_x-box_size, middle_x-box_size],
             [middle_y-box_size, middle_y-box_size, middle_y+box_size, middle_y+box_size, middle_y-box_size], color='red')
    plt.show()

# GRID ANALYSIS

In [None]:
import grid_analysis
import sys
import importlib
_ = importlib.reload(sys.modules['grid_analysis'])

PTH = r'D:\DATA\g5ht-free\20251028\date-20251028_time-1500_strain-ISg5HT_condition-starvedpatch_worm001'

xspace = 10
yspace = 10
fps = 1/0.533
bin_factor = 1
z_start = 5
z_end = 25
baseline_start_sec = 5 # start baseline calculation here (relative to the start of the recording, not relative to start_time_sec)
baseline_end_sec = 35 # end baseline calculation here (relative to the start of the recording, not relative to start_time_sec)
start_time_sec = 0
time_food_sec = 435
sys.argv = ["", PTH, xspace, yspace, fps, bin_factor, z_start, z_end, baseline_start_sec, baseline_end_sec, start_time_sec]#time_food_sec] 
grid_analysis.main()

In [None]:
import grid_analysis
_ = importlib.reload(sys.modules['grid_analysis'])
# load the grid_timeseries_flat_smoothed.npz file and plot the heatmap using the plot_heatmap function
data = np.load(os.path.join(PTH, 'grid_timeseries_flat_smoothed.npz'))
flat_timeseries_smoothed = data['timeseries']
z_labels = data['z_labels']
grid_analysis.plot_heatmap(flat_timeseries_smoothed, z_labels, fps=1/0.533, output_path=os.path.join(PTH, 'grid_heatmap.png'), data_start_time=start_time_sec)
# grid_analysis.plot_heatmap(flat_timeseries_smoothed, z_labels, fps=1/0.533, output_path=os.path.join(PTH, 'grid_heatmap.png'), data_start_time=start_time_sec)

In [None]:
flat_timeseries_smoothed.shape

In [None]:
plt.figure(figsize=(10, 6))
plt.plot(np.arange(flat_timeseries_smoothed.shape[0])/fps,flat_timeseries_smoothed[:,20:27])
# plt.ylim(0.03,0.3)
plt.show()

In [None]:
np.unique(z_labels)

In [None]:
flat_timeseries_smoothed.shape

In [None]:
_ = importlib.reload(sys.modules['grid_analysis'])
# sort the grid squares by their mean signal across time, and plot the heatmap again with the grid squares in sorted order
mean_signal_grid = flat_timeseries_smoothed.mean(axis=0)
sorted_indices = np.argsort(mean_signal_grid)
flat_timeseries_smoothed_sorted = flat_timeseries_smoothed[:, sorted_indices]

plt.close('all')

# smooth some more
from grid_analysis import causal_smooth
toplot = causal_smooth(flat_timeseries_smoothed_sorted, sigma=1.0)
grid_analysis.plot_heatmap(toplot, z_labels, fps=1/0.533, output_path=os.path.join(PTH, 'grid_heatmap_sorted.png'), data_start_time=start_time_sec)
# plt.xlim(-50,125)
# plt.ylim(50,770)

# plot the mean signal across all grid squares for each z-slice over time, and mark the time of food addition
mean_signal_grid = flat_timeseries_smoothed.mean(axis=1)
plt.figure(figsize=(10, 6))
plt.plot(np.arange(flat_timeseries_smoothed.shape[0])/fps - start_time_sec, mean_signal_grid)
plt.axvline(0, color='red', linestyle='--', label='Food addition')
plt.xlabel('Time (sec)')
plt.ylabel('Mean $F/F_{\\mathrm{baseline}}$ across grid squares')
plt.legend()
# plt.xlim(-50,125)
# plt.ylim(0.03,0.3)
plt.show()

In [None]:
flat_timeseries_smoothed.shape

In [None]:
# pca on flat_timeseries_smoothed, which of size (time,features)
from sklearn.decomposition import PCA
import utils
pca = PCA(n_components=5)
pca.fit(flat_timeseries_smoothed)
# get scores
pc_scores = pca.transform(flat_timeseries_smoothed)
explained_variance = pca.explained_variance_ratio_
explained_variance


line_colors = ['C0', 'C1', 'C2', 'C3', 'C4']
ylim = (-12,14)
for i in range(5):
    fig, ax = utils.pretty_plot()
    ax.plot(np.arange(pc_scores.shape[0])/fps, pc_scores[:,i], label=f'PC{i+1}', lw=3, color=line_colors[i])
    ax.set_ylim(ylim)
    ax.set_xlabel('Time (sec)')
    plt.ylabel(f'PC{i+1} projection')
    plt.show()

# scree plot
fig, ax = utils.pretty_plot()
ax.bar(np.arange(len(explained_variance))+1, explained_variance*100)
ax.set_xlabel('PC')
ax.set_ylabel('Variance explained (%)')
plt.show()

In [None]:
# cluster flat_timeseries_smoothed, then plot heatmap with clusters indicated
from sklearn.cluster import KMeans
n_clusters = 3
kmeans = KMeans(n_clusters=n_clusters, random_state=0)
kmeans.fit(flat_timeseries_smoothed.T)
cluster_labels = kmeans.labels_


In [None]:
plt.figure()
plt.plot(sorted_timeseries)
plt.show()

In [None]:
# sort flat_timeseries_smoothed by cluster labels
sorted_indices = np.argsort(cluster_labels)
sorted_timeseries = flat_timeseries_smoothed[:, sorted_indices]

plt, ax = utils.pretty_plot(figsize=(10,6))
im = ax.pcolormesh(sorted_timeseries.T, cmap='plasma')
plt.show()

# plot mean of each cluster over time on the same axis
i = 0
for c in range(n_clusters-1):
    cluster_mean = flat_timeseries_smoothed[:, cluster_labels == c].mean(axis=1)
    # normalize between 0 and 1
    cluster_mean = (cluster_mean - cluster_mean.min()) / (cluster_mean.max() - cluster_mean.min())
    # smooth cluste means
    # cluster_mean = pca_analysis.causal_smooth(cluster_mean[:,np.newaxis], sigma=3.0)
    if i==0:
        fig, ax = utils.pretty_plot()
    ax.plot(np.arange(cluster_mean.shape[0])/fps, cluster_mean, lw=1)
    ax.set_xlabel('Time (sec)')
    # plt.ylabel(f'Cluster {c+1} mean $F/F_{{\\mathrm{{baseline}}}}$')
    i += 1
plt.show()

# cross correlation between cluster 1 and 2 means signals
from scipy.signal import correlate
cluster1_mean = flat_timeseries_smoothed[:, cluster_labels == 0].mean(axis=1)
cluster2_mean = flat_timeseries_smoothed[:, cluster_labels == 1].mean(axis=1)
# normalize means bwetween 0 and 1
cluster1_mean = (cluster1_mean - cluster1_mean.min()) / (cluster1_mean.max() - cluster1_mean.min())
cluster2_mean = (cluster2_mean - cluster2_mean.min()) / (cluster2_mean.max() - cluster2_mean.min())
# compute cross-correlation
corr = correlate(cluster1_mean - cluster1_mean.mean(), cluster2_mean - cluster2_mean.mean(), mode='full')
lags = np.arange(-len(cluster1_mean)+1, len(cluster1_mean))
# plot xcorr
fig, ax = utils.pretty_plot()
ax.plot(lags/fps, corr)
ax.set_xlabel('Lag (sec)')
ax.set_ylabel('Cross-correlation')
plt.show()

# PCA analysis

In [None]:
import pca_analysis
import sys
_ = importlib.reload(sys.modules['pca_analysis'])

PTH = r'D:\DATA\g5ht-free\20251028\date-20251028_time-1500_strain-ISg5HT_condition-starvedpatch_worm001'

# Arguments: input_dir, fps, bin_factor, z_start, z_end, lowpass_freq, n_components,
#            n_pcs_plot, baseline_start, baseline_end, start_time, time_offset
sys.argv = ["", PTH, 1/0.533, 1, 7, 18, 0, 50, 6, 5, 35, 0]
pca_analysis.main()

# DMD analysis

In [None]:
import dmd_analysis
import sys
_ = importlib.reload(sys.modules['dmd_analysis'])

PTH = r'D:\DATA\g5ht-free\20251028\date-20251028_time-1500_strain-ISg5HT_condition-starvedpatch_worm001'

# Arguments: input_dir, fps, bin_factor, z_start, z_end, lowpass_freq, dmd_rank, 
#            n_modes_plot, baseline_start, baseline_end, start_time, time_offset
sys.argv = ["", PTH, 1/0.533, 1, 7, 18, 0, 5, 6, 5, 35, 0]

dmd_analysis.main()