In [1]:

import numpy as np
import os
import sys
sys.path.append(os.path.dirname(os.getcwd()))
from scipy.io import loadmat
import pandas as pd
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
from matplotlib.backends.backend_agg import FigureCanvasAgg
import cv2
from collections import deque
from pinkrigs_tools.dataset.query import load_data, queryCSV
from scipy.ndimage import gaussian_filter1d
import glob
from utils.pipeline import load_and_process_session

In [1]:

subject_id = 'AV043'
date = '2024-03-05'
exp_kwargs = {
    'subject': subject_id,
    'expDate': date,
    }


input_path = fr"\\znas\Lab\Share\Maja\labelled_DLC_videos\{subject_id}_{date}.mp4"
output_path = f"H:/Annotated_Videos/annotated_video_{subject_id}_{date}.mp4"




In [3]:
def setup_plots(oa_speed, wh_speed, dpi=80):
    
    # Create figures
    fig_oa, ax_oa = plt.subplots(figsize=(8, 3), dpi=dpi)
    fig_wh, ax_wh = plt.subplots(figsize=(8, 3), dpi=dpi)
    
    # Plot the main lines
    x_all = np.arange(len(oa_speed))
    line_oa, = ax_oa.plot(x_all, oa_speed, 'b-')
    line_wh, = ax_wh.plot(x_all, wh_speed, 'g-')
    
    # Set y limits
    y_lim_oa = np.nanmax(oa_speed) if len(oa_speed) > 0 else 1
    y_lim_wh = np.nanmax(wh_speed) if len(wh_speed) > 0 else 1
    ax_oa.set_ylim(0, y_lim_oa * 1.1)
    ax_wh.set_ylim(0, y_lim_wh * 1.1)
    
    # Create canvases for faster rendering
    canvas_oa = FigureCanvasAgg(fig_oa)
    canvas_wh = FigureCanvasAgg(fig_wh)
    
    # Return everything we need
    return (fig_oa, ax_oa, canvas_oa, line_oa, y_lim_oa,
            fig_wh, ax_wh, canvas_wh, line_wh, y_lim_wh)

def update_plot_window(ax_oa, ax_wh, line_oa, line_wh, y_lim_oa, y_lim_wh,
                      window_start, window_end, oa_onsets, oa_offsets, 
                      wh_onsets, wh_offsets):
    
    # Clear previous annotations but keep the main lines
    ax_oa.clear()  # Clear everything
    ax_oa.plot(line_oa.get_xdata(), line_oa.get_ydata(), 'b-')  # Redraw the line
    ax_oa.set_ylim(0, y_lim_oa * 1.1)
    
    ax_wh.clear()  # Clear everything
    ax_wh.plot(line_oa.get_xdata(), line_oa.get_ydata(), 'b-')  # Redraw the line
    ax_wh.set_ylim(0, y_lim_oa * 1.1)
    
    # Update x limits
    ax_oa.set_xlim(window_start, window_end)
    ax_wh.set_xlim(window_start, window_end)
    
    # Filter indices for current window
    oa_onset_indices = oa_onsets[(oa_onsets >= window_start) & (oa_onsets <= window_end)]
    oa_offset_indices = oa_offsets[(oa_offsets >= window_start) & (oa_offsets <= window_end)]
    wh_onset_indices = wh_onsets[(wh_onsets >= window_start) & (wh_onsets <= window_end)]
    wh_offset_indices = wh_offsets[(wh_offsets >= window_start) & (wh_offsets <= window_end)]
    
    # Add annotations for OA
    for onset_idx in oa_onset_indices:
        ax_oa.axvline(onset_idx, color='black', linestyle='--', alpha=0.7)
        oa_distances = oa_offset_indices - onset_idx
        valid_offsets = np.where(oa_distances > 0)[0]
        
        if len(valid_offsets) > 0:
            next_offset_idx = oa_offset_indices[valid_offsets[0]]
            x = np.arange(onset_idx, next_offset_idx + 1)
            ax_oa.fill_between(x, 0, y_lim_oa, alpha=0.3, color='green')
    
    for offset_idx in oa_offset_indices:
        ax_oa.axvline(offset_idx, color='black', linestyle='--', alpha=0.7)
    
    # Add annotations for WH
    for onset_idx in wh_onset_indices:
        ax_wh.axvline(onset_idx, color='black', linestyle='--', alpha=0.7)
        wh_distances = wh_offset_indices - onset_idx
        valid_offsets = np.where(wh_distances > 0)[0]
        
        if len(valid_offsets) > 0:
            next_offset_idx = wh_offset_indices[valid_offsets[0]]
            x = np.arange(onset_idx, next_offset_idx + 1)
            ax_wh.fill_between(x, 0, y_lim_wh, alpha=0.3, color='purple')
    
    for offset_idx in wh_offset_indices:
        ax_wh.axvline(offset_idx, color='black', linestyle='--', alpha=0.7)

