In [6]:
import os
# change to the upper level folder to detect dj_local_conf.json
if os.path.basename(os.getcwd())=='notebooks': os.chdir('..')
assert os.path.basename(os.getcwd())=='adamacs', ("Please move to the main directory")
from adamacs.pipeline import subject, session, equipment, surgery, event, trial, imaging, behavior, scan, model
from adamacs.ingest import session as isess
from adamacs.ingest.harp import CamLoader_sync
from adamacs.helpers import stack_helpers as sh
from adamacs.helpers import trace_helpers as th
from adamacs.helpers import dj_helpers as djh
from adamacs.ingest import behavior as ibe
from adamacs.paths import get_experiment_root_data_dir
import datajoint as dj
from rspace_client.eln import eln
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from element_interface.utils import find_full_path
print(dj.__version__)
print(dj.config['custom']['database.prefix'])

0.14.1
tobiasrext_


### some functions used here (will be hidden later)

In [2]:
import bisect
import numpy as np
from skimage.transform import resize
from tqdm import tqdm
import concurrent.futures
import cv2
import math
import matplotlib as mpl

def get_closest_timestamps(series, target_timestamp):
    # List to store the indices
    indices = []

    # For each timestamp in series1, find the closest timestamp in series2 and get its index
    for t1 in series:
        closest_index = closest_timestamp(target_timestamp, t1)
        indices.append(closest_index)
    return indices

# Function to find closest timestamp
def closest_timestamp(series, target_timestamp):
    index = bisect.bisect_left(series, target_timestamp)
    if index == 0:
        return 0
    if index == len(series):
        return len(series)-1
    before = series[index - 1]
    after = series[index]
    if after - target_timestamp < target_timestamp - before:
       return index
    else:
       return index-1


def resize_movie(movie, new_height, new_width):
    # Get the number of frames and color channels
    num_frames, _, _, num_channels = movie.shape
    
    # Initialize an empty array for the scaled movie
    scaled_movie = np.empty((num_frames, new_height, new_width, num_channels), dtype=np.uint8)
    
    # Iterate through each frame
    for i in tqdm(range(num_frames), desc="Resizing frames"):
        # Resize the frame and store it in the new array
        scaled_movie[i] =  cv2.resize(movie[i], (new_width, new_height), interpolation = cv2.INTER_AREA)
    
    # Return the scaled movie
    return scaled_movie


def resize_frame(frame, new_height, new_width):
    return cv2.resize(frame, (new_width, new_height), interpolation = cv2.INTER_AREA)

def resize_movie_mt(movie, new_height, new_width):
    num_frames, _, _, num_channels = movie.shape
    scaled_movie = np.empty((num_frames, new_height, new_width, num_channels), dtype=np.uint8)

    with concurrent.futures.ThreadPoolExecutor() as executor:
        for i, resized_frame in tqdm(enumerate(executor.map(resize_frame, movie, [new_height]*num_frames, [new_width]*num_frames)), total=num_frames, desc="Resizing frames"):
            scaled_movie[i] = resized_frame

    return scaled_movie



def create_snippets(videodata, indices, fps):
    snippets = []
    for index in indices:
        start = max(index - starter, 0)
        end = min(index + ender, len(videodata))
        snippet = videodata[start:end]
        snippets.append(snippet)
    return snippets




def equalize_snippets(snippets):
    # Find the maximum number of frames in any snippet
    max_frames = max(snippet.shape[0] for snippet in snippets)

    # Equalize the number of frames in each snippet
    equalized_snippets = []
    for snippet in snippets:
        if snippet.shape[0] < max_frames:
            # Pad the snippet with blank frames or repeat the last frame
            padding = np.zeros((max_frames - snippet.shape[0],) + snippet.shape[1:])
            snippet_padded = np.concatenate([snippet, padding], axis=0)
            equalized_snippets.append(snippet_padded)
        else:
            equalized_snippets.append(snippet)
    
    return equalized_snippets

# Figure Style settings for notebook.

def concatenate_to_grid(snippets):
    # Determine the grid size
    num_snippets = len(snippets)
    grid_size = int(math.ceil(math.sqrt(num_snippets)))

    # Initialize placeholders for rows and the final grid
    rows = []
    final_grid = None

    # Concatenate snippets into rows
    for i in range(0, num_snippets, grid_size):
        row = snippets[i:i + grid_size]
        while len(row) < grid_size:  # Pad the row if necessary
            row.append(np.zeros_like(snippets[0]))
        concatenated_row = np.concatenate(row, axis=2)  # Concatenate along width
        rows.append(concatenated_row)

    # Concatenate rows to form the grid
    final_grid = np.concatenate(rows, axis=1)  # Concatenate along height

    return final_grid


def resize_video_frames(grid_data, target_width=3840, target_height=2160):
    resized_video = []

    for frame in grid_data:
        # Resize frame to 4K resolution
        resized_frame = cv2.resize(frame, (target_width, target_height), interpolation=cv2.INTER_AREA)
        resized_video.append(resized_frame)

    return np.array(resized_video)


plot_params = {
    'axes.facecolor': 'white',
    'figure.facecolor': 'white',
    'font.family': 'sans-serif',
    # 'font.sans-serif': 'Helvetica Neue',
    'font.size': 16,
    'lines.color': 'black',
    'xtick.direction': 'out',
    'ytick.direction': 'out',
    'xtick.color': 'black',
    'ytick.color': 'black',
    'axes.spines.left': True,
    'axes.spines.bottom': True,
    'axes.spines.top': True,
    'axes.spines.right': True,
    'axes.edgecolor': 'black',  
    # 'legend.frameon': False,
    'figure.subplot.wspace': .5,
    'figure.subplot.hspace': .5,
    # 'figure.figsize': (18, 13),
    'ytick.major.left': True,
    'xtick.major.bottom': True
}

map_params = {
    'axes.facecolor': 'white',
    'figure.facecolor': 'white',
    'font.family': 'sans-serif',
    # 'font.sans-serif': 'Helvetica Neue',
    'font.size': 12,
    'lines.color': 'black',
    'xtick.direction': 'out',
    'ytick.direction': 'out',
    'xtick.color': 'black',
    'ytick.color': 'black',
    'axes.spines.left': False,
    'axes.spines.bottom': False,
    'axes.spines.top': False,
    'axes.spines.right': False,
    'axes.edgecolor': 'black',  
    # 'legend.frameon': False,
    'figure.subplot.wspace': .5,
    'figure.subplot.hspace': .5,
    # 'figure.figsize': (18, 13),
    'ytick.major.left': False,
    'xtick.major.bottom': False
}


img_params = {
    'axes.titlecolor': 'white',
    'axes.facecolor': 'black',
    'figure.facecolor': 'black',
    'axes.spines.left': False,
    'axes.spines.bottom': False,
    'axes.spines.top': False,
    'axes.spines.right': False,
    'legend.frameon': False,
    'figure.subplot.wspace': .01,
    'figure.subplot.hspace': .01,
    'figure.figsize': (18, 13),
    'ytick.major.left': False,
    'xtick.major.bottom': False
}


# Recordings and example analysis for BonnBrain - NATASHA

In [None]:
# Natasha - example scans

# mini2p
# sess9FIG39RU
# sess9FHS7Y22

# bench2p
# sessi = sess9FJ66OT2
# scan9FGLE1FN

# hese are the recordings for the poster. But I am not sure if the recording is good enough. If not, it would be with translucent arena, posting in the next message
# x2 5min + x1 10 mins recordings
# LED flickering
# IMU on
# arena is covered with rubber walls and floor, upper IR light
# no task
# power 70% tLens -110
# threshold for the blob tracking 15
# /datajoint-data/data/nataliak/NK_WEZ-8869_2023-06-07_scan9FIG39RU_sess9FIG39RU
# /datajoint-data/data/nataliak/NK_WEZ-8869_2023-06-07_scan9FIG3MU8_sess9FIG39RU
# /datajoint-data/data/nataliak/NK_WEZ-8869_2023-06-07_scan9FIG3GCJ_sess9FIG39RU
# 9:32
# x2 10min recordings
# LED flickering
# IMU on
# translucent arena
# no task
# power 70% tLens -80
# /datajoint-data/data/nataliak/NK_ROS-1485_2023-04-28_scan9FHS7Y22_sess9FHS7Y22
# /datajoint-data/data/nataliak/NK_ROS-1485_2023-04-28_scan9FHS845A_sess9FHS7Y22

# Bench2p figures and examples

In [None]:
# first define a key to be used across multiple tables

scansi = "scan9FIG39RU"
# scansi = "scan9FKNRW9Y"

scan_key = (scan.Scan & f'scan_id = "{scansi}"').fetch('KEY')[0]
# curation_key = (imaging.Curation & scan_key & 'curation_id=1').fetch1('KEY')
sessi = (scan.Scan & f'scan_id = "{scansi}"').fetch('session_id')[0]
aux_setup_typestr = (scan.ScanInfo() & scan_key).fetch("userfunction_info")[0] # check setup type (not needed)
print(aux_setup_typestr)
print((scan.ScanPath & scan_key).fetch("path")[0])

### Get and show overview images from suite2p registration

In [None]:
ref_image = (imaging.MotionCorrection.Summary & curation_key & 'field_idx=0').fetch1('ref_image')
average_image = (imaging.MotionCorrection.Summary & curation_key & 'field_idx=0').fetch1('average_image')
correlation_image = (imaging.MotionCorrection.Summary & curation_key & 'field_idx=0').fetch1('correlation_image')
max_proj_image = (imaging.MotionCorrection.Summary & curation_key & 'field_idx=0').fetch1('max_proj_image')

In [None]:
# get some scaling values from pixel distribution
scalemin = 2
scalemax = 100

cmin = np.percentile(average_image,scalemin)  
cmax = np.percentile(average_image,scalemax)

# load image styles for display
mpl.rcParams.update(mpl.rcParamsDefault)
mpl.rcParams.update(img_params)

