In [53]:
import cv2
import numpy as np
import matplotlib.pyplot as plt
import xarray as xr
import json
import scipy
from bisect import bisect_left
import pickle as pkl
import os

In [54]:
def load_bpod_data(dpath):
    with open(dpath, 'r') as file:
        data = json.load(file)
    return data

def get_event_time(event, bpod_data):
    """
    Input a event name and bpod_data. Return the time stamp of the event(relative to trial start)
    """
    onsets = []
    for trial in bpod_data["SessionData"]['RawEvents']['Trial']:
        if not trial['States']['WrongPort'][0]:
            if type(trial['States'][event][0]) != float:
                onsets.append(np.squeeze(trial['States'][event][0])[0])
            else:
                onsets.append(np.squeeze(trial['States'][event])[0])
    return np.array(onsets)

def take_closest(ms_stamps, bpod_stamps):
    """
    Assumes ms_stamps is sorted. Returns closest stamps to the given bpod_stamps.

    If two numbers are equally close, return the smallest number.
    """
    msAlignedIndex = []
    for bpod_stamp in bpod_stamps:
        pos = bisect_left(ms_stamps, bpod_stamp)
        if pos == 0:
            return ms_stamps[0]
        if pos == len(ms_stamps):
            return ms_stamps[-1]
        before = ms_stamps[pos - 1]
        after = ms_stamps[pos]
        if after - bpod_stamp < bpod_stamp - before:
            msAlignedIndex.append(np.where(ms_stamps ==after))
        else:
            msAlignedIndex.append(np.where(ms_stamps ==before))
    return np.squeeze(np.array(msAlignedIndex))

In [55]:
mouse = 'ZZ0024-LR' 
mouse_pair = 'ZZ0024-L_ZZ0024-LR'
date = '20240926'

In [56]:
ms_stamps = np.squeeze(scipy.io.loadmat(f"/Users/fgs/HMLworkplace/Arena_analysis/Data/ws_data/processed/{mouse}/{mouse}_{date}_ms.mat")['ms_frames_samplingstamps'])
bpod_trialstart_stamps = np.squeeze(scipy.io.loadmat(f"/Users/fgs/HMLworkplace/Arena_analysis/Data/ws_data/processed/{mouse}/{mouse}_{date}_bpod.mat")['bpod_trialstart_samplingstamps'])

bpod_data = load_bpod_data(f'/Users/fgs/HMLworkplace/Arena_analysis/Data/bpod_data/{mouse_pair}/{mouse_pair}_{date}.json')

minian_results = xr.open_dataset(f"/Users/fgs/HMLworkplace/Arena_analysis/Data/minian_data/{mouse}_{date}.netcdf")
calcium_traces = minian_results.C

with open(f'/Users/fgs/HMLworkplace/Arena_analysis/Data/bahavior_led_frames/{mouse_pair}_{date}.pickle', 'rb') as f:
    dlc_trialStartStamps = np.array(pkl.load(f))

In [57]:
led_frames_sample = dlc_trialStartStamps[10:21]
print(led_frames_sample[-1] - led_frames_sample[0])
led_frames_sample

2985


array([3377, 3676, 3998, 4236, 4549, 4856, 5133, 5445, 5802, 6133, 6362])

In [58]:
ms_trialstarts = take_closest(ms_stamps,bpod_trialstart_stamps)
ms_sample_trialstarts = ms_trialstarts[10:21]
print(ms_sample_trialstarts[-1] - ms_sample_trialstarts[0])
ms_sample_trialstarts

2972


array([3832, 4129, 4450, 4688, 4999, 5304, 5580, 5889, 6246, 6576, 6804])

In [59]:
selected_neurons = [22,21,32,114,125,156,171,172]
calciam_traces = minian_results.C
sample_C = calciam_traces[selected_neurons]
sample_C = np.array(sample_C)[:,ms_sample_trialstarts[0]:ms_sample_trialstarts[-1]]

In [60]:

# Input video path (animal behavior video)
behavior_video_path = f"/Users/fgs/HMLworkplace/Arena_analysis/Data/behavior_videos/{mouse_pair}_{date}_highsat.mp4"

