coding directions (preencounter-baseline) (postencounter-baseline)
- then show voxel weights
- only use z slices where rois have been labeled
- maybe implement temporal smoothing

# LOAD DATA

In [126]:
import sys
import os
import quantify_voxels
from numpy import genfromtxt
import matplotlib.pyplot as plt
import numpy as np
import tifffile
from skimage.morphology import erosion, disk
import importlib
import glob

import matplotlib
%matplotlib widget
font = {'family' : 'Arial',
        'weight' : 'normal',
        'size'   : 15}
matplotlib.rc('font', **font)

_ = importlib.reload(sys.modules['quantify_voxels'])

PTH = r'D:\DATA\g5ht-free\20251028\date-20251028_time-1500_strain-ISg5HT_condition-starvedpatch_worm001'
bin_factor = 2
fps = 1/0.533
baseline_window = (0,60) # specify baseline frame range
pre_window = (82,123) # specify pre-encounter frame range
post_window = (141,170) # specify post-encounter frame range
frames = (baseline_window[0], post_window[1]) # frames to load from normalized_voxels.npy
rfp_thresh = 50 # only include voxels where rfp_mean is above this threshold
keep_width = (25,228) # only keep this x range before

# load roi.tif
roi = tifffile.imread(os.path.join(PTH,'roi.tif'))

# load g5 data
data = np.load(os.path.join(PTH, f'normalized_voxels.npy'))
g5 = data.copy() # g5 is R/R0 for frames we care about

# load mask, bin it, then repeat it so it matches g5 dimensions
fixed_mask_fn = glob.glob(os.path.join(PTH, 'fixed_mask_*.tif'))[0]
mask = tifffile.imread(fixed_mask_fn)
h, w = mask.shape
h_binned = h // bin_factor
w_binned = w // bin_factor
mask = mask.reshape(h_binned, bin_factor, w_binned, bin_factor).max(axis=(1,3))
mask = np.repeat(mask[np.newaxis, :, :], g5.shape[1], axis=0)

# # load rfp_mean.npy
rfp_mean = np.load(os.path.join(PTH, 'rfp_mean.npy'))

print('ROI: ', roi.shape)
print('g5: ', g5.shape)
print('mask: ', mask.shape)
print('rfp_mean: ', rfp_mean.shape)

z_with_roi = np.where(np.sum(roi, axis=(1,2)) > 0)[0]
print('Z slices with ROI:', z_with_roi)

# subsample z slices to only those with ROI
g5 = g5[:, z_with_roi, :, :]
mask = mask[z_with_roi, :, :]
rfp_mean = rfp_mean[z_with_roi, :, :]

print('g5: ', g5.shape)
print('mask: ', mask.shape)
print('rfp_mean: ', rfp_mean.shape)

# update mask based on rfp_thresh
mask_updated = mask.copy()
mask_updated[rfp_mean < rfp_thresh] = 0

# # g5 is R/R0 for a single z slice, for frames 800 to 1000
# # mask_updated is the segmentation mask updated based on rfp_thresh
g5 = g5 * mask_updated[np.newaxis, :, :]

# # trim width in g5 and mask_updated
g5 = g5[:, :, :, keep_width[0]:keep_width[1]]
mask_updated = mask_updated[:, :, keep_width[0]:keep_width[1]]

print('Number of voxels in updated mask:', np.sum(mask_updated > 0))
print('g5: ', g5.shape)


ROI:  (39, 200, 500)
g5:  (1200, 39, 100, 250)
mask:  (39, 100, 250)
rfp_mean:  (39, 100, 250)
Z slices with ROI: [ 6  7  8  9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24]
g5:  (1200, 19, 100, 250)
mask:  (19, 100, 250)
rfp_mean:  (19, 100, 250)
Number of voxels in updated mask: 154568
g5:  (1200, 19, 100, 203)