# Make figure with all templates
plt.subplot(1, 4, 1)
plt.imshow(ref_image, cmap='gray', vmin = cmin, vmax = cmax)
plt.title(f"Reference Image - {scansi}")

plt.subplot(1, 4, 2)
plt.imshow(average_image, cmap='gray', vmin = cmin, vmax = cmax)
plt.title("Registered Image, Mean Projection")

plt.subplot(1, 4, 3)
plt.imshow(max_proj_image, cmap='gray')
plt.title("Registered Image, Max Projection")

plt.subplot(1, 4, 4)
plt.imshow(correlation_image, cmap='gray')
plt.title("Registered Image, Correlation Map")

plt.show(block=False)

# plt.savefig

In [None]:
# just mean image
scalemin = 2
scalemax = 100

cmin = np.percentile(average_image,scalemin)  
cmax = np.percentile(average_image,scalemax)


# plt.subplot(1, 4, 1)
plt.imshow(average_image, cmap='gray', vmin = cmin, vmax = cmax)
plt.show()


### Plot treadmill data

In [None]:
# extract scan treadmill data from database using the scan_key from above
treadmill = (behavior.TreadmillRecording.Channel() & scan_key).fetch("data")[0]
auxtime = (behavior.TreadmillRecording.Channel() & scan_key).fetch("time")[0]

In [None]:
# smoothing window size (ms)
window = 5000

# convert voltage to degree
treadmillnorm = (treadmill-np.min(treadmill)) / np.max(treadmill) * 360

# compute running speed (see function above)
angular_velocity_smoothed, unwrapped_angle_smoothed = ibe.compute_angular_velocity(auxtime, treadmillnorm, window)

# get some values to scale running speed plot 
scalemin = 0
scalemax = 100
offset = 10
ymin = np.percentile(angular_velocity_smoothed,scalemin)  - offset
ymax = np.percentile(angular_velocity_smoothed,scalemax)  + offset

# load plot styles for display
mpl.rcParams.update(mpl.rcParamsDefault)
mpl.rcParams.update(plot_params)


fig, axes = plt.subplots(nrows=3, ncols=1, figsize=(15, 10))
# plt.rcParams['agg.path.chunksize'] = 10000  # Add this line if it does not rende

# Plotting the time series
axes[0].plot(auxtime, treadmillnorm)
axes[0].set_ylim([-10, 370])
axes[0].set_ylabel("Wheel position \n[degree]")
axes[0].set_xlabel("Time [s]")

axes[1].plot(auxtime[:-window+1],unwrapped_angle_smoothed )
# axes[1].set_ylim([-10000, 10000])
axes[1].set_ylabel("Unwrapped wheel position \n[cumulative degree]")
axes[1].set_xlabel("Time [s]")

axes[2].plot(auxtime[:-window],angular_velocity_smoothed)
axes[2].set_ylim([ymin, ymax])
axes[2].set_ylabel("Running speed \n[degree / s]")
axes[2].set_xlabel("Time [s]")

fig.suptitle(scan_key["scan_id"], fontsize=16)

plt.show() 

### Get the fluorescence traces of this recording

In [None]:
# get mask positions of masks that are classified as cells and that are larger than a certain pixel size
mask_xpix, mask_ypix = (
    imaging.Segmentation.Mask * imaging.MaskClassification.MaskType
    & scan_key
    & "mask_npix > 1"
).fetch("mask_xpix", "mask_ypix")

Using this query, we've fetched the coordinates of segmented masks. We can overlay these
masks onto our average image.

In [None]:
mask_image = np.zeros(np.shape(average_image), dtype=bool)
for xpix, ypix in zip(mask_xpix, mask_ypix):
    mask_image[ypix, xpix] = True

In [None]:
# load image styles for display
mpl.rcParams.update(mpl.rcParamsDefault)
mpl.rcParams.update(img_params)

plt.imshow(average_image, cmap='gray', vmin = cmin, vmax = cmax)
plt.contour(mask_image, colors="red", linewidths=0.5)
plt.show()

One more example using queries - plot fluorescence and deconvolved activity
traces:

Here we fetch the primary key attributes of the entry with `curation_id=0` for the
current session in the `imaging.Curation` table. 

Then, we fetch all cells that fit the
restriction criteria from `imaging.Segmentation.Mask` and
`imaging.MaskClassification.MaskType` as a `projection`. 

We then use this projection as
a restriction to fetch and plot fluorescence and deconvolved activity traces from the
`imaging.Fluorescence.Trace` and `imaging.Activity.Trace` tables, respectively.

In [None]:
curation_key = (imaging.Curation & scan_key & "curation_id=1").fetch1("KEY")
query_cells = (
    imaging.Segmentation.Mask * imaging.MaskClassification.MaskType
    & curation_key
    & "mask_center_z=0"
    & "mask_npix > 1"
).proj()

# query_cells

In [None]:
neuropilcorr = True

fluorescence_traces = (imaging.Fluorescence.Trace & query_cells).fetch(
    "fluorescence", order_by="mask"
)

neuropil_traces = (imaging.Fluorescence.Trace & query_cells).fetch(
    "neuropil_fluorescence", order_by="mask"
)

if neuropilcorr:
    print("DOING VANILLA NEUROPIL CORRECTION NOW!")
    fluorescence_traces = fluorescence_traces - 0.7 * neuropil_traces

activity_traces = (imaging.Activity.Trace & query_cells).fetch(
    "activity_trace", order_by="mask"
)

sampling_rate = (scan.ScanInfo & curation_key).fetch1("fps")

# timebase_2p = np.r_[: fluorescence_traces[0].size] * 1 / sampling_rate

timebase_2p = np.linspace(0, fluorescence_traces[0].size / sampling_rate, fluorescence_traces[0].shape[0])


In [None]:
from rastermap import Rastermap
from scipy import stats 
from scipy.stats import zscore
from scipy.ndimage import gaussian_filter1d

# stack fluorescence for rastermap
fluos = np.vstack(fluorescence_traces)

nan_mask = np.isnan(fluos).any(axis=1)
S = fluos[~nan_mask]
S = zscore(S, axis=1)

rmmodel = Rastermap(n_clusters=None, # None turns off clustering and sorts single neurons 
                  n_PCs=80, # use fewer PCs than neurons
                  locality=0.15, # some locality in sorting (this is a value from 0-1)
                  time_lag_window=15, # use future timepoints to compute correlation
                  grid_upsample=0, # 0 turns off upsampling since we're using single neurons
                ).fit(S)


y = rmmodel.embedding # neurons x 1
isort = rmmodel.isort

# sort by embedding and smooth over neurons (uncomment)

# Sfilt = gaussian_filter1d(S[isort], np.minimum(1,np.maximum(1,int(S.shape[0]*0.001))),axis=0)
Sfilt = S[isort]

In [None]:
# Plot sorted data
# load plot styles for display

mpl.rcParams.update(mpl.rcParamsDefault)
mpl.rcParams.update(plot_params)

plt.figure(figsize=(5,3))
plt.imshow(Sfilt, vmin = -0.1, vmax=1, extent= [timebase_2p[0], timebase_2p[-1], 0, Sfilt.shape[0]], aspect='auto', cmap='gray_r')
plt.xlabel('time [s]')
plt.ylabel('sorted neurons')
plt.show()

Different visualization


In [None]:
import seaborn as sns
figure = plt.figure(figsize=(15,15))
ax = figure.add_subplot(111)

offset_scaler = 10 # We want to plot every cell with a little offset to the last one
for no,trace in enumerate(Sfilt):
    if no == 80: break # not more than 80

    # get the neuropil corrected values for that trace:
    # trace = Sfilt
    ax.plot(timebase_2p,trace + (no*offset_scaler),lw=1,c='k',alpha=.8)

ax.set_xlim(0,timebase_2p[-1])

ax.get_yaxis().set_ticks([])
ax.set_title('Sorted z-scored traces')    

ax.set_ylabel('Cells')
ax.set_xlabel('Time [s]')
sns.despine(left=True)

plt.show()

#### Synchronization!

In [None]:
# and 2p timestamps (which will always be in the recording gate).
twoptimestamps = (event.Event()  &  "event_type='bench2p_frames'" &  scan_key ).fetch('event_start_time')
aligned_wheel_indices = get_closest_timestamps(twoptimestamps,auxtime[:-window]) #smoothing windwo from above

# use this to index into the wheelspeed
angular_velocity_smoothed_2pref = angular_velocity_smoothed[aligned_wheel_indices]

# both arrays have same shape now - now 2pdata and wheel speed can be plotted together on the 2ptimestamps
print(np.shape(twoptimestamps))
print(np.shape(angular_velocity_smoothed_2pref))


In [None]:
kp_colors = np.array([[0.55,0.55,0.55]])


# timepoints to visualize
tstart = 0
tend =  timebase_2p[-1] - 10

xmin = int(np.floor(tstart * sampling_rate))
xmax = int(np.floor(tend * sampling_rate))

# make figure with grid for easy plotting
fig = plt.figure(figsize=(8,5), dpi=200)
grid = plt.GridSpec(9, 20, figure=fig, wspace = 0.05, hspace = 0.3)

# plot running speed
ax = plt.subplot(grid[:2, :-1])
ax.plot(angular_velocity_smoothed_2pref,  color=kp_colors[0])
ax.set_xlim([0, xmax-xmin])
ax.axis("off")
ax.set_title("running speed", color=kp_colors[0])
# ax.set_xlabel("running speed")

# plot superneuron activity
ax = plt.subplot(grid[2:, :-1])
ax.imshow(Sfilt[:, xmin:xmax], cmap="gray_r", vmin=-0.1, vmax=0.7,  extent= [timebase_2p[xmin], timebase_2p[xmax], 0, Sfilt.shape[0]], aspect="auto")
ax.set_xlabel("time [s]")
ax.set_ylabel("sorted boutons")

