# Dewan Lab EPM Analysis
## 0: Run once to create all needed directories at beginning of a project

## STEP 1: Always Execute! Load Libraries and User Settings
### STEP 1A: Import Libraries

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import cv2
import numpy as np
import pandas as pd

from pathlib import Path
from tqdm import tqdm, trange

from dewan_calcium.helpers import IO, parse_json, EPM
from dewan_calcium.helpers.project_folder import ProjectFolder

from dewan_calcium import plotting, deconv

from dewan_manual_curation import dewan_manual_curation

# from roipoly import MultiRoi  ## Must be installed from GitHub as the PyPi version is obsolete; pip install git+https://github.com/jdoepfert/roipoly.py

### STEP 1B: User Configurables

In [None]:
animal = 'ANIMAL_GOES_HERE'
date = 'DATE_GOES_HERE'

EXPERIMENT_TIME = 10

In [None]:
# Create Project Folder to Gather and Hold all the File Paths

#test_path = Path("/mnt/dev/Test_Data/Odor/VGLUT-20")  # On Fedora
test_path = Path("C:\\Projects\\Test_Data\\EPM\\VGLUT-20")  # On Desktop

project_folder = ProjectFolder(project_dir=test_path)
file_header = animal + '-' + date + '-'

In [None]:
# If this is the first time the project folder has been created,
# move the files to the appropriate directories and then run this cell, otherwise skip this cel
project_folder.get_data()

In [None]:
# Get settings from imaging session and display them for the user

gain, LED_power, focal_planes = parse_json.get_session_settings(project_folder.raw_data_dir.session_json_path)

print(f'Recording Gain: {gain}')
print(f'LED_power: {LED_power}')
print(f'LED_power: {focal_planes}')

## 2A: Import and pre-process the raw data

#### Copy DLC output .h5 file and labled video -> EPM_Analysis\RawData

In [None]:
# New way that we should be grabbing the data

#STEP 2A.1: LOAD DLC DATA
# Read in data for processing.  Needs Cell Traces, Odor List, and GPIO file.

tracked_points = pd.read_hdf(project_folder.raw_data_dir.points_h5_path)  # Load tracked points
labeled_video = cv2.VideoCapture(str(project_folder.raw_data_dir.labeled_video_path))  # Load Video
VIDEO_FPS = labeled_video.get(cv2.CAP_PROP_FPS)

In [None]:
#STEP 2A.2: LOAD INSCOPIX DATA

cell_trace_data = pd.read_csv(project_folder.inscopix_dir.cell_trace_path, engine='pyarrow')
GPIO_data = pd.read_csv(project_folder.inscopix_dir.GPIO_path, header=0, engine='pyarrow')
all_cell_props = pd.read_csv(project_folder.inscopix_dir.props_path, header=0, engine='pyarrow')
cell_outlines = parse_json.get_outline_coordinates(project_folder.inscopix_dir.contours_path)  # TODO: remove cell keys from the json function

In [None]:
# STEP 2A.2: PREPROCESSING

# STEP 2A.2.1: Drop the first row which contains all 'undecided' labels which is the Inscopix default label.
cell_trace_data = cell_trace_data.drop([0])

# STEP 2A.2.2: Force all dF/F values to be numbers and round times to 2 decimal places
cell_trace_data = cell_trace_data.apply(pd.to_numeric, errors='coerce')

# Set the times as the index so the listed data is all dF/F values
cell_trace_data[cell_trace_data.columns[0]] = cell_trace_data[cell_trace_data.columns[0]].round(2)
cell_trace_data = cell_trace_data.set_index(cell_trace_data.columns[0]) 

# STEP 2A.2.3: Remove spaces from column names and contents
cell_trace_data.columns = cell_trace_data.columns.str.replace(" ", "")
GPIO_data.columns = GPIO_data.columns.str.replace(" ", "")
GPIO_data['ChannelName'] = GPIO_data['ChannelName'].str.replace(" ", "")

# STEP 2A.2.4: Reduce properties to only include the cells with only one component
all_cell_props = all_cell_props[all_cell_props['NumComponents']==1]  # We only want cells that have one component
all_cell_props = all_cell_props.drop(columns='Status').reset_index(drop=True)
cell_names = all_cell_props['Name'].values

# STEP 2A.2.5: PARSE GPIO DATA
sniff_data = GPIO_data[GPIO_data['ChannelName'] == "GPIO-1"].reset_index(drop=True)
FV_data = GPIO_data[GPIO_data['ChannelName'] == "GPIO-2"].reset_index(drop=True)