def get_plot_images(canvas_oa, canvas_wh):
    """REPLACES fig_to_image: Uses canvas for faster conversion"""
    
    canvas_oa.draw()
    canvas_wh.draw()
    
    buf_oa = np.frombuffer(canvas_oa.tostring_rgb(), dtype=np.uint8)
    buf_wh = np.frombuffer(canvas_wh.tostring_rgb(), dtype=np.uint8)
    
    w_oa, h_oa = canvas_oa.get_width_height()
    w_wh, h_wh = canvas_wh.get_width_height()
    
    img_oa = buf_oa.reshape(h_oa, w_oa, 3)
    img_wh = buf_wh.reshape(h_wh, w_wh, 3)
    
    img_oa_bgr = cv2.cvtColor(img_oa, cv2.COLOR_RGB2BGR)
    img_wh_bgr = cv2.cvtColor(img_wh, cv2.COLOR_RGB2BGR)
    
    return img_oa_bgr, img_wh_bgr

In [4]:
def read_video_frames(video_path):
    cap = cv2.VideoCapture(video_path)
    
    if not cap.isOpened():
        raise ValueError(f"Could not open video: {video_path}")
    
    frame_idx = 0
    while True:
        ret, frame = cap.read()
        if not ret:
            break
            
        yield frame_idx, frame
        frame_idx += 1
    
    cap.release()




In [5]:
def annotate_frame(frame, frame_idx, plot_img_oa, plot_img_wh, 
                  bodypart_x, bodypart_y, median_x, median_y, roi_x, roi_y, radius):
    
    annotated_video_frame = frame.copy()
    
   
    if bodypart_x is not None and bodypart_y is not None and frame_idx < len(bodypart_x):
       
        for x_pos, y_pos in zip(bodypart_x[frame_idx], bodypart_y[frame_idx]):
            if not (np.isnan(x_pos) or np.isnan(y_pos)): 
                cv2.circle(annotated_video_frame, 
                         (int(x_pos), int(y_pos)), 3, (255, 0, 0), -1)
    
        if not (np.isnan(median_x[frame_idx]) or np.isnan(median_y[frame_idx])): 
                cv2.circle(annotated_video_frame, 
                         (int(median_x[frame_idx]), int(median_y[frame_idx])), 3, (255, 0, 0), -1)
    
        
        if not (np.isnan(roi_x) or np.isnan(roi_y)):
            cv2.circle(annotated_video_frame, 
                     (int(roi_x), int(roi_y)), int(radius), (0, 255, 0), 2)
            
            cv2.circle(annotated_video_frame, 
                     (int(roi_x), int(roi_y)), 5, (0, 0, 255), -1)
    
    
    plots_combined = np.vstack([plot_img_oa, plot_img_wh])
    processed_frame = np.vstack([annotated_video_frame, plots_combined])
    
    return processed_frame


