# 2D gait analysis of top cam data

> Analysis of DLC tracking data of top camera recordings 

In [None]:
#| default_exp twoD/topcam

In [None]:
#| hide
from nbdev.showdoc import *

In [None]:
#| export

from typing import Tuple, List, Dict, Optional, Union


from gait_analysis.core import TrackedRecording

Notes:

- rename self.full_df_from_file to self.loaded_tracking_df

In [None]:
#| export

class Tracked2DRecording(TrackedRecording):
    """
    Very customized subclass of `TrackedRecording` that was designed to run the gait analysis
    on 2D tracking data obtained from a single camera with a top-down-view on the subject.    
    """

    
    def _load_remaining_metadata(self) -> None:
        self.fps = self._get_correct_fps()
        self.framerate = 1/self.fps
        self.metadata = self._retrieve_metadata(filename = filepath.name)
        

    def _get_correct_fps(self) -> int:
        if self.loaded_tracking_df.shape[0] > 25_000:
            fps = 80
        else:
            fps = 30
        return fps


    def _retrieve_metadata(self, filename: str)->Dict:
        """
        Very much dependent on the following file naming convention:
        LLL_FS-SS_YYMMDD_PPP_whatever.h5 (or .csv)
        Where:
            LLL: three digits mouse line code
            FS-SS: the subject ID (must start with capital F) with generation (e.g. F2 for second generation) and the mouse ID as two digits
            YYMMDD: the date of the recording as two digits year (YY), month (MM), and day (DD)
            PPP: the three letter string of the experimental paradigm
        For instance:
            196_F7-27_220826_OTT_whatever.h5
        """
        splits = filename.split('_')
        line_id, mouse_id, date, paradigm_id, cam_id = splits[0], splits[1], splits[2], splits[3][0:3], 'Top'
        self._check_metadata(metadata = (line_id, mouse_id, date, paradigm_id, cam_id))
        metadata = {'recording_date': self.recording_date, 
                    'animal': f'{self.mouse_line}_{self.mouse_id}', 
                    'paradigm': self.paradigm, 
                    'cam': self.cam_id}
        return metadata
    
    
    # ToDo - replace with something more generalizable that could be put to utils
    def _check_metadata(self, 
                        metadata = Tuple[str]
                       ) -> None: 
        animal_line, animal_id, recording_date, paradigm, cam_id = metadata[0], metadata[1], metadata[2], metadata[3], metadata[4]
        self.cam_id = cam_id
        if animal_line not in self.valid_mouse_lines:
            while True:
                entered_input = input(f'Mouse line for {self.filepath}')
                if entered_input in self.valid_mouse_lines:
                    self.mouse_line = entered_input
                    break
                else:
                    print(f'Entered mouse line does not match any of the defined mouse lines. \nPlease enter one of the following lines: {self.valid_mouse_lines}')
        else:
            self.mouse_line = animal_line
        if not animal_id.startswith('F'):
            while True:
                entered_input = input(f'Mouse ID for {self.filepath}')
                if entered_input.startswith('F'):
                    self.mouse_id = entered_input
                    break
                else:
                    print(f'Animal ID has to start with F. Example: F2-14')
        else:
            self.mouse_id = animal_id
        if paradigm not in self.valid_paradigms:
            while True:
                entered_input = input(f'Paradigm for {self.filepath}')
                if entered_input in self.valid_paradigms:
                    self.paradigm = entered_input
                    break
                else:
                    print(f'Entered paradigm does not match any of the defined paradigms. \nPlease enter one of the following paradigms: {self.valid_paradigms}')
        else:
            self.paradigm = paradigm
        try:
            int(recording_date)
            self.recording_date = recording_date
        except:
            while True:
                entered_input = input(f'Recording date for {self.filepath}')
                try:
                    int(recording_date)
                    self.recording_date = recording_date
                    break
                except:
                    print(f'Entered recording date has to be an integer in shape YYMMDD. Example: 220812')


        def preprocess(self,
                   marker_ids_to_compute_coverage: List[str]=['TailBase', 'Snout'], # List of marker_ids on base of which the tracking coverage will be computed
                   coverage_threshold: float=0.75, # If coverage of the above defined markers is lower than this threshold, the recording will be excluded from the analysis and therefore not further processed 
                   max_seconds_to_interpolate: float=0.5, # Maximum time interval in which consecutive nan´s will be interpolated
                   likelihood_threshold: float=0.5, # Minimum prediction likelihood of DLC that is required to acceppt predicted marker position as valid 
                   marker_ids_to_compute_center_of_gravity: List[str]=['TailBase', 'Snout'], # marker ids that will be used to compute the center of gravity
                   relative_maze_normalization_error_tolerance: float=0.25 # relative error that is tolerated when estimating the maze position for normalizing it´s position
                   ) -> None:
        initial_logs_to_add = {'critical_markers_to_compute_coverage': marker_ids_to_compute_coverage,
                               'coverage_threshold': coverage_threshold,
                               'max_seconds_to_interpolate': max_seconds_to_interpolate,
                               'min_likelihood_threshold': likelihood_threshold,
                               'center_of_gravity_based_on': marker_ids_to_compute_center_of_gravity, 
                               'relative_error_tolerance_corner_detection': relative_maze_normalization_error_tolerance}
        window_length = self._get_max_odd_n_frames_for_time_interval(fps = self.fps, time_interval = max_seconds_to_interpolate)
        marker_ids_to_preprocess = self._get_preprocessing_relevant_marker_ids(df = self.full_df_from_file)
        smoothed_df = self._smooth_tracked_coords_and_likelihood(marker_ids = marker_ids_to_preprocess, window_length = window_length, polyorder = 3)
        interpolated_df = self._interpolate_low_likelihood_intervals(df = smoothed_df, marker_ids = marker_ids_to_preprocess, max_interval_length = window_length)
        interpolated_df_with_cog = self._add_new_marker_derived_from_existing_markers(df = interpolated_df,
                                                                                      existing_markers = marker_ids_to_compute_center_of_gravity,
                                                                                      new_marker_id = 'CenterOfGravity',
                                                                                      likelihood_threshold = likelihood_threshold)
        preprocessed_df = self._interpolate_low_likelihood_intervals(df = interpolated_df_with_cog,
                                                                     marker_ids = ['CenterOfGravity'],
                                                                     max_interval_length = window_length)
        coverage_critical_markers = self._compute_coverage(df = preprocessed_df,
                                                           critical_marker_ids = marker_ids_to_compute_coverage,
                                                           likelihood_threshold = likelihood_threshold)
        initial_logs_to_add['coverage_critical_markers'] = coverage_critical_markers
        self._add_to_logs(logs_to_add = initial_logs_to_add)
        if coverage_critical_markers >= coverage_threshold:
            normalization_params = self._get_parameters_to_normalize_maze_coordinates(df = preprocessed_df,
                                                                                      relative_error_tolerance = relative_maze_normalization_error_tolerance)
            self.normalized_df = self._normalize_df(df = preprocessed_df, normalization_parameters = normalization_params)
            self.bodyparts = self._create_bodypart_objects()
            normalized_maze_corner_coords = self._get_normalized_maze_corners(normalization_parameters = normalization_params)
            coverage_center_of_gravity = self._compute_coverage(df = preprocessed_df,
                                                                critical_marker_ids = ['CenterOfGravity'],
                                                                likelihood_threshold = likelihood_threshold)
            additional_logs_to_add = {'coverage_CenterOfGravity': coverage_center_of_gravity}
            for key, value in normalization_params.items():
                additional_logs_to_add[key] = value
            for key, value in normalized_maze_corner_coords.items():
                additional_logs_to_add[f'normalized_{key}_coords'] = value
            self._add_to_logs(logs_to_add = additional_logs_to_add)
        

In [None]:
#| export
def test(): pass

In [None]:
#| hide
import nbdev; nbdev.nbdev_export()