In [None]:
# Standard library imports
import os
from pathlib import Path
import subprocess
from io import BytesIO

# Third-party library imports
import cv2
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.patches import Rectangle
from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas
import seaborn as sns
import requests
from PIL import Image

# Custom library imports
from aind_vr_foraging_analysis.utils.parsing import data_access

sns.set_context("talk")

## Import the dataset

In [None]:
## This section is to find the data in vast
date_string = "2025-4-2"
mouse = '789918'

session_paths = data_access.find_sessions_relative_to_date(
    mouse=mouse,
    date_string=date_string,
    when='on'
)

for session_path in session_paths:
    try:
        all_epochs, stream_data, data = data_access.load_session(
            session_path
        )
        reward_sites = all_epochs.loc[all_epochs['label'] == 'OdorSite']
    except Exception as e:
        print(f"Error loading {session_path.name}: {e}")

In [None]:
# Alternatively, you can specify the session path directly if you know it
#The following are examples that we used in the past

# session_path = Path(r"C:\Users\tiffany.ona\Downloads\789918_2025-04-02T203144Z")
# frame_skip = 8  # Skip frames for performance
# start_time = 0  # seconds
# end_time = 15   # seconds
# window = 2      # time window shown on plot in seconds

# session_path = Path(r"C:\Users\tiffany.ona\Downloads\789924_2025-04-23T183014Z")
# frame_skip = 8  # Skip frames for performance
# start_time = 275  # seconds
# end_time = 314   # seconds
# window = 5      # time window shown on plot in seconds

# session_path = Path(r"C:\Users\tiffany.ona\Downloads\789907_2025-04-24T181053Z")
# frame_skip = 8  # Skip frames for performance
# start_time = 1457  # seconds
# end_time = 1483   # seconds
# window = 5  

# session_path = Path(r"C:\Users\tiffany.ona\Downloads\789911_2025-04-24T201010Z")
# frame_skip = 8  # Skip frames for performance
# start_time = 3350  # seconds
# end_time = 3386   # seconds
# window = 5  

all_epochs, stream_data, data = data_access.load_session(
    session_path
)

## Main loop for generating the video

In [None]:
video_path = os.path.join(session_path, "behavior-videos")

# === Settings ===
camera_view = "SideCamera"
output_path = f"synced_video_output_{session_path.name}.mp4"

video_paths = {
    cam: os.path.join(video_path, cam, "video.mp4") for cam in ["SideCamera", "FrontCamera", "FaceCamera"]
}

metadata_paths = {
    cam: os.path.join(video_path, cam, "metadata.csv") for cam in ["SideCamera", "FrontCamera", "FaceCamera"]
}

# Load metadata
try:
    sync = pd.read_csv(metadata_paths[camera_view])
except FileNotFoundError:
    print(f"❌ Metadata file not found for {camera_view}.")
    exit(1)
    
first_time = sync['ReferenceTime'].iloc[0]
sync['Frame'] = sync['CameraFrameNumber'] - sync['CameraFrameNumber'].iloc[0]
sync.ReferenceTime = sync.ReferenceTime - first_time
sync = sync[(sync['ReferenceTime'] >= start_time) & (sync['ReferenceTime'] <= end_time)].reset_index(drop=True)

# === Load data ===
sniff = stream_data.breathing.copy()
speed = stream_data.encoder_data.copy()
tone = stream_data.choice_feedback.copy()
lick = stream_data.lick_onset.copy()
reward = stream_data.give_reward.copy()
events_df = all_epochs.copy()

# Adjust continuous and discrete data to match sync time
sniff.index = sniff.index - first_time
mean = np.mean(sniff.data)
std = np.std(sniff.data)
sniff.data = (sniff.data - mean) / std

tone.index = tone.index - first_time
lick.index = lick.index - first_time
reward.index = reward.index - first_time

speed.index = speed.index - first_time

events_df.index = events_df.index - first_time
events_df['stop_time'] = events_df.stop_time - first_time

# Open video captures
cap_face = cv2.VideoCapture(video_paths['FaceCamera'])
cap_side = cv2.VideoCapture(video_paths['SideCamera'])
fps = cap_side.get(cv2.CAP_PROP_FPS) / frame_skip
frame_w = int(cap_side.get(cv2.CAP_PROP_FRAME_WIDTH))
frame_h = int(cap_side.get(cv2.CAP_PROP_FRAME_HEIGHT))

# Plot size
plot_w, plot_h = 2 * frame_w, frame_h

# Set up matplotlib figure
fig, (ax3, ax1, ax2) = plt.subplots(3, 1, figsize=(22, 8), gridspec_kw={'height_ratios': [1, 1, 2]}, dpi=100)
canvas = FigureCanvas(fig)
sniff_line, = ax1.plot([], [], 'k-')
speed_line, = ax2.plot([], [], 'k-')
ax1.set_ylabel('Sniffing (a/u)')
ax2.set_ylabel('Velocity (m/s)')
ax1.set_ylim(-3.5,np.max(sniff.data)+0.2)
ax1.xaxis.set_visible(False)
ax2.set_ylim(0, speed['filtered_velocity'].max())
ax2.set_xlabel('Time (s)')