plt.show()

# ax = plt.subplot(grid[1:, -1])
# ax.imshow(np.arange(0, len(sn))[:,np.newaxis], cmap="gist_ncar", aspect="auto")
# ax.axis("off")

# Mini2p figures and examples

In [7]:
# first define a key to be used across multiple tables

scansi = "scan9FKNRW9Y"
scan_key = (scan.Scan & f'scan_id = "{scansi}"').fetch('KEY')
# curation_key = (imaging.Curation & scan_key & 'curation_id=10').fetch1('KEY') # SET THE CURATION ID OF MANUAL CURATION HERE!
sessi = (scan.Scan & f'scan_id = "{scansi}"').fetch('session_id')[0]
aux_setup_typestr = (scan.ScanInfo() & scan_key).fetch("userfunction_info")[0] # check setup type (not needed)
print(aux_setup_typestr)
print((scan.ScanPath & scan_key).fetch("path")[0])

mini2p1_openfield
/datajoint-data/data/tobiasr/NK_ROS-1629_2023-10-19_scan9FKNRW9Y_sess9FKNRW9Y


### Get and show overview images from suite2p registration

In [None]:
ref_image = (imaging.MotionCorrection.Summary & curation_key & 'field_idx=0').fetch1('ref_image')
average_image = (imaging.MotionCorrection.Summary & curation_key & 'field_idx=0').fetch1('average_image')
correlation_image = (imaging.MotionCorrection.Summary & curation_key & 'field_idx=0').fetch1('correlation_image')
max_proj_image = (imaging.MotionCorrection.Summary & curation_key & 'field_idx=0').fetch1('max_proj_image')

In [None]:
# get some scaling values from pixel distribution
scalemin = 0
scalemax = 100
offset = 0


cmin = np.percentile(average_image,scalemin)  
cmax = np.percentile(average_image,scalemax) + offset

# load image styles for display
mpl.rcParams.update(mpl.rcParamsDefault)
mpl.rcParams.update(img_params)

# Make figure with all templates
plt.subplot(1, 4, 1)
plt.imshow(ref_image, cmap='gray', vmin = cmin, vmax = cmax)
plt.title("Reference Image for Registration");

plt.subplot(1, 4, 2)
plt.imshow(average_image, cmap='gray', vmin = cmin, vmax = cmax)
plt.title("Registered Image, Mean Projection");

plt.subplot(1, 4, 3)
plt.imshow(max_proj_image, cmap='gray')
plt.title("Registered Image, Max Projection")

plt.subplot(1, 4, 4)
plt.imshow(correlation_image, cmap='gray')
plt.title("Registered Image, Correlation Map")
plt.show(block=False)
# plt.savefig

In [None]:
# just mean image
scalemin = 5
scalemax = 100

cmin = np.percentile(average_image,scalemin) 
cmax = np.percentile(average_image,scalemax) #+ offset


# plt.subplot(1, 4, 1)
plt.imshow(average_image, cmap='gray', vmin = cmin, vmax = cmax)
plt.show()


## make cam-synchronized movie

In [None]:
from pathlib import Path
import skvideo.io

# get the movie file name from the database
topfile = (model.VideoRecording.File & scan_key).fetch('file_path')[0]

#load to array
videodata = skvideo.io.vread(str(topfile))
# videodata = np.asarray([skvideo.io.vshape(frame)[0] for frame in videodata], dtype=np.uint8)

#### get / make DLC overlay video

In [None]:
import glob
from pathlib import Path
import skvideo.io

key =  (model.PoseEstimationTask & f'recording_id="{scansi}"').fetch1('KEY')
destfolder = (model.PoseEstimationTask & key).fetch1('pose_estimation_output_dir')

labeled_videofile = glob.glob(f"{destfolder}/*.mp4")

# DEEPLABCUT OVERLAY - CURRENTLY ONLY WORKING IN MY ENVIRONMENT. NEED TO CHECK VERSIONS

# key = (model.VideoRecording & scan_key).fetch1('KEY')
# key.update({'model_name': 'Head_orientation-NK', 'task_mode': 'trigger'})


# from deeplabcut.utils.make_labeled_video import create_labeled_video
# import yaml
# from element_interface.utils import find_full_path
# from adamacs.paths import get_dlc_root_data_dir


# destfolder = model.PoseEstimationTask.infer_output_dir(scan_key)

# config_paths = sorted( # Of configs in the project path, defer to the datajoint-saved
#     list(
#         find_full_path(
#             get_dlc_root_data_dir(), ((model.Model & key).fetch1("project_path"))
#         ).glob("*.y*ml")
#     )
# )

# create_labeled_video( # Pass strings to label the video
#     config=str(config_paths[-1]),
#     videos=str(topfile),
#     destfolder=str(destfolder),
# )


# labeled_videofile = '/datajoint-data/data/tobiasr/RN_OPI-1681_2023-04-05_scan9FHELAYA_sess9FHELAYA/device_mini2p1_top_recording_scan9FHELAYA_model_Head_orientation-NK/scan9FHELAYA_top_video_2023-04-05T15_19_53DLC_resnet50_Head_orientationJul17shuffle1_90000_labeled.mp4'

print(destfolder)

# labeled_videodata = skvideo.io.vread(str(labeled_videofile[0]))


#### load moving average registered Ca2+ imaging movie

In [None]:
# get the registered moving average (blinking) movie data of the specified scanID

scandir = (scan.ScanPath & scan_key).fetch('path')[0]

directory = Path(scandir + "/suite2p/plane0/reg_tif")
pattern = '*20_frame*.mp4'
files = list(directory.glob(pattern))
blinkvideodata = skvideo.io.vread(str(files[0]))
blinkvideodata = np.asarray([skvideo.io.vshape(frame)[0] for frame in blinkvideodata], dtype=np.uint8)
print(files[0])


In [None]:
# display camaravideo with slider
sh.display_volume_z(videodata,1)

In [None]:
# display 2pstackvideo with slider
sh.display_volume_z(blinkvideodata,1)

In [None]:
# display 2pstackvideo with slider
sh.display_volume_z(labeled_videodata,1)

In [None]:
#dimensionts of the original movie (frames, x,y,rgb)
print(videodata.shape)
print(blinkvideodata.shape)
print(labeled_videodata.shape)

#### Synchronization!

In [None]:
## Get the timestamp data and gate / offset cameraframes

# from the event table get the main recording gate start / end timestamps.
auxgatetimestamp_end = (event.Event()  &  "event_type='main_track_gate'" &  scan_key ).fetch('event_end_time')
auxgatetimestamp_start = (event.Event()  &  "event_type='main_track_gate'" &  scan_key ).fetch('event_start_time')

# Then return camera start timestamps within the recording gate only (not necessary: he important bit is to subtract the offsetframes after alignment)
cameratimestamps = (event.Event()  &  "event_type='aux_cam'" & f"event_start_time>{auxgatetimestamp_start[0]}" & f"event_start_time<{auxgatetimestamp_end[0]}" & scan_key).fetch('event_start_time')

cameraoffsetframes =len((event.Event()  &  "event_type='aux_cam'" & f"event_start_time<{auxgatetimestamp_start[0]}" & scan_key).fetch('event_start_time'))

#  and 2p timestamps (which will always be in the recording gate).
twoptimestamps = (event.Event()  &  "event_type='mini2p_frames'" &  scan_key ).fetch('event_start_time')

# align the two recordings by finding the indices of the closest camera timestamp that fits the 2p frame timestamps by sorted list insertion ("bisect"). Be aware: camera frames can be double.
aligned_cameraframes = np.array(get_closest_timestamps(twoptimestamps,cameratimestamps)) - cameraoffsetframes

# this should have the same shape as the 2p frames:
print(np.shape(aligned_cameraframes))

In [None]:
twoptimestamps =  twoptimestamps[:np.shape(blinkvideodata)[0]] # truncating 2p timestamps to number of 2p videoframes


In [None]:

# now use this camara frame indices to reslice the video (which now is aligned to the 2p frames on a frame-by-frame level)
# resliced_cam_video = videodata[aligned_cameraframes]
resliced_cam_video = labeled_videodata[aligned_cameraframes]


In [None]:
# display synchronized movie
sh.display_volume_z(resliced_cam_video,1)

In [None]:
# rescale camera movie to fit size of 2p movie (can take a lot of time and memory)
rescaled_cam_movie = resize_movie(resliced_cam_video, np.shape(blinkvideodata)[1],np.shape(blinkvideodata)[2])

In [None]:
np.shape(rescaled_cam_movie)

In [None]:
# concatenate and display movies
concatmovie = np.concatenate((blinkvideodata,rescaled_cam_movie), axis = 2)
sh.display_volume_z(concatmovie,1)

In [None]:
#save as new movie (without rescaling)


filename = str(directory) + '/aligned_stack_cam_movie.mp4'
fps = (scan.ScanInfo & scan_key).fetch('fps')
# p1 = 0
# p2 = 100
# trash = sh.make_stack_movie(concatmovie, filename, fps[0], p1, p2)

codecset = 'libx264'
import imageio
import imageio.plugins.ffmpeg as ffmpeg

# Create an imageio VideoWriter object to write the video
writer = imageio.get_writer(filename, fps=fps[0], codec=codecset, output_params=['-crf', '19'])

# # Calculate the 1st and 99th percentile
# p1, p99 = np.percentile(running_z_projection[:500,:,:], (p1set, p2set))

# # rescale to 8 bit
# rescaled_image_8bit = rescale_image_multithreaded(running_z_projection, p1, p99)

for page in concatmovie:
    writer.append_data(page)

# Close the video writer
writer.close()


In [None]:
# speed up, add timestamps etc - all with fast ffmpeg operations

import os

spedby = 5
setpts_value = 1/spedby # change this to your desired value
newfps = fps[0]*spedby

input_filename = filename
# 2. Add timestamps

