In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import glob
import os

In [None]:
%run livecricketcapture_functions.py

In [None]:
GF_files = glob.glob('GermFree(GF)GroupA/*.xlsx')
CV_files = glob.glob('Control(CV)GroupC/*.xlsx')

In [None]:
for file in GF_files:
    group, day, mouse_id, sex, trial = parse_filename(file)
    if sex == 'M' and mouse_id == '1' and day == '1':
        print(f'group: {group}, day: {day}, mouse_id: {mouse_id}, sex: {sex}, trial: {trial}')
        print(file)
        print('--------------------------------')

In [None]:
#Compute all metrics across all mice and all days takes a good 10 minutes
GF_data = process_gf_cv_mouse_experiment(GF_files)
CV_data = process_gf_cv_mouse_experiment(CV_files)

In [None]:
#Average across groups and days
GF_average_data = average_gf_cv_mouse_experiment_data(GF_data)
CV_average_data = average_gf_cv_mouse_experiment_data(CV_data)

In [None]:
GF_data.keys()

In [None]:
day = 'all'
for key in GF_average_data.keys():
    try:
        print(f'{key} has {GF_average_data[key][day]["num_approach_events"]} approaches for day {day}')
    except:
        print(f'{key} has no data for day {day}')
for key in CV_average_data.keys():
    try:
        print(f'{key} has {CV_average_data[key][day]["num_approach_events"]} approaches for day {day}')
    except:
        print(f'{key} has no data for day {day}')

In [None]:
#Extract approach data for individual days of training
GF_approach_data = extract_approach_data(GF_average_data)
CV_approach_data = extract_approach_data(CV_average_data)

In [None]:
plot_metric_across_days(CV_approach_data,
                        GF_approach_data,
                        'num_approach_events',
                        title='Number of Interception Events per Trial',
                        ylabel='Number of Interception Events per Trial',
                        group1_name='CV',
                        group2_name='GF',
                        plot_type='both')

plot_metric_across_days(CV_approach_data,
                        GF_approach_data,
                        'distance_travelled',
                        title='Distance Travelled per Trial',
                        ylabel='Distance Travelled per Trial (cm)',
                        group1_name='CV',
                        group2_name='GF')

plot_metric_across_days(CV_approach_data,
                        GF_approach_data,
                        'average_speed_during_interception',
                        title='Average Speed During Interception',
                        ylabel='Average Speed During Interception (cm/s)',
                        group1_name='CV',
                        group2_name='GF')

plot_metric_across_days(CV_approach_data,
                        GF_approach_data,
                        'max_speed_during_interception',
                        title='Max Speed During Interception',
                        ylabel='Max Speed During Interception (cm/s)',
                        group1_name='CV',
                        group2_name='GF')

plot_metric_across_days(CV_approach_data,
                        GF_approach_data,
                        'time_to_capture',
                        title='Time to Capture',
                        ylabel='Time to Capture (s)',
                        group1_name='CV',
                        group2_name='GF')

plot_metric_across_days(CV_approach_data,
                        GF_approach_data,
                        'distance_to_cricket_at_approach_start',
                        title='Distance to Cricket at Approach Start',
                        ylabel='Distance to Cricket at Approach Start (cm)',
                        group1_name='CV',
                        group2_name='GF')
plot_metric_across_days(CV_approach_data,
                        GF_approach_data,
                        'num_incomplete_approach_events',
                        title='Number of Incomplete Approach Events per Trial',
                        ylabel='Number of Incomplete Approach Events per Trial',
                        group1_name='CV',
                        group2_name='GF')

In [None]:
# Stack and average the heading data across all keys
GF_all_headings = []
CV_all_headings = []
days = ['1', '2', '3', '4', '5', '6', '7']