def write_annotated_video(input_video_path, output_video_path, 
                         oa_speed, wh_speed, oa_onsets, oa_offsets, wh_onsets, wh_offsets, 
                         bodypart_x, bodypart_y, median_x, median_y,
                         roi_x, roi_y, radius,  start_frame=0, end_frame=None, window_size=200):
    
    cap = cv2.VideoCapture(input_video_path)
    fps = cap.get(cv2.CAP_PROP_FPS)
    total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
    cap.release()

    
    if end_frame is None:
        end_frame = total_frames

    (fig_oa, ax_oa, canvas_oa, line_oa, y_lim_oa,
     fig_wh, ax_wh, canvas_wh, line_wh, y_lim_wh) = setup_plots(oa_speed, wh_speed)
    
    writer = None
    
    try:
        for frame_idx, frame in read_video_frames(input_video_path):
            if frame_idx < start_frame:
                continue
            if frame_idx >= end_frame:
                break

            plot_idx = frame_idx - start_frame
            
            window_start = max(0, frame_idx - window_size)
            window_end = frame_idx
            
            
            update_plot_window(ax_oa, ax_wh, line_oa, line_wh, y_lim_oa, y_lim_wh,
                             window_start, window_end, oa_onsets, oa_offsets, 
                             wh_onsets, wh_offsets)
            plot_img_oa, plot_img_wh = get_plot_images(canvas_oa, canvas_wh)
            
           
            annotated_frame = annotate_frame(
                frame, frame_idx, plot_img_oa, plot_img_wh,
                bodypart_x, bodypart_y, median_x, median_y, 
                roi_x, roi_y, radius
            )

            if writer is None:
                h, w = annotated_frame.shape[:2]
                fourcc = cv2.VideoWriter_fourcc(*'mp4v')
                writer = cv2.VideoWriter(output_video_path, fourcc, fps, (w, h))
            
            writer.write(annotated_frame)
            
            if frame_idx % 100 == 0:
                print(f"Processed frame {frame_idx}/{end_frame}")
    
    finally:
        if writer:
            writer.release()
        plt.close(fig_oa)
        plt.close(fig_wh)

In [6]:

    

sesh = load_and_process_session(subject_id, date, target_freq=60)


neck: 482 NaN frames for x (0.3%)
neck: 482 NaN frames for y (0.3%)
mid_back: 1371 NaN frames for x (0.9%)
mid_back: 1371 NaN frames for y (0.9%)
mouse_center: 206 NaN frames for x (0.1%)
mouse_center: 206 NaN frames for y (0.1%)
mid_backend: 460 NaN frames for x (0.3%)
mid_backend: 460 NaN frames for y (0.3%)
mid_backend2: 1770 NaN frames for x (1.2%)
mid_backend2: 1770 NaN frames for y (1.2%)
mid_backend3: 2292 NaN frames for x (1.5%)
mid_backend3: 2292 NaN frames for y (1.5%)
Processing frame 10/10
151000
150275
151000
No probe1 data found


In [7]:
print(f"bodypart_x type: {type(sesh.bodypart_x)}")
print(f"bodypart_y type: {type(sesh.bodypart_y)}")
print(f"Shape: {sesh.bodypart_x.shape}")
body_x = sesh.bodypart_x.values  
body_y = sesh.bodypart_y.values  

bodypart_x type: <class 'pandas.core.frame.DataFrame'>
bodypart_y type: <class 'pandas.core.frame.DataFrame'>
Shape: (151000, 6)


In [9]:

write_annotated_video(input_path, output_path, 
                         sesh.oa_speed, sesh.wh_speed, sesh.oa_onsets, sesh.oa_offsets, sesh.wh_onsets, sesh.wh_offsets,
                         body_x, body_y, sesh.x, sesh.y,
                         sesh.roi_x, sesh.roi_y, sesh.radius, start_frame=sesh.exp_onset, end_frame=sesh.exp_onset+5000)

Processed frame 800/5726
Processed frame 900/5726
Processed frame 1000/5726
Processed frame 1100/5726
Processed frame 1200/5726
Processed frame 1300/5726
Processed frame 1400/5726
Processed frame 1500/5726
Processed frame 1600/5726
Processed frame 1700/5726
Processed frame 1800/5726
Processed frame 1900/5726
Processed frame 2000/5726
Processed frame 2100/5726
Processed frame 2200/5726
Processed frame 2300/5726
Processed frame 2400/5726
Processed frame 2500/5726
Processed frame 2600/5726
Processed frame 2700/5726
Processed frame 2800/5726
Processed frame 2900/5726
Processed frame 3000/5726
Processed frame 3100/5726
Processed frame 3200/5726
Processed frame 3300/5726
Processed frame 3400/5726
Processed frame 3500/5726
Processed frame 3600/5726
Processed frame 3700/5726
Processed frame 3800/5726
Processed frame 3900/5726
Processed frame 4000/5726
Processed frame 4100/5726
Processed frame 4200/5726
Processed frame 4300/5726
Processed frame 4400/5726
Processed frame 4500/5726
Processed fram