output_filename = str(directory) + '/' + scansi + '_top_video_concatenated' + 'withtimestamps.mp4'
command = f"""ffmpeg -y -i {input_filename} -vf "drawtext=fontfile=/Library/Fonts/Arial.ttf:timecode='00\\\\:00\\\\:00\\\\:00':rate={fps[0]}:text='':fontsize=20:fontcolor=white:x=530:y=20:box=1:boxcolor=0x00000000@1" -f mp4 {output_filename}"""

os.system(command)


input_filename = output_filename  # 'sped_up_video.mp4'
output_filename = str(directory) + '/' +  scansi + '_top_video_concatenated_spedup_' + str(spedby) + 'fold_withtimestamps_labels.mp4'

command = f'ffmpeg -y -i {input_filename} -vf "setpts={setpts_value}*PTS" -r {newfps}  {output_filename}'



os.system(command)


## Plot activity

### Get the fluorescence traces of this recording

In [None]:
# get mask positions of masks that are classified as cells and that are larger than a certain pixel size
mask_xpix, mask_ypix = (
    imaging.Segmentation.Mask * imaging.MaskClassification.MaskType
    & scan_key
    & "mask_npix > 10"
    & "curation_id = 10" #SET CURATION ID OF MANUAL CURATION HERE
).fetch("mask_xpix", "mask_ypix")

Using this query, we've fetched the coordinates of segmented masks. We can overlay these
masks onto our average image.

In [None]:
mask_image = np.zeros(np.shape(average_image), dtype=bool)
for xpix, ypix in zip(mask_xpix, mask_ypix):
    mask_image[ypix, xpix] = True

In [None]:
# load image styles for display
mpl.rcParams.update(mpl.rcParamsDefault)
mpl.rcParams.update(img_params)

plt.imshow(average_image, cmap='gray', vmin = cmin, vmax = cmax)
plt.contour(mask_image, colors="red", linewidths=0.5)
plt.show()

One more example using queries - plot fluorescence and deconvolved activity
traces:

Here we fetch the primary key attributes of the entry with `curation_id=0` for the
current session in the `imaging.Curation` table. 

Then, we fetch all cells that fit the
restriction criteria from `imaging.Segmentation.Mask` and
`imaging.MaskClassification.MaskType` as a `projection`. 

We then use this projection as
a restriction to fetch and plot fluorescence and deconvolved activity traces from the
`imaging.Fluorescence.Trace` and `imaging.Activity.Trace` tables, respectively.

In [None]:
curation_key = (imaging.Curation & scan_key & "curation_id=10").fetch1("KEY")
query_cells = (
    imaging.Segmentation.Mask * imaging.MaskClassification.MaskType
    & curation_key
    & "mask_center_z=0"
    & "mask_npix > 10"
).proj()

# query_cells

In [None]:
imaging.Fluorescence.Trace & scan_key

In [None]:
neuropilcorr = True

fluorescence_traces = (imaging.Fluorescence.Trace & query_cells).fetch(
    "fluorescence", order_by="mask"
)

neuropil_traces = (imaging.Fluorescence.Trace & query_cells).fetch(
    "neuropil_fluorescence", order_by="mask"
)

if neuropilcorr:
    print("DOING VANILLA NEUROPIL CORRECTION NOW!")
    fluorescence_traces = fluorescence_traces - 0.7 * neuropil_traces

activity_traces = (imaging.Activity.Trace & query_cells).fetch(
    "activity_trace", order_by="mask"
)

sampling_rate = (scan.ScanInfo & curation_key).fetch1("fps")

# timebase_2p = np.r_[: fluorescence_traces[0].size] * 1 / sampling_rate

timebase_2p = np.linspace(0, fluorescence_traces[0].size / sampling_rate, fluorescence_traces[0].shape[0])


In [None]:
from rastermap import Rastermap
from scipy import stats 
from scipy.stats import zscore
from scipy.ndimage import gaussian_filter1d



# stack fluorescence for rastermap
fluos = np.vstack(fluorescence_traces)

nan_mask = np.isnan(fluos).any(axis=1)

# Create a mask for rows containing only zeros
zero_rows = np.all(fluos == 0, axis=1)

# Create a mask for rows containing only inf
inf_rows = np.all(np.isinf(fluos), axis=1)

# Create a mask for rows containing only NaN
nan_rows = np.all(np.isnan(fluos), axis=1)

# Combine the masks using logical OR
mask_to_remove = zero_rows | inf_rows | nan_rows | nan_mask

S = fluos[~mask_to_remove]
S = zscore(S, axis=1)

rmmodel = Rastermap(n_clusters=None, # None turns off clustering and sorts single neurons 
                  n_PCs=24, # use fewer PCs than neurons
                  locality=0.15, # some locality in sorting (this is a value from 0-1)
                  time_lag_window=15, # use future timepoints to compute correlation
                  grid_upsample=0, # 0 turns off upsampling since we're using single neurons
                ).fit(S)


y = rmmodel.embedding # neurons x 1
isort = rmmodel.isort

# sort by embedding and smooth over neurons (uncomment)

# Sfilt = gaussian_filter1d(S[isort], np.minimum(1,np.maximum(1,int(S.shape[0]*0.001))),axis=0)
Sfilt = S[isort]
Sfilt_backup = Sfilt

In [None]:
# Plot sorted data
# load plot styles for display

mpl.rcParams.update(mpl.rcParamsDefault)
mpl.rcParams.update(plot_params)

plt.figure(figsize=(5,3))
plt.imshow(Sfilt, vmin = -0.1, vmax=2, extent= [timebase_2p[0], timebase_2p[-1], 0, Sfilt.shape[0]], aspect='auto', cmap='gray_r', interpolation='none')
plt.xlabel('time [s]')
plt.ylabel('sorted neurons')
plt.show()

In [None]:
import seaborn as sns
figure = plt.figure(figsize=(15,15))
ax = figure.add_subplot(111)

offset_scaler = 10 # We want to plot every cell with a little offset to the last one
for no,trace in enumerate(Sfilt):
    # if no == 25: break # not more than 80

    # get the neuropil corrected values for that trace:
    # trace = Sfilt
    ax.plot(timebase_2p,trace + (no*offset_scaler),lw=1,c='k',alpha=.8)

ax.set_xlim(0,timebase_2p[-1])

ax.get_yaxis().set_ticks([])
ax.set_title('Sorted z-scored traces')    

ax.set_ylabel('Cells')
ax.set_xlabel('Time [s]')
sns.despine(left=True)

plt.show()

Binarized event visualization


In [None]:
import seaborn as sns
from scipy.ndimage import binary_dilation


method = "tank"
# method = "sd"
# method = "poor"

cutoff_std = float(2)
min_transient_length = .1

filtered_event_mask = []

for cell in range(len(activity_traces)):  
    # Horst's event filtering 
    re = th.FilterEvents(activity_traces[cell])
    if method == "tank":
        transient_dict = re.transients(fluorescence_traces[cell],
                    np.ones_like(fluorescence_traces[cell], dtype=bool),
                    sampling_rate, cutoff_std, min_transient_length, plot=False)
    elif method == "sd":
        transient_dict = re.robust(cutoff_std)  
    filtered_event_mask.append(transient_dict['mask_events'])

event_matrix = np.array(filtered_event_mask)


if method == "poor":
    event_matrix = np.array([item > cutoff_std for item in Sfilt], dtype=bool).astype(int)


# broaden events to generate non-event baseline for SNR calculation and second pass event detection with proper f0 baseline

cutoff_std = float(2)
min_transient_length = .250 # s

broaden_by = 1 # s

# Structuring element for one-second dilation
structure = np.ones(int(broaden_by * sampling_rate))

filtered_event_mask_2ndpass =[]
broadened_events = []
filtered_events = []
SNR = []
dFF0_traces = []

if method == "tank":
    for cell in range(len(activity_traces)):
        re = th.FilterEvents(activity_traces[cell])
        
        # make dF/F0 from first pass events
        F0mean = np.mean(fluorescence_traces[cell][np.logical_not(event_matrix[cell])])
        dFF0 = (fluorescence_traces[cell] - F0mean) / F0mean
        
        # Apply binary dilation on previous events and invert to get baseline
        
        broadened_events = binary_dilation(filtered_event_mask[cell], structure=structure) 
        
        transient_dict = re.transients(dFF0,
                    np.logical_not(broadened_events),
                    sampling_rate, cutoff_std, min_transient_length, plot=False)
 
        filtered_event_mask_2ndpass.append(transient_dict['mask_events'])
        filtered_events.append(transient_dict['filtered_events'])      
        
        SNR.append(np.mean(filtered_events[cell] / np.nanstd(dFF0[np.logical_not(broadened_events)])))
        dFF0_traces.append(dFF0)
        
    event_matrix_2ndpass = np.array(filtered_event_mask_2ndpass)



# sort events based on previous rastermap embedding
event_matrix = event_matrix[isort]
event_matrix_2ndpass = event_matrix_2ndpass[isort]
dFF0_traces = np.array(dFF0_traces)[isort]


# plot events and non-event epochs (with traces)

figure = plt.figure(figsize=(15,20))
ax = figure.add_subplot(111)


offset_scaler = 1.5 # We want to plot every cell with a little offset to the last one
for no, (trace, trace2) in enumerate(zip(event_matrix, event_matrix_2ndpass)):
    if no == 80: break # not more than 80

    ax.plot(timebase_2p, trace + (no * offset_scaler), lw=1, c='k', alpha=.8)
    ax.plot(timebase_2p, trace2 + (no * offset_scaler), lw=2, c='r', alpha=.8)


ax.set_xlim(0,timebase_2p[-1])

ax.get_yaxis().set_ticks([])
ax.set_title('Sorted binarized deconvolved traces')    

ax.set_ylabel('Cells')
ax.set_xlabel('Time [s]')
sns.despine(left=True)

plt.show()


