In [None]:
# Import relevant modules

import pandas as pd
import datajoint as dj
import cv2
import os

In [None]:
def get_all_trial_video_frames(trial_video_file_path):
    cap = cv2.VideoCapture(trial_video_file_path)
    frames_list= []

    while cap.isOpened():
        ret, frame = cap.read()
        if not ret:
            break

        frames_list.append(frame)

    return frames_list

def aligned_trial_video_frames(trial_video_frames_indexes, frames_list):
    aligned_trial_video_frames = []
    for index_range in trial_video_frames_indexes:
        if index_range == []:
            continue
        if index_range[1] + 1 == len(frames_list):
            aligned_trial_video_frames.append(frames_list[index_range[0]:])
            continue
        aligned_trial_video_frames.append(frames_list[index_range[0]:index_range[1] + 1])

    return aligned_trial_video_frames


def get_trials_data_table_for_mouse_session(subject_id, session, camera_num = 3):
    # Get all relevant schemas
    tracking = dj.VirtualModule('TRACKING', 'arseny_learning_tracking')
    video_neural_schema = dj.VirtualModule('VIDEONEURAL', "JonathanStahl_VIDEONEURAL")
    key = {'subject_id': subject_id, 'session': session}
    tracking_trials = tracking.TrackingTrial
    video_neural_alignment_table = video_neural_schema.VideoNeuralAlignment
    return pd.DataFrame((tracking_trials * video_neural_alignment_table & key & {'tracking_device_id': camera_num}).fetch())


def get_trial_video_frames_groups(row, all_videos_path, subject_id, session_string, camera_num):
    trial_num = row["tracking_datafile_num"]
    video_file_name = f"video_cam_{camera_num}_v{trial_num:03d}.avi"
    trial_video_file_path = os.path.join(all_videos_path, subject_id, session_string, video_file_name)
    trial_frame_list = get_all_trial_video_frames(trial_video_file_path)
    trial_video_frames_indexes = row["trial_video_frames_indexes_groups"]
    return aligned_trial_video_frames(trial_video_frames_indexes , trial_frame_list)


def get_dff_table_for_mouse_session(subject_id, session):
    img = dj.VirtualModule('IMG', 'arseny_learning_imaging')
    key = {"subject_id" : subject_id, "session" : session, "session_epoch_type" : "behav_only"}
    ROIdeltaF = pd.DataFrame(((img.ROIdeltaF & key)).fetch())
    dff_trace_matrix = pd.DataFrame([x[0] for x in ROIdeltaF["dff_trace"]])
    return pd.concat([ROIdeltaF.drop(["dff_trace", "session_epoch_type", "subject_id", "session"], axis='columns'), dff_trace_matrix], axis=1)


def get_trial_neural_frames(dff_data, trial_neural_frames_indexes, trial_video_length, drop_frames_with_no_video = True):
    if drop_frames_with_no_video:
        if trial_video_length == None:
            raise Exception("Can't drop neural frames with no data if video length is not provided")
        trial_neural_frames_indexes = trial_neural_frames_indexes[:trial_video_length]
        neural_frames = dff_data[trial_neural_frames_indexes]
    else:
        neural_frames = dff_data[trial_neural_frames_indexes]
    
    return neural_frames
    

In [None]:
# selsct the subject id and session number
subject_id = "463189"
session = 1

In [None]:
# This function yields a dictionary with the neural frames and aligned video frames for each trial
# take_only_first_video_frame: You can set this True if you want to take only the first video frame for each neural frame
# drop_neural_frames_with_no_video: You can set this False if you want to keep neural frames that have no video frames associated with them

def get_session_trials_aligned_frames(take_only_first_video_frame = False, drop_neural_frames_with_no_video = True):
    session_string = f"session{session}"
    camera_num = 0
    all_videos_path = "D:\\admin\\SharedFolder\\Arseny_behavior_video"
    
    # Get data from DataJoint
    trials_data = get_trials_data_table_for_mouse_session(subject_id, session)
    dff_data = get_dff_table_for_mouse_session(subject_id, session)

    for index, row in trials_data.iterrows():
        trial_video_frames = get_trial_video_frames_groups(row, all_videos_path, subject_id, session_string, camera_num)
        trial_neural_frames = get_trial_neural_frames(dff_data, row["trial_neural_frames_indexes"], len(trial_video_frames), drop_neural_frames_with_no_video)
        
        if take_only_first_video_frame:
            trial_video_frames = [video_frames_group[0] for video_frames_group in trial_video_frames]
        
        new_row = {'trial_neural_frames': trial_neural_frames, 'trial_video_frames_groups': trial_video_frames}

        yield new_row
