In [1]:
# Import relevant modules

import datajoint as dj

In [None]:
# Get all relevant schemas

img = dj.VirtualModule('IMG', 'arseny_learning_imaging')
tracking = dj.VirtualModule('TRACKING', 'arseny_learning_tracking')

In [45]:
def get_behav_epoch_neural_timestamps(key): # include or not last frame?
    
    # Get all session timestamps
    session_neural_timestamps_data = (img.FrameTime & key).fetch1()
    session_neural_timestamps = session_neural_timestamps_data["frame_timestamps"][0]

    # Get behav epoch timestamps
    session_epoch_frame_data = (img.SessionEpochFrame & key).fetch1()
    behav_epoch_neural_start_frame = int(session_epoch_frame_data["session_epoch_start_frame"])
    behav_epoch_neural_end_frame  = int(session_epoch_frame_data["session_epoch_end_frame"])
    behav_epoch_neural_timestamps = session_neural_timestamps[behav_epoch_neural_start_frame : behav_epoch_neural_end_frame] # include or not last frame?
    
    return behav_epoch_neural_timestamps
    
def get_trial_neural_timestamps_and_indexes(key, behav_epoch_neural_timestamps):
    
    # Get trial start and end frames' index
    trial_neural_frames_data = (img.FrameStartTrial & key).fetch1()
    trial_neural_start_frame = int(trial_neural_frames_data["session_epoch_trial_start_frame"])
    trial_neural_end_frame  = int(trial_neural_frames_data["session_epoch_trial_end_frame"])

    # Get trial timestamps (relative to trial start) and indexes (relative to session epoch start)
    trial_neural_timestamps = behav_epoch_neural_timestamps[trial_neural_start_frame : trial_neural_end_frame] # include or not last frame?
    trial_neural_timestamps_zero = trial_neural_timestamps - trial_neural_timestamps[0]
    trial_neural_frames_indexes = [i for i in range(trial_neural_start_frame, trial_neural_end_frame)]

    return trial_neural_timestamps_zero, trial_neural_frames_indexes


def get_trial_timestamps_and_indexes(key):
    
    # Get video timestamps
    trial_video_data = (tracking.TrackingTrial & key & {'tracking_device_id': 3}).fetch1() # the tracking device number is arbitrary, they both contain same data
    trial_video_num_of_frames = trial_video_data["tracking_num_samples"]
    trial_video_frame_rate = float(trial_video_data["tracking_sampling_rate"])
    trial_video_start_time = float(trial_video_data["tracking_start_time"])
    trial_video_timestamps = [trial_video_start_time + (i/trial_video_frame_rate) for i in range(trial_video_num_of_frames)]

    return trial_video_timestamps


def get_grouped_trial_video_timestamps_and_indexes_by_neural_timestamps(trial_video_timestamps, trial_neural_timestamps_zero):

    # Group trial video timestamps and indexes
    trial_video_frames_indexes_groups = []
    trial_video_timestamps_groups = []
    for i in range(len(trial_neural_timestamps_zero) - 1):
        lower, upper = trial_neural_timestamps_zero[i], trial_neural_timestamps_zero[i + 1]
        # Get values from list_a that fall within (lower, upper)
        timestamps_group = [x for x in trial_video_timestamps if lower <= x < upper]
        indexes_group = [idx for idx, val in enumerate(trial_video_timestamps) if lower <= val < upper]
        if len(timestamps_group) > 0:
            timestamps_group_edges = [timestamps_group[0], timestamps_group[-1]]
            indexes_group_edges = [indexes_group[0], indexes_group[-1]]
        else:
            timestamps_group_edges = []
            indexes_group_edges = []
        trial_video_timestamps_groups.append(timestamps_group_edges)
        trial_video_frames_indexes_groups.append(indexes_group_edges)

    # Add last group - video timestamps larger than last neural timestamp
    upper = trial_neural_timestamps_zero[-1]
    timestamps_group = [x for x in trial_video_timestamps if upper <= x]
    indexes_group = [idx for idx, val in enumerate(trial_video_timestamps) if upper <= val]
    if len(timestamps_group) > 0:
        timestamps_group_edges = [timestamps_group[0], timestamps_group[-1]]
        indexes_group_edges = [indexes_group[0], indexes_group[-1]]
    else:
        timestamps_group_edges = []
        indexes_group_edges = []
    trial_video_timestamps_groups.append(timestamps_group_edges)
    trial_video_frames_indexes_groups.append(indexes_group_edges)

    return trial_video_timestamps_groups, trial_video_frames_indexes_groups

def get_key_to_insert(key, trial_neural_frames_indexes, trial_video_frames_indexes_groups, trial_neural_timestamps_zero, trial_video_timestamps_groups):
    key["trial_neural_frames_indexes"] = trial_neural_frames_indexes
    key["trial_video_frames_indexes_groups"] = trial_video_frames_indexes_groups
    key["trial_neural_timestamps"] = trial_neural_timestamps_zero
    key["trial_video_timestamps_groups"] = trial_video_timestamps_groups

    return key


In [None]:
# Table definition
schema = dj.Schema(f"{dj.config['database.user']}_VIDEONEURAL") 

@schema
class VideoNeuralAlignment(dj.Computed):
    definition = """
    -> img.FrameStartTrial
    ---
    trial_neural_frames_indexes: blob              # (frames) Relative to session epoch start
    trial_video_frames_indexes_groups: blob        # (frames) Relative to trial start
    trial_neural_timestamps: blob                  # (s) Relative to trial start
    trial_video_timestamps_groups: blob            # (s) Relative to trial start
    """
    
    def make(self, key):
        if key["subject_id"] not in [464724, 464725, 463189, 463190]:
            return
        
        behav_epoch_neural_timestamps = get_behav_epoch_neural_timestamps(key)
        trial_neural_timestamps_zero, trial_neural_frames_indexes = get_trial_neural_timestamps_and_indexes(key, behav_epoch_neural_timestamps)
        trial_video_timestamps = get_trial_timestamps_and_indexes(key)
        trial_video_timestamps_groups, trial_video_frames_indexes_groups = get_grouped_trial_video_timestamps_and_indexes_by_neural_timestamps(trial_video_timestamps, trial_neural_timestamps_zero)
        key = get_key_to_insert(key, trial_neural_frames_indexes, trial_video_frames_indexes_groups, trial_neural_timestamps_zero, trial_video_timestamps_groups)

        self.insert1(key)

In [None]:
VideoNeuralAlignment.populate(limit = 1000)