#set used event_matrix
event_matrix = event_matrix_2ndpass
event_matrix_backup = event_matrix_2ndpass
SNR_matrix = np.array(SNR)[isort]

In [None]:
import seaborn as sns
figure = plt.figure(figsize=(8,8))
ax = figure.add_subplot(111)

SNRthresh = .5

offset_scaler = 12 # We want to plot every cell with a little offset to the last one
for no, (trace,evnts) in enumerate(zip(dFF0_traces[SNR_matrix > SNRthresh,:], event_matrix[SNR_matrix > SNRthresh,:] * 1)):
    if no == 15: break # not more than 25

    # get the neuropil corrected values for that trace:
    # trace = Sfilt
    ax.plot(timebase_2p,trace + (no*offset_scaler),lw=1,c='k',alpha=.8)
    ax.plot(timebase_2p, evnts + (no*offset_scaler - 2),lw=2,c='r',alpha=.8)

ax.set_xlim(0,timebase_2p[-1])

ax.get_yaxis().set_ticks([])
ax.set_title('Sorted, SNR-filtered dF/F0 traces')    

ax.set_ylabel('Cells')
ax.set_xlabel('Time [s]')
sns.despine(left=True)

plt.show()




In [None]:
# limit data to SNR threshold from hereon
Sfilt = Sfilt_backup[SNR_matrix > SNRthresh]
event_matrix = event_matrix[SNR_matrix > SNRthresh]

## Synchronization!

In [8]:
## Get the timestamp data and gate / offset cameraframes

# from the event table get the main recording gate start / end timestamps.
auxgatetimestamp_end = (event.Event()  &  "event_type='main_track_gate'" &  scan_key ).fetch('event_end_time')
auxgatetimestamp_start = (event.Event()  &  "event_type='main_track_gate'" &  scan_key ).fetch('event_start_time')

# Then return camera start timestamps within the recording gate only 
cameratimestamps = (event.Event()  &  "event_type='aux_cam'" & f"event_start_time>{auxgatetimestamp_start[0]}" & f"event_start_time<{auxgatetimestamp_end[0]}" & scan_key).fetch('event_start_time')
cameratimestamps_end = (event.Event()  &  "event_type='aux_cam'"  & f"event_start_time<{auxgatetimestamp_end[0]}" & scan_key).fetch('event_start_time')
cameraoffsetframes =len((event.Event()  &  "event_type='aux_cam'" & f"event_start_time<{auxgatetimestamp_start[0]}" & scan_key).fetch('event_start_time'))

cameratimestamps_low = (event.Event()  &  "event_type='aux_cam'"  & f"event_start_time<{auxgatetimestamp_end[0]}" & scan_key).fetch('event_start_time')
cameratimestamps_all = (event.Event()  &  "event_type='aux_cam'"   & scan_key).fetch('event_start_time')


#  and 2p timestamps (which will always be in the recording gate).
twoptimestamps = (event.Event()  &  "event_type='mini2p_frames'" &  scan_key ).fetch('event_start_time')

# align the two recordings by finding the indices of the closest camera timestamp that fits the 2p frame timestamps by sorted list insertion ("bisect"). Be aware: camera frames can be double.
aligned_cameraframes = np.array(get_closest_timestamps(twoptimestamps,cameratimestamps)) - cameraoffsetframes

# this should have the same shape as the 2p frames:
print(np.shape(aligned_cameraframes))
print(cameraoffsetframes)

# cameraoffsetframes = 0

(26500,)
115


In [None]:
print(len(cameratimestamps))
print(len(cameratimestamps_low))
print(cameraoffsetframes)
# np.shape(labeled_videodata)
np.shape(videodata)

### check for camera framedrops

In [None]:
from adamacs.paths import get_experiment_root_data_dir
behavior_path_relative = (event.BehaviorRecording.File & scan_key &  "filepath LIKE '%.mat%'").fetch1("filepath")
camera_timestamp_paths = list(find_full_path(
    get_experiment_root_data_dir(), behavior_path_relative
).parent.glob("*top_video_timestamps*.csv"))

camera_timestamp_paths[0]

In [None]:
allchans_sync = CamLoader_sync(camera_timestamp_paths[0]).data_for_insert()

In [None]:
allchans_sync[0]

In [None]:
np.diff(allchans_sync[0]["time"])

In [None]:
len(np.where(np.diff(allchans_sync[0]["time"]) > 30)[0])

In [None]:
# Define the bin width (1 ms in this case)
bin_width = .1

data = np.diff(new_timestamps)

# Calculate the number of bins
num_bins = int((max(data) - min(data)) / bin_width)

# Create the histogram
plt.hist(data, bins=num_bins, edgecolor='black')
# plt.hist(data, edgecolor='black')

# Add labels and title
plt.xlabel('Value')
plt.ylabel('Frequency')
plt.title('Histogram with 1ms Bin Width')


In [None]:
import numpy as np

def detect_frame_drops(timestamps, threshold):
    # Calculate time intervals between consecutive frames
    intervals = np.diff(timestamps)
    
    # Identify indices where the interval exceeds the threshold
    drop_indices = np.where(intervals > threshold)[0]
    
    return drop_indices

def insert_frames(timestamps, drop_indices):
    new_timestamps = np.copy(timestamps)
    
    for idx in drop_indices:
        # Insert a new timestamp halfway between the adjacent frames
        new_timestamp = (timestamps[idx] + timestamps[idx + 1]) / 2.0
        new_timestamps = np.insert(new_timestamps, idx + 1, new_timestamp)
    
    return new_timestamps

# Example usage:
timestamps = allchans_sync[0]["time"]
threshold = 25  # Adjust this threshold based on your application

drop_indices = detect_frame_drops(timestamps, threshold)
new_timestamps = insert_frames(timestamps, drop_indices)

print("Original Timestamps:", timestamps)
print("Indices of Frame Drops:", drop_indices)
print("New Timestamps with Inserted Frames:", new_timestamps)


In [None]:
drop_indices

### align movie to individual events and concatenate

In [None]:
# get cricket data aligned
leftcricket_start = (trial.Trial()  &  "trial_type='BPOD: SubCricket; 60'" &  scan_key ).fetch('trial_start_time')
rightcricket_start = (trial.Trial()  &  "trial_type='BPOD: SubCricket; -60'" &  scan_key ).fetch('trial_start_time')

allcricket_triggered = (event.Event()  &  "event_type='aux_bpod_visual'" &  scan_key ).fetch('event_start_time')
reward_triggered = (event.Event()  &  "event_type='aux_bpod_reward'" &  scan_key ).fetch('event_start_time')

target_triggered = (event.Event()  &  "event_type='bpod_at_target'" &  scan_key ).fetch('event_start_time')

# rightcricket_start = (trial.Trial()  &  "trial_type='BPOD: SubCricket; -60'" &  scan_key ).fetch('trial_start_time')


aligned_cameraframes_leftcricket_start = np.array(get_closest_timestamps(leftcricket_start,cameratimestamps)) - cameraoffsetframes

aligned_cameraframes_rightcricket_start = np.array(get_closest_timestamps(rightcricket_start,cameratimestamps)) - cameraoffsetframes


aligned_cameraframes_allcricket_triggered = np.array(get_closest_timestamps(allcricket_triggered,cameratimestamps)) - cameraoffsetframes
aligned_cameraframes_reward_triggered = np.array(get_closest_timestamps(reward_triggered,cameratimestamps)) - cameraoffsetframes

aligned_cameraframes_target_triggered = np.array(get_closest_timestamps(target_triggered,cameratimestamps)) - cameraoffsetframes

In [None]:
aligned_cameraframes_target_triggered[0]

In [None]:
# Fetching data from the event.Event and trial.Trial tables
event_data = (event.Event & scan_key).fetch('event_type', 'event_start_time', as_dict=True)
trial_data = (trial.Trial & scan_key).fetch('trial_id', 'trial_type', 'trial_start_time', as_dict=True)

pltn = djh.plot_event_trial_start_times(event_data, trial_data)
pltn.xlim([0, 20])
pltn.show()

In [None]:
# Create an array with even-aligned videosnippets 

# Assuming videodata and cameratimestamps are defined
cameraspeed = 1 / (np.mean(np.diff(cameratimestamps)))
frames_per_second = int(cameraspeed)
starter = int(.1 * frames_per_second)
ender = int(2 * frames_per_second)

def create_snippets(videodata, indices, fps):
    snippets = []
    for index in indices:
        start = max(index - starter, 0)
        end = min(index + ender, len(videodata))
        snippet = videodata[start:end]
        snippets.append(snippet)
    return snippets

snippets = create_snippets(labeled_videodata, aligned_cameraframes_rightcricket_start, frames_per_second)
# vertical_concat = np.concatenate(snippets, axis=1)  # Vertical concatenation
# horizontal_concat = np.concatenate(snippets, axis=2)  # Horizontal concatenation
# overview_video = np.concatenate([vertical_concat, horizontal_concat], axis=0)




In [None]:
sh.display_volume_z(snippets[10],0)

In [None]:
# Equalize the snippet length before concatenating

import math

def equalize_snippets(snippets):
    # Find the maximum number of frames in any snippet
    max_frames = max(snippet.shape[0] for snippet in snippets)

    # Equalize the number of frames in each snippet
    equalized_snippets = []
    for snippet in snippets:
        if snippet.shape[0] < max_frames:
            # Pad the snippet with blank frames or repeat the last frame
            padding = np.zeros((max_frames - snippet.shape[0],) + snippet.shape[1:])
            snippet_padded = np.concatenate([snippet, padding], axis=0)
            equalized_snippets.append(snippet_padded)
        else:
            equalized_snippets.append(snippet)
    
    return equalized_snippets

# Equalize the snippets before concatenating
equalized_snippets = equalize_snippets(snippets)
# vertical_concat = np.concatenate(equalized_snippets, axis=1)
# horizontal_concat = np.concatenate(equalized_snippets, axis=2)
# overview_video = np.concatenate([vertical_concat, horizontal_concat], axis=0)