# Multiple neural signals (replace with your actual data)
fps = 30  # Frames per second (match the behavior video)
time = np.array(range(len(sample_C.T)))  # 10 seconds of data (for example)
neurons = sample_C  # Example signals

# Output video path
output_video_path = "/Users/fgs/HMLworkplace/Arena_analysis/Data/behavior_videos/test.mp4"

# Open behavior video
behavior_cap = cv2.VideoCapture(behavior_video_path)

# Get video properties
behavior_width = int(behavior_cap.get(cv2.CAP_PROP_FRAME_WIDTH))
behavior_height = int(behavior_cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
total_frames = int(behavior_cap.get(cv2.CAP_PROP_FRAME_COUNT))
print(behavior_height)
print(behavior_width)
# Define the time window of the behavior video you want to use (in seconds)
start_frame = led_frames_sample[0]
end_frame = start_frame + len(sample_C.T)
# Set the frame position to the start time
behavior_cap.set(cv2.CAP_PROP_POS_FRAMES, start_frame)

# Determine video dimensions for output
num_neurons = len(neurons)
neural_signal_width = 1024  # Width allocated for neural signals
neural_signal_height = behavior_height // num_neurons  # Height allocated per neuron
output_width = behavior_width + neural_signal_width
output_height = behavior_height

# Prepare output video writer
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
out = cv2.VideoWriter(output_video_path, fourcc, fps, (output_width, output_height))

# Fixed y-axis limits for neural signals
#y_min, y_max = np.min(sample_C), np.max(sample_C) # Adjust based on your data range

# Iterate through behavior video frames within the selected time range
frame_index = start_frame
ms_index = 0
while frame_index < end_frame:
    ret, frame = behavior_cap.read()
    behavior_frame = cv2.resize(frame, (2048, 2048))
    if not ret:
        break

    # Prepare a blank canvas for neural signals
    neural_canvas = np.zeros((behavior_height, neural_signal_width, 3), dtype=np.uint8)
    
    # Plot each neuron's signal in its allocated row
    for neuron_idx, signal in enumerate(neurons):
        plt.figure(figsize=(4, 2), facecolor='black')  # Smaller figure for each neuron
        start_idx = max(0, ms_index - 200)
        end_idx = ms_index
        plt.plot(time[start_idx:end_idx], signal[start_idx:end_idx], 'orange', lw=2)
        plt.scatter(time[ms_index], signal[ms_index], color='orange', s=50)
        y_min, y_max = np.min(sample_C[neuron_idx, :]), np.max(sample_C[neuron_idx, :]) 
        plt.ylim(y_min, y_max)
        plt.title(f'Neuron{selected_neurons[neuron_idx]}', loc='left', fontsize=10, color='white')
        plt.axis('off')
        
        # Save the plot as an image
        plot_path = f"neuron_{neuron_idx}_frame_{ms_index}.png"
        plt.tight_layout(pad=0)
        plt.savefig(plot_path, dpi=100, bbox_inches='tight', pad_inches=0)
        plt.close()

        # Load the image and resize it to fit the allocated row
        signal_img = cv2.imread(plot_path)
        signal_img = cv2.resize(signal_img, (neural_signal_width, neural_signal_height))
        
        # Place the signal image on the neural canvas
        start_y = neuron_idx * neural_signal_height
        end_y = start_y + neural_signal_height
        neural_canvas[start_y:end_y, :] = signal_img
        os.remove(plot_path)

    # Combine the behavior video frame and neural signals
    combined_frame = np.zeros((output_height, output_width, 3), dtype=np.uint8)
    combined_frame[:, :2048] = behavior_frame  # Place behavior video
    combined_frame[:, 2048:] = neural_canvas  # Place neural signals

    # Write the combined frame to the output video
    out.write(combined_frame)

    frame_index += 1
    ms_index += 1

# Release resources
behavior_cap.release()
out.release()

print(f"Video saved at {output_video_path}")


2048
2048
Video saved at /Users/fgs/HMLworkplace/Arena_analysis/Data/behavior_videos/test.mp4


In [61]:
bpod_data['SessionData']['RawEvents']['Trial'][0]['States']['BackToCenter']

[0.1, 3.331]

In [77]:
for i in range(8):
    print(i)


0
1
2
3
4
5
6
7