vline1 = ax1.axvline(x=start_time, color='crimson', linestyle='solid')
vline2 = ax2.axvline(x=start_time, color='crimson', linestyle='solid')

# === Draw all event spans once ===
for i, e in events_df.iterrows():
    start, end = i, e['stop_time']
    if end < start_time - window or start > end_time + window:
        continue  # Skip out-of-range events
    if e['label'] == 'InterSite': color = '#808080'
    elif e['label'] == 'InterPatch': color = '#b3b3b3'
    elif e['label'] == 'OdorSite': 
        color = '#d95f02'
    else: continue  # Skip unknown labels

    ax1.axvspan(start, end, color=color, alpha=0.6)
    ax2.axvspan(start, end, color=color, alpha=0.6)

licks = ax2.scatter([], [], color='black', marker='|', s=500)
tones = ax2.scatter([], [], color='crimson', marker='s', s=500)
rewards = ax2.scatter([], [], color='steelblue', marker='o', s=500)

## === Draw the visual corridor ===
session_info = data['config'].streams.session_input.data

# ------------------ Textures ---------------------
remote_textures_root = f"https://github.com/AllenNeuralDynamics/Aind.Behavior.VrForaging/tree/{session_info['commit_hash']}/src/Textures"
texture_name = "Floor.jpg"
raw_url = remote_textures_root.replace("tree", "raw") + f"/{texture_name}"
response = requests.get(raw_url)
response.raise_for_status()  # Raise an error if the request failed
image = np.array(Image.open(BytesIO(response.content)).rotate(90, expand=True))
image = image[::5, ::5, :]  # Downsample
texture_size = image.shape[1]

# Settings
visual_corridor_scale = 5 * (end_time-start_time)  # cm per texture repeat

# Get position values at the start and end of video
current_position = data['operation_control'].streams.CurrentPosition.data.copy()
current_position.index = current_position.index - first_time
start_pos = current_position.loc[current_position.index >= start_time-window/2].iloc[0]['Position']
end_pos = current_position.loc[current_position.index >= end_time+window/2].iloc[0]['Position']
win_abs = (start_pos, end_pos)
print(win_abs)
if win_abs[0] == win_abs[1]:
    print("❌ Start and end positions are the same. Exiting.")
    exit(1)
    
# Compute fractional overlap at the start and end
first_texture_idx = (1 - ((start_pos / visual_corridor_scale) % 1)) * texture_size
last_texture_idx = ((end_pos / visual_corridor_scale) % 1) * texture_size

# Clip the first and last texture pieces
first_texture = image[:, int(first_texture_idx):, :]
last_texture = image[:, :int(last_texture_idx), :]

# Compute how many full textures are in between
n_textures = int(np.floor((end_pos - start_pos) / visual_corridor_scale))

# Repeat the full textures
middle_textures = np.tile(image, (1, n_textures, 1))

# Combine them
stitched_image = np.concatenate((first_texture, middle_textures, last_texture), axis=1)

# === OPTIONAL: crop to the exact pixel width corresponding to spatial range ===
# Compute the expected width in pixels to match spatial coverage
spatial_width = end_pos - start_pos
pixels_per_cm = image.shape[1] / visual_corridor_scale
expected_pixel_width = int(spatial_width * pixels_per_cm)

# Crop to match exactly
stitched_image = stitched_image[:, :expected_pixel_width, :]

# Function to map position to pixel
texture_current_position = lambda pos: (pos - start_pos) * pixels_per_cm

events_df['stop_position'] = events_df['start_position'] + events_df['length']
events_df = events_df.iloc[:-1]

# Make a copy of the original stitched image to apply gray regions
gray_overlay_image = stitched_image.copy()

for _, row in events_df.iterrows():
    if row['label'] not in ['InterSite', 'OdorSite']:
        continue
    
    if row['stop_position'] < win_abs[0] or row['start_position'] > win_abs[1]:
        continue  # Fully outside

    start_grey = max(row['start_position'], win_abs[0])
    stop_grey  = min(row['stop_position'], win_abs[1])
    
    if start_grey < win_abs[0] or stop_grey > win_abs[1]:
        continue  # Skip regions completely outside the window
    
    pos_gray_start = texture_current_position(start_grey)
    pos_gray_end   = texture_current_position(stop_grey)
    FACTOR = 0.3

    start_idx = int(pos_gray_start)
    end_idx   = int(pos_gray_end)

    roi = gray_overlay_image[:, start_idx:end_idx, :]
    roi = np.clip((roi - 127.5) * FACTOR + 127.5, 0, 255).astype(np.uint8)
    gray_overlay_image[:, start_idx:end_idx, :] = roi

# Finally, assign back to stitched_image if needed:
stitched_image = gray_overlay_image
ax3.imshow(stitched_image)