# Assuming equalized_snippets is a list of video snippets of equal frame counts
def concatenate_to_grid(snippets):
    # Determine the grid size
    num_snippets = len(snippets)
    grid_size = int(math.ceil(math.sqrt(num_snippets)))

    # Initialize placeholders for rows and the final grid
    rows = []
    final_grid = None

    # Concatenate snippets into rows
    for i in range(0, num_snippets, grid_size):
        row = snippets[i:i + grid_size]
        while len(row) < grid_size:  # Pad the row if necessary
            row.append(np.zeros_like(snippets[0]))
        concatenated_row = np.concatenate(row, axis=2)  # Concatenate along width
        rows.append(concatenated_row)

    # Concatenate rows to form the grid
    final_grid = np.concatenate(rows, axis=1)  # Concatenate along height

    return final_grid

# Concatenate the snippets into a grid
grid_video = concatenate_to_grid(equalized_snippets[1:])


In [None]:
sh.display_volume_z(grid_video,0)

In [None]:
# scale to 4k movie

import cv2
import numpy as np

def resize_video_frames(grid_data, target_width=5504, target_height=5008):
    resized_video = []

    for frame in grid_data:
        # Resize frame to 4K resolution
        resized_frame = cv2.resize(frame, (target_width, target_height), interpolation=cv2.INTER_AREA)
        resized_video.append(resized_frame)

    return np.array(resized_video)

# Assuming grid_data is a numpy array of video frames
grid_data_4k = resize_video_frames(grid_video)


In [None]:
# write movie

import imageio
import imageio.plugins.ffmpeg as ffmpeg

#save as new movie (without rescaling)
scandir = (scan.ScanPath & scan_key).fetch('path')[0]

directory = Path(scandir + "/")

filename = str(directory) + '/concatenated_peri_rightcricket_movie2.mp4'
fps = frames_per_second
# p1 = 0
# p2 = 100
# trash = sh.make_stack_movie(concatmovie, filename, fps[0], p1, p2)

codecset = 'libx264'


# Create an imageio VideoWriter object to write the video
writer = imageio.get_writer(filename, fps=fps, codec=codecset, output_params=['-crf', '19'])

# # Calculate the 1st and 99th percentile
# p1, p99 = np.percentile(running_z_projection[:500,:,:], (p1set, p2set))

# # rescale to 8 bit
# rescaled_image_8bit = rescale_image_multithreaded(running_z_projection, p1, p99)

for page in grid_data_4k:
    writer.append_data(page)

# Close the video writer
writer.close()
print(filename)

In [None]:
leftcricket_start = (trial.Trial()  &  "trial_type='BPOD: SubCricket; 60'" &  scan_key ).fetch('trial_start_time')
rightcricket_start = (trial.Trial()  &  "trial_type='BPOD: SubCricket; -60'" &  scan_key ).fetch('trial_start_time')

In [None]:
# the timestamps of the video synchronization from above are the one to use for synchronized plotting of positions etc: aligned_cameraframes
print(np.shape(cameratimestamps))
print(np.shape(aligned_cameraframes))
print(np.shape(twoptimestamps))

### now do some positional plotting!

In [12]:
dlc_scan_key = (model.PoseEstimation & f'recording_id = "{scan_key[0]["scan_id"]}"').fetch('KEY')
path = (model.VideoRecording.File & scan_key).fetch("file_path")
path

array(['/datajoint-data/data/tobiasr/NK_ROS-1629_2023-10-19_scan9FKNRW9Y_sess9FKNRW9Y/scan9FKNRW9Y_mini2p1_top_video_2023-10-19T10_44_15.mp4'],
      dtype=object)

In [15]:
f'recording_id = "{scan_key[0]["scan_id"]}"'


'recording_id = "scan9FKNRW9Y"'

In [17]:
model.PoseEstimation.task

AttributeError: type object 'PoseEstimation' has no attribute 'task'

In [10]:
#reduce dataframe to xy coordinates
# dlc_scan_key =dlc_scan_key[0]
df=model.PoseEstimation.get_trajectory(dlc_scan_key)
df_xy = df.iloc[:,df.columns.get_level_values(2).isin(["x","y"])]['Topcam_2bin_without_scope']
# df_xy.mean()
# df_xy
df_xy.plot().legend(loc='right')
plt.show()

TypeError: list indices must be integers or slices, not str

In [None]:
df_flat = df_xy.copy()
df_flat.columns = df_flat.columns.map('_'.join)

fig,ax=plt.subplots()
# df_flat.plot(x='body_middle_x',y='body_middle_y',ax=ax)
df_flat.plot(x='head_middle_x',y='head_middle_y', ax=ax)
# df_flat.plot(x='tail_x',y='tail_y', ax=ax)
ax.set_aspect('equal')
plt.title(scan_key)
plt.show()

In [None]:
position = df_flat[['body_middle_x', 'body_middle_y']].values

In [None]:
# plot events over position

position = df_flat[['body_middle_x', 'body_middle_y']].values
position = position[aligned_cameraframes].T/10 # synchronize to 2pframes and translate for opexebo - THIS IS A GUESSTIMATE!  pretending 1px = 1mm NEEDS CALIBRATION - tracking needs to be in xy real-world coordinates (in cm)


total_cells = np.shape(event_matrix)[0] # Change this to the desired number of cells

# Determine the grid dimensions (for a roughly square arrangement)
nrows = int(np.ceil(np.sqrt(total_cells)))
ncols = int(np.ceil(total_cells / nrows))

fig, axs = plt.subplots(nrows=nrows, ncols=ncols, figsize=(30, 30))
fig.subplots_adjust(hspace=0.1) # Add some space between the subplots

# If axs is not already a 2D array (e.g., if total_cells is a perfect square), make it one
# If axs is not already a 2D array (e.g., if total_cells is a perfect square), make it one
if total_cells != nrows * ncols:
    axs = axs.reshape(-1)
else:
    axs = axs.flatten()



# load image styles for display
# load image styles for display
mpl.rcParams.update(mpl.rcParamsDefault)
mpl.rcParams.update(map_params)

for cell in range(total_cells):
    # try:    
    ax = axs[cell]
    # Plotting the line plot first
    ax.plot(position[0], position[1], color='grey')

    # spike events at position
    spikes_at_pos = np.vstack((position[0, event_matrix[cell].astype("bool")], position[1, event_matrix[cell].astype("bool")]))
    
    # Then plotting the scatter plot so that it's on top of the line
    ax.scatter(spikes_at_pos[0], spikes_at_pos[1], color='red', alpha = 0.2, zorder=2)

    ax.set_aspect('equal')
    ax.set_title(scan_key[0]["scan_id"] + "_" + str(cell+1))
    # except:
    #     print(f'error at cell{cell}')
# Remove any extra subplots
for cell in range(total_cells, nrows * ncols):
    fig.delaxes(axs[cell])

plt.show()

In [None]:
scan_key[0]

In [None]:
# make masked spatial occupancy map - OPEXEBO

import opexebo

arena_size = 100 # in cm - NEEDS CALIBRATED TRACKING COORDS!
arena_shape = "circle"
bin_width =  4 # cm

masked_occupancy_map, coverage, bin_edges = opexebo.analysis.spatial_occupancy(timebase_2p, position, arena_size, arena_shape = arena_shape, bin_width = bin_width)

plt.figure(figsize=(5,3))
plt.imshow(np.flipud(masked_occupancy_map))
cbar = plt.colorbar()
cbar.set_label('time / bin [s]')
plt.title(f'spatial occupancy - {scan_key[0]["scan_id"]}')
plt.show()

In [None]:
# plot rate maps - OPEXEBO
from scipy.ndimage import gaussian_filter


# Determine the grid dimensions (for a roughly square arrangement)
nrows = int(np.ceil(np.sqrt(total_cells)))
ncols = int(np.ceil(total_cells / nrows))

fig, axs = plt.subplots(nrows=nrows, ncols=ncols, figsize=(15, 15))
fig.subplots_adjust(hspace=0.1) # Add some space between the subplots

# If axs is not already a 2D array (e.g., if total_cells is a perfect square), make it one
if total_cells != nrows * ncols:
    axs = axs.reshape(-1)
else:
    axs = axs.flatten()


# load image styles for display
mpl.rcParams.update(mpl.rcParamsDefault)
mpl.rcParams.update(map_params)

for cell in range(total_cells):
    ax = axs[cell]
    # print(cell)
    try:
        # spike events at position
        spikes_at_pos = np.vstack((position[0, event_matrix[cell].astype("bool")], position[1, event_matrix[cell].astype("bool")]))
        
        # time at postiion
        time_at_pos = (timebase_2p[event_matrix[cell].astype("bool")])
        
        # spikes_tracking [t,x,y]
        spikes_tracking = np.vstack((time_at_pos, spikes_at_pos))
        
        # make ratemap
        rate_map = opexebo.analysis.rate_map(masked_occupancy_map, spikes_tracking, arena_size, arena_shape = arena_shape, bin_width = bin_width)
        rate_map = np.flipud(rate_map)
        # filtered_rate_map = gaussian_filter(rate_map, sigma = 0.5)
        
        
        im = ax.imshow(rate_map, vmin = 0, vmax = 2)

        ax.set_aspect('equal')
        ax.set_title(scan_key[0]["scan_id"] + "-" + str(cell+1))
        # cbar = plt.colorbar(im, ax=ax) # Pass the image object and the ax to plt.colorbar
        # cbar.set_label('events / bin [s]')
    except:
        print(f'error at cell{cell}')

# Remove any extra subplots
for cell in range(total_cells, nrows * ncols):
    fig.delaxes(axs[cell])

plt.show()

### speed tuning - freely moving