In [129]:
# plot a few representative slices
%matplotlib qt
frame = 150 # choose a frame to visualize
slice_indices = [0, len(z_with_roi)//2, len(z_with_roi)-1]

fig, axs = plt.subplots(1, len(slice_indices), figsize=(15,3.5), constrained_layout=True)
for i, slice_idx in enumerate(slice_indices):
    ax = axs[i]
    im = ax.imshow(g5[frame, slice_idx], cmap='viridis')
    ax.set_title(f'Z slice {z_with_roi[slice_idx]}')
    # make colorbar small and horizontal on bottom
    cbar = fig.colorbar(im, ax=ax, orientation='horizontal', fraction=0.046, pad=0.04)
    cbar.ax.tick_params(labelsize=10)
    # turn off axis elements
    ax.axis('off')
plt.show()

In [None]:
# plot heatmap of g5 voxels over time
%matplotlib qt

plt.close('all')

# reshape g5 to (time, voxels)
g5_reshaped = g5.reshape(g5.shape[0], -1)
# only keep voxels that are in the updated mask
mask_flat = mask_updated.reshape(-1)
g5_reshaped = g5_reshaped[:, mask_flat > 0]
# sort voxels by their variance over time
voxel_vars = np.var(g5_reshaped, axis=0)
sort_indices = np.argsort(voxel_vars)
g5_reshaped = g5_reshaped[:, sort_indices][:, ::-1] # descending order
# keep the top 50000 most variable voxels
g5_reshaped = g5_reshaped[:, :50000]
voxel_indices = np.arange(g5_reshaped.shape[1])
# keep a specified time window
time_window = (50, post_window[1])
g5_reshaped = g5_reshaped[time_window[0]:time_window[1], :]

time = np.arange(g5_reshaped.shape[0]) / fps
fig, ax = plt.subplots(figsize=(4,6), constrained_layout=True)
im = ax.pcolormesh(time, voxel_indices, g5_reshaped.T, shading='auto', cmap='plasma', vmin=0.8, vmax=7.5)
# flip the y axis
ax.invert_yaxis()
ax.set_xlabel('Time (s)')
ax.set_ylabel('Voxels')
cbar = fig.colorbar(im, ax=ax)
cbar.set_label('R/R0', rotation=270, labelpad=15)
plt.show()

# not plot the mean over the chosen voxels
mean_g5 = np.mean(g5_reshaped, axis=1)
fig, ax = plt.subplots(figsize=(6,4), constrained_layout=True)
ax.plot(time, mean_g5, lw=2)
ax.axhline(1, color='k', ls='--')
ax.set_xlabel('Time (s)')
ax.set_ylabel('Mean R/R0')
ax.set_title('Mean for top 50000 most variable voxels')
plt.show()



In [None]:
# calculate coding directions
# a coding direction is defined by the difference between the mean voxel activity in the post-encounter window and the pre-encounter window

# reshape g5 to (time, voxels)
g5_reshaped = g5.reshape(g5.shape[0], -1)
# only keep voxels that are in the updated mask
mask_flat = mask_updated.reshape(-1)
g5_reshaped = g5_reshaped[:, mask_flat > 0]
# sort voxels by their variance over time
voxel_vars = np.var(g5_reshaped, axis=0)
sort_indices = np.argsort(voxel_vars)
g5_reshaped = g5_reshaped[:, sort_indices][:, ::-1] # descending order
# keep the top 50000 most variable voxels
g5_reshaped = g5_reshaped[:, :50000]
voxel_indices = np.arange(g5_reshaped.shape[1])


# calculate mean activity in each epoch window
base_mean = np.mean(g5[baseline_window[0]:baseline_window[1]], axis=0) # ZHW
pre_mean = np.mean(g5[pre_window[0]:pre_window[1]], axis=0)
post_mean = np.mean(g5[post_window[0]:post_window[1]], axis=0)

# calculate a pre-encounter coding direction
pre_coding_direction = (pre_mean - base_mean).flatten() # Z*H*W
# divide by the sum of the square root of the standard deviations in each window to normalize
# pre_coding_direction /= (np.sqrt(np.std(g5[baseline_window[0]:baseline_window[1]], axis=0).flatten()) + np.sqrt(np.std(g5[pre_window[0]:pre_window[1]], axis=0).flatten()))
# normalize by L1 norm to ensure voxel count differences don't affect magnitude
# pre_coding_direction /= np.sum(np.abs(pre_coding_direction))
# normalize by L2 norm
pre_coding_direction /= np.sqrt(np.sum(pre_coding_direction**2))

# calculate a post-encounter coding direction
post_coding_direction = (post_mean - base_mean).flatten() # Z*H*W
# divide by the sum of the square root of the standard deviations in each window to normalize
# pre_coding_direction /= (np.sqrt(np.std(g5[baseline_window[0]:baseline_window[1]], axis=0).flatten()) + np.sqrt(np.std(g5[pre_window[0]:pre_window[1]], axis=0).flatten()))
# normalize by L1 norm to ensure voxel count differences don't affect magnitude
# post_coding_direction /= np.sum(np.abs(post_coding_direction))
# normalize by L2 norm
post_coding_direction /= np.sqrt(np.sum(post_coding_direction**2))

# calculate an encounter coding direction as the difference between post and pre
encounter_coding_direction = post_coding_direction - pre_coding_direction
# normalize by L1 norm to ensure voxel count differences don't affect magnitude
# encounter_coding_direction /= np.sum(np.abs(encounter_coding_direction))
# normalize by L2 norm
encounter_coding_direction /= np.sqrt(np.sum(encounter_coding_direction**2))


# # # orthogonalize using gram-schmidt
# post_coding_direction -= np.dot(post_coding_direction, pre_coding_direction) * pre_coding_direction
# post_coding_direction /= np.sqrt(np.sum(post_coding_direction**2))
# encounter_coding_direction -= np.dot(encounter_coding_direction, pre_coding_direction) * pre_coding_direction
# encounter_coding_direction -= np.dot(encounter_coding_direction, post_coding_direction) * post_coding_direction
# encounter_coding_direction /= np.sqrt(np.sum(encounter_coding_direction**2))

# reshape coding directions back to ZHW
pre_coding_direction_reshaped = pre_coding_direction.reshape(g5.shape[1], g5.shape[2], g5.shape[3])
post_coding_direction_reshaped = post_coding_direction.reshape(g5.shape[1], g5.shape[2], g5.shape[3])
encounter_coding_direction_reshaped = encounter_coding_direction.reshape(g5.shape[1], g5.shape[2], g5.shape[3])

In [158]:
# ignore nans in sum
print(np.nansum(pre_coding_direction_reshaped))
print(np.nansum(post_coding_direction_reshaped))
print(np.nansum(encounter_coding_direction_reshaped))

107.30636434259637
241.19774349049607
167.3386373984176


In [159]:
# plot representative slices of coding directions

%matplotlib qt
slice_indices = [0, len(z_with_roi)//2, len(z_with_roi)-1]

fig, axs = plt.subplots(1, len(slice_indices), figsize=(15,4.5), constrained_layout=True)
for i, slice_idx in enumerate(slice_indices):
    ax = axs[i]
    im = ax.imshow(pre_coding_direction_reshaped[slice_idx, :, :] * mask_updated[slice_idx, :, :], cmap='viridis')
    # area outside mask should be white
    im.set_clim(vmin=np.nanmin(pre_coding_direction_reshaped * mask_updated), vmax=np.nanmax(pre_coding_direction_reshaped * mask_updated))
    ax.set_title(f'z{z_with_roi[slice_idx]}')
    # make colorbar small and horizontal on bottom
    cbar = fig.colorbar(im, ax=ax, orientation='horizontal', fraction=0.046, pad=0.04)
    cbar.ax.tick_params(labelsize=12)
    # turn off axis elements
    ax.axis('off')
plt.suptitle('Pre-encounter coding direction', y=0.9)
plt.show()

fig, axs = plt.subplots(1, len(slice_indices), figsize=(15,4.5), constrained_layout=True)
for i, slice_idx in enumerate(slice_indices):
    ax = axs[i]
    im = ax.imshow(post_coding_direction_reshaped[slice_idx, :, :] * mask_updated[slice_idx, :, :], cmap='viridis')
    im.set_clim(vmin=np.nanmin(post_coding_direction_reshaped * mask_updated), vmax=np.nanmax(post_coding_direction_reshaped * mask_updated))
    ax.set_title(f'z{z_with_roi[slice_idx]}')
    # make colorbar small and horizontal on bottom
    cbar = fig.colorbar(im, ax=ax, orientation='horizontal', fraction=0.046, pad=0.04)
    cbar.ax.tick_params(labelsize=12)
    # turn off axis elements
    ax.axis('off')
plt.suptitle('Post-encounter coding direction', y=0.9)
plt.show()

fig, axs = plt.subplots(1, len(slice_indices), figsize=(15,4.5), constrained_layout=True)
for i, slice_idx in enumerate(slice_indices):
    ax = axs[i]
    im = ax.imshow(encounter_coding_direction_reshaped[slice_idx, :, :] * mask_updated[slice_idx, :, :], cmap='viridis')
    im.set_clim(vmin=np.nanmin(encounter_coding_direction_reshaped * mask_updated), vmax=np.nanmax(encounter_coding_direction_reshaped * mask_updated))
    ax.set_title(f'z{z_with_roi[slice_idx]}')
    # make colorbar small and horizontal on bottom
    cbar = fig.colorbar(im, ax=ax, orientation='horizontal', fraction=0.046, pad=0.04)
    cbar.ax.tick_params(labelsize=12)
    # turn off axis elements
    ax.axis('off')
plt.suptitle('Encounter coding direction', y=0.9)
plt.show()

In [162]:
# project data onto coding directions to get time series of coding direction scores
# reshape g5 to (time, voxels)
g5_reshaped = g5.reshape(g5.shape[0], -1)
# only keep voxels that are in the updated mask
mask_flat = mask_updated.reshape(-1)
g5_reshaped = g5_reshaped[:, mask_flat > 0]
pre_coding_direction_flat = pre_coding_direction[mask_flat > 0]
post_coding_direction_flat = post_coding_direction[mask_flat > 0]
encounter_coding_direction_flat = encounter_coding_direction[mask_flat > 0]

# projections, normalize by min and max to get scores between 0 and 1
pre_coding_direction_scores = np.dot(g5_reshaped, pre_coding_direction_flat)
post_coding_direction_scores = np.dot(g5_reshaped, post_coding_direction_flat)
encounter_coding_direction_scores = np.dot(g5_reshaped, encounter_coding_direction_flat)
# normalize
pre_coding_direction_scores = (pre_coding_direction_scores - np.min(pre_coding_direction_scores)) / (np.max(pre_coding_direction_scores) - np.min(pre_coding_direction_scores))
post_coding_direction_scores = (post_coding_direction_scores - np.min(post_coding_direction_scores)) / (np.max(post_coding_direction_scores) - np.min(post_coding_direction_scores))
encounter_coding_direction_scores = (encounter_coding_direction_scores - np.min(encounter_coding_direction_scores)) / (np.max(encounter_coding_direction_scores) - np.min(encounter_coding_direction_scores))

In [163]:
print(pre_coding_direction_scores.shape)
print(post_coding_direction_scores.shape)

(1200,)
(1200,)


In [165]:
# plot projections
%matplotlib qt
# plot pre-encounter coding direction scores

time = np.arange(g5_reshaped.shape[0]) / fps

fig, ax = plt.subplots(figsize=(6,4), constrained_layout=True)
ax.plot(time, pre_coding_direction_scores, lw=2)
ax.plot(time, post_coding_direction_scores, lw=2)
ax.plot(time, encounter_coding_direction_scores, lw=2)
ax.set_xlabel('Time (s)')
ax.set_ylabel('Projection (a.u)')
plt.show()

In [166]:
# calculate variance explained by each coding direction
# fraction of total data variance explained by each coding direction (0-1)
# total variance across voxels (time-invariant denominator)
voxel_vars = np.var(g5_reshaped, axis=0)
total_variance = np.sum(voxel_vars)
if total_variance == 0:
    pre_variance_explained = 0.0
    post_variance_explained = 0.0
    encounter_variance_explained = 0.0
else:
    pre_proj = np.dot(g5_reshaped, pre_coding_direction_flat)
    post_proj = np.dot(g5_reshaped, post_coding_direction_flat)
    encounter_proj = np.dot(g5_reshaped, encounter_coding_direction_flat)
    pre_variance_explained = np.var(pre_proj) / total_variance
    post_variance_explained = np.var(post_proj) / total_variance
    encounter_variance_explained = np.var(encounter_proj) / total_variance
# clip to [0,1] to avoid tiny numerical issues
pre_variance_explained = float(np.clip(pre_variance_explained, 0.0, 1.0))
post_variance_explained = float(np.clip(post_variance_explained, 0.0, 1.0))
encounter_variance_explained = float(np.clip(encounter_variance_explained, 0.0, 1.0))

In [167]:
# plot variance explained
# use default colors (blue orange tab20)
%matplotlib qt
fig, ax = plt.subplots(figsize=(6,4), constrained_layout=True)
ax.bar(['Pre-encounter', 'Post-encounter', 'Encounter'], [pre_variance_explained*100, post_variance_explained*100, encounter_variance_explained*100], color=['tab:blue', 'tab:orange', 'tab:green'])
ax.set_ylabel('Variance explained (%)')
plt.show()