# OPTIONAL UNUSED DATA
# running_data = GPIO_data[GPIO_data['ChannelName'] == "GPIO-3"]  # Running Wheel Data
# lick_data = GPIO_data[GPIO_data['ChannelName'] == "GPIO-4"]  # Lick Data


In [None]:
# STEP 2A.3: PREPROCESSING DLC Data

cols = ['mouse_x', 'mouse_y', 'mouse_p', 'led_x', 'led_y', 'led_p'] 
# Reset the column names to something sensible
tracked_points.columns = cols 

### STEP 2B: Manual Curation

In [None]:
# STEP 2B.2: Run ManualCuration GUI
curated_cells = dewan_manual_curation.launch_gui(project_folder_override=project_folder, cell_trace_data_override=cell_trace_data, cell_contours_override=cell_outlines, cell_names_override=cell_names)
if curated_cells is None:
    print('Error, no good cells selected!')

### STEP 2C: Apply Manual Curation Results and Additional Preprocessing

In [None]:
# STEP 2C.1: Filter all data by the GoodCells identified in ManualCuration
curated_cell_props = all_cell_props[all_cell_props['Name'].isin(curated_cells)].reset_index(drop=True)
curated_trace_data = cell_trace_data[curated_cells]
cell_names = curated_cell_props['Name']

### STEP 2D: Pickle and Save all preprocessed data

In [None]:
# Pickle the reorganized CellTraceData incase its needed later
# Saves Cell Traces, GPIO, Odor List, Sniff, FV data, Good Cell Properties, Good Cells, and the labeled max projection
# Once these have been saved, they don't need to be re-run on the same data again unless the data itself is changed

folder = project_folder.analysis_dir.preprocess_dir.path

IO.save_data_to_disk(curated_trace_data, 'curated_trace_data', file_header, folder)
IO.save_data_to_disk(GPIO_data, 'GPIO_data', file_header, folder)
IO.save_data_to_disk(FV_data, 'FV_data', file_header, folder)
IO.save_data_to_disk(curated_cell_props, 'curated_cell_props', file_header, folder)
IO.save_data_to_disk(sniff_data, 'sniff_table', file_header, folder)

IO.save_data_to_disk(tracked_points, 'tracked_points', file_header, folder)

In [None]:
# Opens the saved pickle files.  If the files have already been saved, code can be re-run
# starting from this point

folder = project_folder.analysis_dir.preprocess_dir.path


curated_trace_data = IO.load_data_from_disk('curated_trace_data', file_header, folder)
GPIO_data = IO.load_data_from_disk('GPIO_data', file_header, folder)
odor_data = IO.load_data_from_disk('odor_data', file_header, folder)
odor_list = IO.load_data_from_disk('odor_list', file_header, folder)
FV_data = IO.load_data_from_disk('FV_data', file_header, folder)
curated_cell_props = IO.load_data_from_disk('curated_cell_props', file_header, folder)
sniff_data = IO.load_data_from_disk('sniff_table', file_header, folder)
cell_names = curated_cell_props['Name']  # List of cells, referenced periodically

tracked_points = IO.load_data_from_disk('tracked_points', file_header, folder)

In [None]:
# There may be an instance where the model erroneously identified the LED for very short time periods
# find_led_start bins the possible LED on times (anywhere led_p > 0.98)
# We then find the bin with the largest size, which means it has the most frames where the LED is identified
# This is most likely the period where the experimenter turned on the LED
led_bins = np.array(EPM.find_led_start(tracked_points))

true_led_bin = np.argmax(np.subtract(led_bins[:, 1], led_bins[:,0]))

led_on = led_bins[true_led_bin][0] # Find first row where the LED is 'on'
experiment_frames = int(VIDEO_FPS * 60 * EXPERIMENT_TIME)  # FPS * 60 s/min * experiment length in minutes --> number of frames
end_frame = led_on + experiment_frames

good_points = tracked_points.iloc[led_on:end_frame] # Subset the frames from LED_ON -> ten minutes later
good_points.reset_index(drop=True, inplace=True) # Reset the index

# Get X, Y coordinates, cast to int, and combine them into tuples
head_x = good_points['mouse_x'].astype(int)
head_y = good_points['mouse_y'].astype(int)
coordinates = np.fromiter(zip(head_x, head_y), dtype=object)