In [None]:
# get running speed - OPEXEBO
new_speed = opexebo.analysis.calc_speed(timebase_2p, position[0], position[1], moving_average = 7)

In [None]:

from matplotlib.ticker import MaxNLocator

mpl.rcParams.update(mpl.rcParamsDefault)
mpl.rcParams.update(plot_params)

# get some scaling values from pixel distribution
scalemin = 0
scalemax = 99.7

cmin = np.percentile(new_speed,scalemin)  
cmax = np.percentile(new_speed,scalemax)

kp_colors = np.array([[0.55,0.55,0.55]])


# timepoints to visualize
tstart = 0
tend =  timebase_2p[-1] -1

xmin = int(np.floor(tstart * sampling_rate))
xmax = int(np.floor(tend * sampling_rate))

# make figure with grid for easy plotting
fig = plt.figure(figsize=(8,5), dpi = 200)
grid = plt.GridSpec(9, 20, figure=fig, wspace = 0.05, hspace = 0.3)

# plot running speed
ax = plt.subplot(grid[:2, :-1])
ax.plot(new_speed,  color=kp_colors[0])
ax.set_xlim([0, xmax-xmin])
ax.set_ylim([cmin, cmax])

ax.axis("off")
ax.set_title("freely moving running speed", color=kp_colors[0])
# ax.set_xlabel("running speed")


# plot superneuron activity
ax = plt.subplot(grid[2:, :-1])
ax.imshow(Sfilt[:, xmin:xmax], cmap="gray_r", vmin=-0.1, vmax=1,  extent= [timebase_2p[xmin], timebase_2p[xmax], 0, Sfilt.shape[0]], aspect="auto", interpolation='none')
ax.set_xlabel("time [s]")
ax.set_ylabel("sorted cells")

ax.xaxis.set_major_locator(MaxNLocator(integer=True))
ax.yaxis.set_major_locator(MaxNLocator(integer=True))
plt.show()

# ax = plt.subplot(grid[1:, -1])
# ax.imshow(np.arange(0, len(sn))[:,np.newaxis], cmap="gist_ncar", aspect="auto")
# ax.axis("off")

### IMU

In [None]:
# GET IMU data

accelerometer = (behavior.HarpRecording.Channel() & scan_key & "channel_name LIKE 'IMU accelerometer %'").fetch("data")
gyroscope = (behavior.HarpRecording.Channel() & scan_key & "channel_name LIKE 'IMU gyroscope %'").fetch("data")
magnetometer = (behavior.HarpRecording.Channel() & scan_key & "channel_name LIKE 'IMU magnetometer %'").fetch("data")
IMU_twopframes = (behavior.HarpRecording.Channel() & scan_key & "channel_name LIKE '2p %'").fetch("data")
IMU_time = (behavior.HarpRecording.Channel() & scan_key & "channel_name LIKE '2p %'").fetch("time")

In [None]:
# SYNC IMU data

propersync = np.max(IMU_twopframes[0]).astype("bool")

## Get the timestamp data

# from the event table get the main recording gate start / end HARP gate timestamps.
harpgatetimestamp_end = (event.Event()  &  "event_type='HARP_gate'" &  scan_key ).fetch('event_end_time')
harpgatetimestamp_start = (event.Event()  &  "event_type='HARP_gate'" &  scan_key ).fetch('event_start_time')
#
# if propersync:
    # print("2p timestamps detected")
# else:
print("No 2p timestamps in IMU rec detected - poor man's single-point sync")
IMU_time = IMU_time - harpgatetimestamp_start
# HARP and AUX not in sync!
print(IMU_time[0][-1]/1000)  
print(harpgatetimestamp_end[0])

# Therefore: space number of HARP samples evenly between HARP gate timestampa.
harpgate_sync_timestamps = np.squeeze(np.linspace(harpgatetimestamp_start, harpgatetimestamp_end, np.shape(magnetometer[0])[0]))

#  and 2p timestamps (which will always be in the recording gate).
twoptimestamps = (event.Event()  &  "event_type='mini2p_frames'" &  scan_key ).fetch('event_start_time')


# get indices
aligned_IMU_indices = get_closest_timestamps(twoptimestamps,harpgate_sync_timestamps) #smoothing windwo from above

In [None]:
# UPDATE: use bidirectional filtering to prevent phase shifts between original and filtered signal
from scipy.signal import butter, filtfilt

# Design the Butterworth filter for accelerometer and magnetometer
N = 6 # Order of the filter
Wn = 0.03 # Cutoff frequency (example value, should be chosen based on your specific needs)
b, a = butter(N, Wn, btype='low')

# Design the Butterworth filter for gyroscope
N = 6 # Order of the filter
Wn = 0.005 # Cutoff frequency (example value, should be chosen based on your specific needs)
d, c = butter(N, Wn, btype='low')

filtered_accelerometer = [filtfilt(b, a, array) for array in accelerometer]
filtered_gyroscope = [filtfilt(d, c, array) for array in gyroscope]
filtered_magnetometer = [filtfilt(b, a, array) for array in magnetometer]
filtered_newspeed = filtfilt(b, a, new_speed)


In [None]:
# generate event trace for light stim

# from the event table get the main recording gate start / end HARP gate timestamps.
flash_timestamp_end = (event.Event()  &  "event_type='arena_LED'" &  scan_key ).fetch('event_end_time')
flash_timestamp_start = (event.Event()  &  "event_type='arena_LED'" &  scan_key ).fetch('event_start_time')
twoptimestamps = (event.Event()  &  "event_type='mini2p_frames'" &  scan_key ).fetch('event_start_time')

# get indices
aligned_flash_timestamp_end  = get_closest_timestamps(flash_timestamp_end, twoptimestamps) #smoothing windwo from above
aligned_flash_timestamp_start  = get_closest_timestamps(flash_timestamp_start, twoptimestamps) #smoothing windwo from above

# Create a linear array of zeros of length 10
array_length = np.shape(twoptimestamps)[0]
flash_array = np.zeros(array_length)

# Iterate through the start and stop times and set the corresponding elements to 1
for start, stop in zip(aligned_flash_timestamp_start, aligned_flash_timestamp_end):
    flash_array[start:stop] = 1

In [None]:
from matplotlib.ticker import MaxNLocator

plotdata = Sfilt_backup
# plotdata = Sfilt

mpl.rcParams.update(mpl.rcParamsDefault)
mpl.rcParams.update(plot_params)

# get some scaling values from pixel distribution
# SPEED 
scalemin = 0
scalemax = 99.7

cmin = np.percentile(new_speed,scalemin)  
cmax = np.percentile(new_speed,scalemax)

# SPEED threshold
# Calculate the 50 th percentile on the lowpass filtered speed trace to obtain thresholds for (smooth) division into running / not running 
cropped_speed = filtered_newspeed[filtered_newspeed <= 500]
perc = 50
speed_thresh = np.percentile(cropped_speed, perc)

# ACC 
scalemin = 0
scalemax = 100

acc_cmin = np.percentile(np.concatenate(filtered_accelerometer),scalemin)  
acc_cmax = np.percentile(np.concatenate(filtered_accelerometer),scalemax)

# GYR 
scalemin = 0
scalemax = 100

gyr_cmin = np.percentile(np.concatenate(filtered_gyroscope),scalemin)  
gyr_cmax = np.percentile(np.concatenate(filtered_gyroscope),scalemax)

# MAG 
scalemin = 0
scalemax = 100

mag_cmin = np.percentile(np.concatenate(filtered_magnetometer),scalemin)  
mag_cmax = np.percentile(np.concatenate(filtered_magnetometer),scalemax)



kp_colors = np.array([[0,0,0], [0.55,0.55,0.55], [0,0.9,0.9]])

imu_colors = np.array([
    [[1.0, 0.75, 0.8], [0.8, 0.65, 0.68], [0.55, 0.55, 0.55]],
    [[0.95, 0.6, 0.95], [0.75, 0.58, 0.75], [0.55, 0.55, 0.55]],
    [[0.65, 0.95, 0.95], [0.6, 0.75, 0.75], [0.55, 0.55, 0.55]]
])

# timepoints to visualize
tstart = 0
tend =  timebase_2p[-1] -1

xmin = int(np.floor(tstart * sampling_rate))
xmax = int(np.floor(tend * sampling_rate))

# make figure with grid for easy plotting
fig = plt.figure(figsize=(16,10), dpi = 200)
grid = plt.GridSpec(20, 20, figure=fig, wspace = 0.5, hspace = 0.3)

# plot running speed
ax = plt.subplot(grid[:2, :-1])
ax.plot(new_speed,  color=kp_colors[0])
# ax.plot(filtered_newspeed,  color=kp_colors[2])
ax.set_xlim([0, xmax-xmin])
ax.set_ylim([cmin, cmax])
ax.tick_params(left=False, right=False, bottom=False, top=False,
               labelleft=False, labelbottom=False)
for spine in ax.spines.values():
    spine.set_visible(False)
    
ax.fill_between(range(len(new_speed)), scalemax, where=filtered_newspeed>speed_thresh, color='red', alpha=0.1)

ax.set_ylabel("speed")
ax.yaxis.set_label_position("left")

# ax.set_title("freely moving running speed", color=kp_colors[0])
# ax.set_xlabel("running speed")

# plot accelerometer
sliced_accelerometer = [array[aligned_IMU_indices] for array in filtered_accelerometer]
ax = plt.subplot(grid[2:4, :-1])
for i, arr in enumerate(sliced_accelerometer):
    ax.plot(arr, label=f'accelerometer {i+1}', color=imu_colors[0][i])
ax.set_xlim([0, xmax-xmin])
ax.set_ylim([acc_cmin, acc_cmax])
ax.tick_params(left=False, right=False, bottom=False, top=False,
               labelleft=False, labelbottom=False)
for spine in ax.spines.values():
    spine.set_visible(False)
ax.set_ylabel("acc")
ax.yaxis.set_label_position("left")