# Initial rectangle mask (start by covering everything to the right of current position)
start_time_pos = current_position.loc[current_position.index >= start_time].iloc[0]['Position']
initial_pos = texture_current_position(start_time_pos)
mask_rect = Rectangle(
    (initial_pos, 0),
    stitched_image.shape[1] - initial_pos,
    stitched_image.shape[0],
    facecolor="white",
    edgecolor="white",
    zorder=10 
)
ax3.add_patch(mask_rect)

vline3 = ax3.axvline(x=initial_pos, color='crimson', linestyle='solid', linewidth=5, zorder=100)

ax3.set_xlabel('Position (cm)')
ax3.axis("off")

sns.despine()
plt.tight_layout()

# Video writer
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
out = cv2.VideoWriter(output_path, fourcc, fps, (plot_w, plot_h * 2))

# === Main loop ===
for i, row in sync.iterrows():
    if i % frame_skip != 0:
        continue
    
    print(np.around(i/len(sync), 2)*100, end="\r")
    current_time = row['ReferenceTime']
    frame_num = int(row['Frame'])

    # Set video position
    cap_face.set(cv2.CAP_PROP_POS_FRAMES, frame_num)
    cap_side.set(cv2.CAP_PROP_POS_FRAMES, frame_num)
    ret_face, frame_face = cap_face.read()
    ret_side, frame_side = cap_side.read()

    if not ret_face or not ret_side:
        print(f"❌ Failed to read frame {frame_num}")
        break

    frame_face = cv2.resize(frame_face, (frame_w, frame_h))
    frame_side = cv2.resize(frame_side, (frame_w, frame_h))
    # === Check for OdorSite ===
    ongoing_odor_events = events_df[
        (events_df['label'] == 'OdorSite') &
        (events_df.index <= current_time) &
        (events_df['stop_time'] >= current_time)
    ]

    if not ongoing_odor_events.empty:
        square_size = 50  # Size of the square in pixels
        square_color = (0, 165, 255)  # Orange in BGR
        square_thickness = -1  # Filled

        top_left = (10, 10)
        bottom_right = (top_left[0] + square_size, top_left[1] + square_size)
        cv2.rectangle(frame_side, top_left, bottom_right, square_color, square_thickness)
    combined_video = np.hstack((frame_side, frame_face))

    # Slice data
    sniff_slice = sniff[(sniff.index >= (current_time - window/2)) & (sniff.index <= (current_time + window/2))]
    speed_slice = speed[(speed.index >= (current_time - window/2)) & (speed.index <= (current_time + window/2))]

    # Update plots
    sniff_line.set_data(sniff_slice.index, sniff_slice['data'])
    speed_line.set_data(speed_slice.index, speed_slice['filtered_velocity'])
    ax1.set_xlim(current_time - window/2, current_time + window/2)
    ax2.set_xlim(current_time - window/2, current_time + window/2)
    vline1.set_xdata([current_time])
    vline2.set_xdata([current_time])

    lick_slice = lick[(lick.index >= (current_time - window/2)) & (lick.index <= (current_time + window/2))]
    y_positions = np.ones(len(lick_slice)) * 20  # Adjust y-axis location if needed
    licks.set_offsets(np.column_stack((lick_slice.index, y_positions)))

    tone_slice = tone[(tone.index >= (current_time - window/2)) & (tone.index <= (current_time + window/2))]
    y_positions = np.ones(len(tone_slice)) * 30  # Adjust y-axis location if needed
    tones.set_offsets(np.column_stack((tone_slice.index, y_positions)))

    reward_slice = reward[(reward.index >= (current_time - window/2)) & (reward.index <= (current_time + window/2))]
    y_positions = np.ones(len(reward_slice)) * 30  # Adjust y-axis location if needed
    rewards.set_offsets(np.column_stack((reward_slice.index, y_positions)))
    
    # Get current position of the animal
    if current_time in current_position.index:
        pos = current_position.loc[current_time]['Position']
    else:
        pos = current_position.iloc[current_position.index.get_indexer([current_time], method='nearest')[0]]['Position']
    
    # Update rectangle mask to match current position
    mask_x = texture_current_position(pos)
    mask_rect.set_x(mask_x)
    mask_rect.set_width(stitched_image.shape[1] - mask_x)
    vline3.set_xdata([mask_x])
    
    canvas.draw()
    buf = canvas.buffer_rgba()
    plot_img = np.asarray(buf)[:, :, :3].copy()
    plot_bgr = cv2.cvtColor(plot_img, cv2.COLOR_RGB2BGR)
    plot_resized = cv2.resize(plot_bgr, (plot_w, plot_h))
    final_frame = np.vstack((combined_video, plot_resized))
    out.write(final_frame)

# === Cleanup ===
cap_face.release()
cap_side.release()
out.release()
plt.close(fig)

print("Compressing video...")
subprocess.run([
    "ffmpeg", "-y",
    "-i", output_path,
    "-vcodec", "libx264",
    "-crf", "23",  # Lower = better quality, larger file. Try 23–28
    "-preset", "slow",  # or "fast", "medium"
    f"compressed_{output_path}"
])

os.remove(output_path)

print(f"✅ Video saved to compressed_{output_path}")