In [None]:
# STEP 3A: Parses the final valve data to identify when the final valve is open vs when it is closed based on TTL pulse from Arduino.
FV_values = FV_data['Value'].astype(float).values # Get FV Values
num_values = len(FV_values)
valve_status = 0
FV_on_indexes = []
FV_off_indexes = []
for i in trange((num_values - 1), desc="Processing: "):
    valve_val_diff = FV_values[i + 1] - FV_values[i]

    if valve_status == 0:    # Start with valve off
        if valve_val_diff > 10000: # If the difference is a very large positive number, the valve opened
            FV_on_indexes.append(i + 1)
            valve_status = 1 # Set valve state to open
    else:
        if valve_val_diff < -10000: # If the difference is a very laarge negative number, the valve closed
            FV_off_indexes.append(i)
            valve_status = 0 # Set valve state to closed

FV_indexes = pd.DataFrame(zip(FV_on_indexes, FV_off_indexes), columns=['On', 'Off'])

In [None]:
experiment_start_index = FV_indexes['On'][0]
FV_timestamps = FV_data['Time(s)']
trial_start_time = FV_timestamps[experiment_start_index]  # Trial start time in unix time (s)
trial_end_time = trial_start_time + (EXPERIMENT_TIME * 60)  # End time is whatever the duration of the experiment was in minutes

cell_trace_times = curated_trace_data.index.values

cell_trace_on_index = np.where(cell_trace_times <= trial_start_time)[0][-1]
cell_trace_off_index = np.where(cell_trace_times <= trial_end_time)[0][-1]  # We can't overshoot otherwise the coordinate will not match, so we may drop a single frame

In [None]:
trimmed_trace_data = curated_trace_data.iloc[cell_trace_on_index:cell_trace_off_index, :]

trimmed_cell_trace_times = trimmed_trace_data.index.values
shifted_cell_trace_times = np.subtract(trimmed_cell_trace_times, trimmed_cell_trace_times[0])
rounded_cell_trace_times = np.round(shifted_cell_trace_times, 2)

trimmed_trace_data.index = rounded_cell_trace_times

good_points_index = good_points.index.values
good_points_time = np.divide(good_points_index, VIDEO_FPS)
good_points.index = good_points_time

In [None]:
## Align Cell Traces with the DLC Data
## Since the DLC data is typically recorded at 6X the rate as the neural data, there is typically multiple data points we can choose for the coordinate of a trace
## For simplicity, we will pick the coordinate that exactly matches the time point of the trace
## In the future we can do some averaging or picking the median, etc. 

trace_coordinate_indexes = []
good_points_index = good_points.index.values

for time in tqdm(trimmed_good_cell_trace_data.index.values):
    coordinate_index = np.where(good_points_index == time)[0]

    trace_coordinate_indexes.extend(coordinate_index)

trace_coordinates = coordinates[trace_coordinate_indexes]

trimmed_good_cell_trace_data.insert(0, 'Coordinate_Index', trace_coordinate_indexes)
trimmed_good_cell_trace_data.insert(1, 'Coordinates', trace_coordinates)

In [None]:
## Save the paired coordinates - trace data

folder = project_folder.analysis_dir.preprocess_dir.path

IO.save_data_to_disk(trimmed_good_cell_trace_data, 'trimmed_good_cell_trace_data', file_header, folder)

In [None]:
EPM.display_roi_instructions()

In [None]:
%matplotlib qt  
# Opens the matplotlib window using the QT backend

labeled_video.set(cv2.CAP_PROP_POS_FRAMES, led_on - 1) # Pull the frame that is our actual start
_, image = labeled_video.read()

arm_coordinates = EPM.get_arm_rois(image)

In [None]:
individual_regions, original_regions = EPM.get_region_polygons(arm_coordinates)  
# ([open_arm_1, open_arm_2, closed_arm_1, closed_arm_2, center_polygon], [open_arm, closed_arm, center])

%matplotlib inline
# Switch back to using inline displays
fig, ax = plotting.plot_epm_roi(original_regions, image)

In [None]:
## Save the ROIs and image
folder = project_folder.analysis_dir.figures_dir.path

image_path = folder.joinpath('EPM_ROI.pdf')

fig.savefig(str(image_path), dpi=600)

IO.save_data_to_disk(arm_coordinates, 'arm_coordinates', file_header, folder)
IO.save_data_to_disk(individual_regions, 'individual_regions', file_header, folder)
IO.save_data_to_disk(original_regions, 'original_regions', file_header, folder)


In [None]:
animal_coordinates = trimmed_good_cell_trace_data['Coordinates']
coordinate_locations, coordinate_indexes = EPM.get_coordinate_region(animal_coordinates, individual_regions)
coordinate_pairs = list(zip(animal_coordinates, coordinate_indexes))
distances = EPM.get_distances(individual_regions, coordinate_pairs)

trimmed_good_cell_trace_data.insert(2, 'Location', coordinate_locations)
trimmed_good_cell_trace_data.insert(3, 'Distance', distances)