# ax.set_title("freely moving running speed", color=kp_colors[0])
# ax.set_xlabel("running speed")

# plot gyroscope
sliced_gyroscope = [array[aligned_IMU_indices] for array in filtered_gyroscope]
ax = plt.subplot(grid[4:6, :-1])
for i, arr in enumerate(sliced_gyroscope):
    ax.plot(arr, label=f'gyroscope {i+1}', color=imu_colors[1][i])
ax.set_xlim([0, xmax-xmin])
ax.set_ylim([gyr_cmin, gyr_cmax])
ax.tick_params(left=False, right=False, bottom=False, top=False,
               labelleft=False, labelbottom=False)
for spine in ax.spines.values():
    spine.set_visible(False)
ax.set_ylabel("gyr")
ax.yaxis.set_label_position("left")

# plot magnetometer
sliced_magnetometer= [array[aligned_IMU_indices] for array in filtered_magnetometer]
ax = plt.subplot(grid[6:8, :-1])
for i, arr in enumerate(sliced_magnetometer):
    ax.plot(arr, label=f'magnetometer {i+1}', color=imu_colors[2][i])
ax.set_xlim([0, xmax-xmin])
ax.set_ylim([mag_cmin, mag_cmax])
ax.tick_params(left=False, right=False, bottom=False, top=False,
               labelleft=False, labelbottom=False)
for spine in ax.spines.values():
    spine.set_visible(False)
ax.set_ylabel("mag")
ax.yaxis.set_label_position("left")
# ax.set_title("freely moving running speed", color=kp_colors[0])
# ax.set_xlabel("running speed")

# plot LIGHT flash
ax = plt.subplot(grid[8:10, :-1])
ax.plot(flash_array,  color=kp_colors[1])
ax.set_xlim([0, xmax-xmin])
ax.set_ylim([-.1, 1.1])
ax.tick_params(left=False, right=False, bottom=False, top=False,
               labelleft=False, labelbottom=False)
for spine in ax.spines.values():
    spine.set_visible(False)
ax.set_ylabel("flash")
ax.yaxis.set_label_position("left")
ax.fill_between(range(len(new_speed)), 1, where=filtered_newspeed>speed_thresh, color='red', alpha=0.2)


# plot neuronal activity
ax = plt.subplot(grid[10:, :-1])
ax.imshow(plotdata[:, xmin:xmax], cmap="gray_r", vmin=0.2, vmax=1,  extent= [timebase_2p[xmin], timebase_2p[xmax], 0, plotdata.shape[0]], aspect="auto", interpolation = "none")
ax.set_xlabel("time [s]")
ax.set_ylabel("sorted cells")

ax.xaxis.set_major_locator(MaxNLocator(integer=True))
ax.yaxis.set_major_locator(MaxNLocator(integer=True))
plt.show()

# ax = plt.subplot(grid[1:, -1])
# ax.imshow(np.arange(0, len(sn))[:,np.newaxis], cmap="gist_ncar", aspect="auto")
# ax.axis("off")

### Event-aligned plotting - Flash onsets / Movement onsets

In [None]:

# PSTH function
def plot_aligned_data(data, eventarray, color, colortrace, label, plot_average=False):
    if len(data.shape) == 1:
        data = data[np.newaxis, :]
    
    segments = []
    for row in data:
        for event in eventarray:
            start = int(event - pre_event_window)
            end = int(event + post_event_window)
            segment_length = end - start
            time_range = np.linspace(-pre_event_window, post_event_window, segment_length) / sampling_rate

            if len(time_range) != len(row[start:end]):
                # print(f"Skipping segment due to mismatched length: {len(time_range)} vs {len(row[start:end])}")
                continue
            segments.append(row[start:end])

    if plot_average:
        avg_segment = np.mean(segments, axis=0)
        std_dev = np.std(segments, axis=0)
        plt.fill_between(time_range, avg_segment - std_dev, avg_segment + std_dev, color=color, alpha=0.3)
        plt.plot(time_range, avg_segment, color = colortrace, lw=2, label='Average ' + label)
        plt.axvline(0, color='black', linestyle='--')
        plt.xlabel("Time [s]")
        plt.ylabel(r"$\Delta$F/F$_0$ [%]")
        plt.ylim([cmin, cmax])


In [None]:
# Running indices
running_indices = filtered_newspeed > speed_thresh

# PSTH
pre = .5  # s
post = 2  # s

cmin = -100
cmax = 250

# Window around events to plot: 2 seconds before and 3 seconds after
pre_event_window = np.floor(pre * sampling_rate)
post_event_window = np.floor(post * sampling_rate)
aligned_flash_timestamp_start = np.array(aligned_flash_timestamp_start)[(aligned_flash_timestamp_start > pre_event_window)]


# Get the total number of cells to plot - HERE ONLY  CROSSING SNRthresh
num_cells = dFF0_traces[SNR_matrix > SNRthresh, :].shape[0]
# num_cells = dFF0_traces.shape[0]

# Determine the layout for the subplots
cols = int(math.ceil(math.sqrt(num_cells)))
rows = int(math.ceil(num_cells / cols))

fig, axarr = plt.subplots(rows, cols, figsize=(8,10))

for cell in range(num_cells):
    row = cell // cols
    col = cell % cols

    timeseries_data = dFF0_traces[SNR_matrix > SNRthresh, :][cell] * 100
    # timeseries_data = dFF0_traces[cell] * 100


    plt.sca(axarr[row, col])  # Set the current subplot

    plot_aligned_data(timeseries_data[:], aligned_flash_timestamp_start, 'grey', 'red', 'Cell ' + str(cell), plot_average=True)
    plt.title("Cell " + str(cell))

    
plt.suptitle("all data")
plt.tight_layout()


plt.show()





In [None]:
#  Running indices
running_indices = filtered_newspeed > speed_thresh

running_event_indices = np.array(sorted(list(set(aligned_flash_timestamp_start) & set(np.where(running_indices)[0]))))
not_running_event_indices = np.array(sorted(list(set(aligned_flash_timestamp_start) - set(np.where(running_indices)[0]))))


# PSTH
pre = .5  # s
post = 2  # s

cmin = -100
cmax = 250

# Window around events to plot: 2 seconds before and 3 seconds after
pre_event_window = np.floor(pre * sampling_rate)
post_event_window = np.floor(post * sampling_rate)
aligned_flash_timestamp_start = np.array(aligned_flash_timestamp_start)[(aligned_flash_timestamp_start > pre_event_window)]

# Adjust the event timestamps for the running condition
adjusted_events = []
for event in aligned_flash_timestamp_start:
    adjustment = np.sum(running_indices[:event])
    adjusted_events.append(event + adjustment)

# Get the total number of cells to plot - HERE ONLY  CROSSING SNRthresh
num_cells = dFF0_traces[SNR_matrix > SNRthresh, :].shape[0]
# num_cells = dFF0_traces.shape[0]


# Determine the layout for the subplots
cols = int(math.ceil(math.sqrt(num_cells)))
rows = int(math.ceil(num_cells / cols))

fig, axarr = plt.subplots(rows, cols, figsize=(8,10))

for cell in range(num_cells):
    row = cell // cols
    col = cell % cols

    # timeseries_data = dFF0_traces[cell] * 100
    timeseries_data = dFF0_traces[SNR_matrix > SNRthresh, :][cell] * 100
    plt.sca(axarr[row, col])  # Set the current subplot

    plot_aligned_data(timeseries_data[:], running_event_indices, 'red', 'red', 'Cell' + str(cell), plot_average=True)
    plot_aligned_data(timeseries_data[:], not_running_event_indices, 'grey', 'black',  'Cell ' + str(cell), plot_average=True)
    plt.title("Cell " + str(cell))

plt.suptitle("all data")
plt.tight_layout()
plt.show()


In [None]:
aligned_flash_timestamp_start[np.where(running_indices)[0]]

In [None]:
np.shape(aligned_flash_timestamp_start)

In [None]:
np.shape([np.where(running_indices)[0]])


In [None]:
set((aligned_flash_timestamp_start, np.where(running_indices)[0]))

In [None]:
aligned_flash_timestamp_start


In [None]:
running_event_indices = np.array(sorted(list(set(aligned_flash_timestamp_start) & set(np.where(running_indices)[0]))))
not_running_event_indices = np.array(sorted(list(set(aligned_flash_timestamp_start) - set(np.where(running_indices)[0]))))


In [None]:
intersection_indices

In [None]:
not_intersecting_indices

In [None]:
# Running indices
running_indices = filtered_newspeed > speed_thresh

# PSTHdd
pre = .5  # s
post = 2  # s

cmin = -100
cmax = 250

# Window around events to plot: 2 seconds before and 3 seconds after
pre_event_window = np.floor(pre * sampling_rate)
post_event_window = np.floor(post * sampling_rate)
aligned_flash_timestamp_start = np.array(aligned_flash_timestamp_start)[(aligned_flash_timestamp_start > pre_event_window)]


# Get the total number of cells to plot
# num_cells = dFF0_traces[SNR_matrix > SNRthresh, :].shape[0]
num_cells = dFF0_traces.shape[0]

# Determine the layout for the subplots
cols = int(math.ceil(math.sqrt(num_cells)))
rows = int(math.ceil(num_cells / cols))

fig, axarr = plt.subplots(rows, cols, figsize=(12,15))

for cell in range(num_cells):
    row = cell // cols
    col = cell % cols

    # timeseries_data = dFF0_traces[SNR_matrix > SNRthresh, :][cell] * 100
    timeseries_data = dFF0_traces[running_indices][cell] * 100


    plt.sca(axarr[row, col])  # Set the current subplot

    plot_aligned_data(timeseries_data[:], 'red', 'Cell ' + str(cell), plot_average=True)
    plt.title("Cell " + str(cell))

    
plt.suptitle("running")
plt.tight_layout()


plt.show()


# some behavior analyis


data['custom']['database.prefix']