for day in days:
    plt.figure(figsize=(7, 4), dpi=100)
    for key in GF_average_data.keys():
        try:
            headings = np.stack(GF_average_data[key][day]['heading_cricket_mouse+-5'])
        except:
            print(f'{key} has no data for day {day}')
            continue
        avg_heading = np.nanmedian(headings, axis=0)
        # Apply sliding window average over two frames to individual traces
        avg_heading_filtered = np.array([(avg_heading[i] + avg_heading[i+1])/2 for i in range(len(avg_heading)-1)])
        plt.plot(np.abs(avg_heading_filtered), color = '#ff7f0e', alpha = 0.1)
        GF_all_headings.append(np.abs(avg_heading_filtered))
    
    for key in CV_average_data.keys():
        try:
            headings = np.stack(CV_average_data[key][day]['heading_cricket_mouse+-5'])
        except:
            print(f'{key} has no data for day {day}')
            continue
        avg_heading = np.nanmedian(headings, axis=0)
        # Apply sliding window average over two frames to individual traces
        avg_heading_filtered = np.array([(avg_heading[i] + avg_heading[i+1])/2 for i in range(len(avg_heading)-1)])
        plt.plot(np.abs(avg_heading_filtered), color = '#1f77b4', alpha = 0.1)
        CV_all_headings.append(np.abs(avg_heading_filtered))
    
    # Average across all keys
    GF_overall_avg_heading = np.nanmedian(np.stack(GF_all_headings), axis=0)
    CV_overall_avg_heading = np.nanmedian(np.stack(CV_all_headings), axis=0)
    
    # Apply sliding window average over two frames
    GF_filtered = np.array([(GF_overall_avg_heading[i] + GF_overall_avg_heading[i+1])/2 for i in range(len(GF_overall_avg_heading)-1)])
    CV_filtered = np.array([(CV_overall_avg_heading[i] + CV_overall_avg_heading[i+1])/2 for i in range(len(CV_overall_avg_heading)-1)])
    
    plt.plot(GF_filtered, color = '#ff7f0e', label = 'GF')
    plt.plot(CV_filtered, color = '#1f77b4', label = 'CV')
    plt.axvline(150, color = 'black', linestyle = '--', alpha = 0.5)
    plt.ylim(-2, 50)
    plt.xlim(80, 220)
    plt.title(f'Median Heading around interception for day {day}')
    plt.xlabel('Time before interception (s)')
    plt.ylabel('Absolute Heading (degrees)')
    current_ticks = plt.gca().get_xticks()
    current_ticks = (current_ticks-150)/30
    _ = plt.gca().set_xticklabels([f'{tick:.1f}' for tick in current_ticks])
    
    plt.legend()
    plt.show()

In [None]:
# Stack and average the heading data across all keys
GF_all_speeds = []
CV_all_speeds = []
days = ['1', '2', '3', '4', '5', '6', '7']

for day in days:
    plt.figure(figsize=(7, 4), dpi=100)
    for key in GF_average_data.keys():
        try:
            speeds = np.stack(GF_average_data[key][day]['mouse_speed+-5'])
        except:
            print(f'{key} has no data for day {day}')
            continue
        avg_speed = np.nanmedian(speeds, axis=0)
        # Apply sliding window average over two frames to individual traces
        avg_speed_filtered = np.array([(avg_speed[i] + avg_speed[i+1])/2 for i in range(len(avg_speed)-1)])
        plt.plot(avg_speed_filtered, color = '#ff7f0e', alpha = 0.1)
        GF_all_speeds.append(avg_speed_filtered)
        print(len(GF_all_speeds))
    
    for key in CV_average_data.keys():
        try:
            speeds = np.stack(CV_average_data[key][day]['mouse_speed+-5'])
        except:
            print(f'{key} has no data for day {day}')
            continue
        avg_speed = np.nanmedian(speeds, axis=0)
        # Apply sliding window average over two frames to individual traces
        avg_speed_filtered = np.array([(avg_speed[i] + avg_speed[i+1])/2 for i in range(len(avg_speed)-1)])
        plt.plot(avg_speed_filtered, color = '#1f77b4', alpha = 0.1)
        CV_all_speeds.append(avg_speed_filtered)
    
    # Average across all keys
    GF_overall_avg_speed = np.nanmedian(np.stack(GF_all_speeds), axis=0)
    CV_overall_avg_speed = np.nanmedian(np.stack(CV_all_speeds), axis=0)
    # plt.plot(GF_overall_avg_speed, color = '#ff7f0e', label = 'GF')
    # plt.plot(CV_overall_avg_speed, color = '#1f77b4', label = 'CV')
    
    # Apply sliding window average over two frames
    GF_filtered = np.array([(GF_overall_avg_speed[i] + GF_overall_avg_speed[i+1])/2 for i in range(len(GF_overall_avg_speed)-1)])
    CV_filtered = np.array([(CV_overall_avg_speed[i] + CV_overall_avg_speed[i+1])/2 for i in range(len(CV_overall_avg_speed)-1)])
    
    plt.plot(GF_filtered, color = '#ff7f0e', label = 'GF')
    plt.plot(CV_filtered, color = '#1f77b4', label = 'CV')
    plt.axvline(150, color = 'black', linestyle = '--', alpha = 0.5)
    plt.xlim(80, 220)
    plt.ylim(0, 30)
    plt.title(f'Median Speed around interception for day {day}')
    plt.xlabel('Time (s)')
    plt.ylabel('Speed (cm/s)')
    plt.legend()
    current_ticks = plt.gca().get_xticks()
    current_ticks = (current_ticks-150)/30
_ = plt.gca().set_xticklabels([f'{tick:.1f}' for tick in current_ticks])

plt.show()