# Load and process

This notebook shows the use of the functions in the `subject.py` module. They are functions that deal with loading and calibrating data 
from a single subject, cutting the data into trials, and then getting a measure of success.

## Imports and definitions

In [1]:
import os
import sys
import pickle
import re
import copy

import numpy as np
import pandas as pd

from scipy import interpolate
from scipy import signal
from scipy.optimize import curve_fit
from scipy.stats import sem


import matplotlib as mpl
from matplotlib import pyplot as plt
from matplotlib import animation as animation
from matplotlib.widgets import Slider

# importing functions from our projects py files
module_path = os.path.abspath(os.path.join('..')) # the path to the source code
sys.path.insert(0, module_path)

from markers_analysis import constants as consts
from markers_analysis import subject as subj
from markers_analysis import billiards_table as billiards





In [2]:
# ???
mpl.rcParams.update({'font.size': 14})

%matplotlib widget  
# %matplotlib qt

## Loading data

### The regular expression

The code for finding the appropriate regular expression in the `subject.py` module was built using [Regular Expression 101](https://regex101.com/). 

Actually, after that I needed a little help from ChatGPT.

In [3]:
results_path = os.path.join('.', 'pkl')
fig_path = os.path.join('.', 'img')

In [4]:
basename = 'white_ball_hit '
subject_id_list = ['007', '008', '009', '010', '011', '012', '013', '014']
date_list = ['2023_11_20', '2023_11_20', '2023_11_20', '2023_12_03', '2023_12_03', '2023_12_03', '2023_12_14', '2023_12_14']

# subject_id_list = ['012', '013', '014']
# date_list = ['2023_12_03', '2023_12_14', '2023_12_14']



### Remapping columns

In [5]:
column_mapping = {
    "white_ball": "cue_ball",
    "purple_ball": "object_ball",
    "red_ball": "target_ball",
    "cue_mid": "start_cue",
    "cue_tip": "end_cue",
}


### Loading the files

In [None]:
all_data = []
for subject_id, date in zip(subject_id_list, date_list):
    calibrated_file_name = os.path.join(results_path, f'id{subject_id} calibrated.pkl')

    if os.path.exists(calibrated_file_name):
        # File exists, unpickle the data
        with open(calibrated_file_name, 'rb') as file:
            subject_data = pickle.load(file)
    else:
        # Process the file
        subject_data = {}

        info = {}
        info['id'] = subject_id
        info['date'] = date
        info['basename'] = basename
        subject_data['info'] = info
        subject_data['data'] = []

        for filenum,filename in subj.find_ball_data_files(date=info['date'], id=subject_id, basename='white_ball_hit '):
            data_dict_pixels = subj.load_ball_data_file(filename, column_mapping=column_mapping, calibrate_table=False)
            # data_dict_for_video = copy.deepcopy(data_dict_pixels)
            # end_video_time = np.min([30, data_dict_for_video['ts'].time[-1]])
            # start_video_time = np.max([0, end_video_time-15])
            # data_dict_for_video['ts'] = data_dict_for_video['ts'].get_ts_between_times(start_video_time, end_video_time)
            # video_file = f'pixels {subject_id} {filenum}.mp4'
            # billiards.create_video(data_dict_for_video, video_file, do_dots=True, xlims=[0, 220], ylims=[0, 220], keep_time=[10, 20])

            data_dict = subj.calibrate_table(data_dict_pixels)
            # data_dict_for_video = copy.deepcopy(data_dict)
            # data_dict_for_video['ts'] = data_dict_for_video['ts'].get_ts_between_times(start_video_time, end_video_time)
            # video_file = f'calibrated {subject_id} {filenum}.mp4'
            # billiards.create_video(data_dict_for_video, video_file, do_dots=True, xlims=[0, 220], ylims=[0, 220], keep_time=[10, 20])

            data_dict['filenum'] = filenum
            data_dict['filename'] = filename
            subject_data['data'].append(data_dict)
            print(f"************* Done with subject f{subject_data['info']['id']} file number {subject_data['data'][-1]['filenum']}: {subject_data['data'][-1]['filename']}")

        with open(calibrated_file_name, 'wb') as file:
            pickle.dump(subject_data, file)

    all_data.append(subject_data)
        



Loaded ..\data\subject_007\2023_11_20\table\white_ball_hit 1_video4DLC_resnet50_purple objectNov8shuffle1_200000.h5
  Results returned in pixels
pixel_coords=[array([ 40.41199149, 142.40061707]), array([ 67.47350378, 143.3311501 ]), array([ 94.23731636, 144.30615744]), array([120.1031982 , 145.16072987]), array([ 24.50786529, 120.47332676]), array([26.38408262, 83.60748862]), array([136.06664908, 125.10479049]), array([136.23466782, 106.49962615]), array([135.94942888,  88.94708557]), array([43.69408282, 64.70350371]), array([70.10664092, 66.59071528]), array([95.37269991, 67.45593823]), array([120.38531098,  68.95852854])]
real_coords=

  Table calibrated with an error of 0.331 cm
************* Done with subject f007 file number 1: ..\data\subject_007\2023_11_20\table\white_ball_hit 1_video4DLC_resnet50_purple objectNov8shuffle1_200000.h5




Loaded ..\data\subject_007\2023_11_20\table\white_ball_hit 2_video4DLC_resnet50_purple objectNov8shuffle1_200000.h5
  Results returned in pixels
pixel_coords=[array([ 40.39713104, 142.39675764]), array([ 67.47761155, 143.35494839]), array([ 94.2395845 , 144.35229191]), array([120.11841693, 145.29954208]), array([ 24.48771717, 120.50534175]), array([26.34521655, 83.62137014]), array([136.07692594, 125.21448169]), array([136.1907426 , 106.64823699]), array([135.87270631,  88.97478756]), array([43.64674573, 64.72453448]), array([70.10387246, 66.6043267 ]), array([95.26561408, 67.37447814]), array([120.23761313,  69.18994833])]
real_coords=

  Table calibrated with an error of 0.343 cm
************* Done with subject f007 file number 2: ..\data\subject_007\2023_11_20\table\white_ball_hit 2_video4DLC_resnet50_purple objectNov8shuffle1_200000.h5




### Calculate velocity and speed
It is pretty easy to add new calculations based on the underlying function `subj.get_calculated_data` which runs over all the data and adds a calculation.   


In [None]:
for subject_data in all_data:
    updated_data_dicts = []
    for data_dict in subject_data['data']:
        ts = data_dict["ts"]
        ball_names = [k for k, v in ts.data_info.items() if v["Type"] == "Ball"]

        balls = ts.get_subset(ball_names)
        subj.get_velocities(balls, merge_in_place=True)
        subj.get_speeds(balls, merge_in_place=True)
        subj.get_moving(balls, merge_in_place=True)
        ts.merge(balls, in_place=True, overwrite=False)

        data_dict["ts"] = ts
        updated_data_dicts.append(data_dict)

    subject_data['data'] = updated_data_dicts



## Cut the data
These section take the long stream of data and identifies the beginning and end of shots and generates a list of time series with the data from each shot.

In [None]:
all_cut_data = []

for subject_data in all_data:
    subject_id = subject_data['info']['id']
    cut_filename = os.path.join(results_path, f'id{subject_id} cut.pkl') 

    if os.path.exists(cut_filename):
        # File exists, unpickle the data
        with open(cut_filename, 'rb') as file:
            shots_data = pickle.load(file)
    else:
        shots_data = {}
        shots_data['info'] = subject_data['info']
        shots_data['shots'] = []

        for data_dict in subject_data['data']:
            shot_times = subj.get_shots(data_dict)

            for shotnum, shotrow in shot_times.iterrows():
                start_time = shotrow.start_time
                end_time = shotrow.end_time
                shot_ts = data_dict['ts'].get_ts_between_times(start_time, end_time, inclusive=True)
                # Set 0 time to the moment the shot actually starts
                zero_time = shot_ts.time[0]+consts.start_positions_time
                shot_ts.shift(-zero_time, in_place=True) 
                
                shot_data = {}
                shot_data['filenum'] = data_dict['filenum']
                shot_data['shotnum'] = shotnum
                shot_data['start_time'] = start_time
                shot_data['zero_time'] = zero_time
                shot_data['end_time'] = end_time
                shot_data['table'] = data_dict['table']
                shot_data['ts'] = shot_ts
                
                shots_data['shots'].append(shot_data)
    
        with open(cut_filename, 'wb') as file:
            pickle.dump(shots_data, file)      


    all_cut_data.append(shots_data)




## Plot success

### Count hits

In [None]:
ax = None
for shots_data in all_cut_data:
    subject_id = shots_data['info']['id']
    hits = []
    num_shots = len(shots_data['shots'])
    bin_width = 30
    num_bins = num_shots - bin_width
    bin_start = np.arange(0, num_bins)
    bin_stop = np.arange(bin_width, num_shots)
    bin_middle = (bin_start + bin_stop) / 2
    bin_hits = np.full_like(bin_start, 0)

    for start, stop in zip(bin_start, bin_stop):
        bin_shots = shots_data['shots'][start:stop]
        bin_hits[start] = subj.count_hits(bin_shots)


    if not ax:
        fig, ax = plt.subplots()
    else:
        ax.clear()

    ax.set_xlim([0, 200])
    ax.plot(bin_middle, bin_hits)
    ax.set_ylabel('Hits')
    ax.set_xlabel('Shot')
    ax.set_title(f'Subject {subject_id} successes (bin width: {bin_width})')

    fig_savename = os.path.join(fig_path, f"successes {subject_id}.png")
    plt.savefig(fig_savename)
    

### Angular error

#### Get the angular error

#### Plot absolute angular error learning curve

In [None]:
ax = None
for shots_data in all_cut_data:
    subject_id = shots_data['info']['id']

    angles = subj.get_angles(shots_data['shots'])
    hits = subj.is_hit(shots_data['shots'])

    # Concatenate the arrays into one long 1D array
    # angles_deg = np.abs(angles * 180 / np.pi)
    angles_deg = np.abs(angles * 180 / np.pi)

    # Remove the big outliers
    outliers = (angles_deg < -60) | (angles_deg > 60)
    angles_deg[outliers] = np.nan

    # Interpolate over NaN values
    nan_indices = np.isnan(angles_deg)
    interp_angles = angles_deg.copy()
    interp_angles[nan_indices] = np.interp(
        np.flatnonzero(nan_indices),
        np.flatnonzero(~nan_indices),
        angles_deg[~nan_indices])

    # Define the exponential function
    def exponential_function(x, a, b, c):
        return a * np.exp(-b * x) + c

    # Define a function to fit exponential function with NaN values
    def fit_exponential_with_nans(x, y):
        # Filter out NaN values
        valid_indices = ~np.isnan(y)
        x_valid = x[valid_indices]
        y_valid = y[valid_indices]

        # Perform the exponential fit
        try:
            params, _ = curve_fit(exponential_function, x_valid, y_valid)
        except RuntimeError:
            params = [np.nan, np.nan, np.nan]

        # Create the fitted curve including NaN values
        fitted_curve = exponential_function(x, *params)

        return fitted_curve

    # Generate an x array (assuming your data is evenly spaced)
    x = np.arange(len(angles_deg))

    
    training_epoch = x >= 15

    # Perform the exponential fit with NaN values
    exponential_fit = fit_exponential_with_nans(x[training_epoch], angles_deg[training_epoch])

    # Plot the original data and fit
    if not ax:
        fig, ax = plt.subplots(figsize=(21/2.54, 7/2.54))
    else:
        ax.clear()

    # hit_colors = np.where(hits, 'red', 'blue')
    hits = np.array(hits)
    shot_number = np.arange(len(angles_deg))

    ax.plot(shot_number[training_epoch], angles_deg[training_epoch], 'bo', label='Angular error')
    ax.plot(shot_number[hits & training_epoch], angles_deg[hits & training_epoch], 'ro', label='Success')
    ax.plot(shot_number[~training_epoch], angles_deg[~training_epoch], 'o', color='grey')
    ax.axvspan(xmin=0, xmax=15, color='#D2B48C', alpha=0.3, label='Baseline')

            
    ax.plot(shot_number[training_epoch], exponential_fit, 'b-', linewidth=2)
    ax.set_ylim([-20, 75])
    ax.set_xlim([0, 200])
    ax.set_title(f'Subject {subject_id}: Absolute angular error')
    ax.set_xlabel('Shot')
    ax.set_ylabel('$|Error|$ (deg)')

    ax.legend(loc='upper right') #, bbox_to_anchor=(1.25, 1))

    fig_savename = os.path.join(fig_path, f"error {subject_id}.png")
    plt.savefig(fig_savename, bbox_inches='tight')  

In [None]:
import importlib

import markers_analysis

importlib.reload(markers_analysis)
importlib.reload(markers_analysis.constants)
importlib.reload(markers_analysis.subject)
importlib.reload(markers_analysis.billiards_table)
importlib.reload(markers_analysis.math)
importlib.reload(consts)
importlib.reload(billiards)
importlib.reload(subj)


# Make average learning plot

## Get data for all subjects into dataframes

In [None]:
angles_list = []
hits_list = []

for shots_data in all_cut_data:
    subject_id = shots_data['info']['id']

    angles = subj.get_angles(shots_data['shots'])
    hits = subj.is_hit(shots_data['shots'])

    # Concatenate the arrays into one long 1D array
    # angles_deg = np.abs(angles * 180 / np.pi)
    angles_deg = angles * 180 / np.pi

    # Remove the big outliers
    outliers = (angles_deg < -60) | (angles_deg > 60)
    angles_deg[outliers] = np.nan

    # Interpolate over NaN values
    nan_indices = np.isnan(angles_deg)
    interp_angles = angles_deg.copy()
    interp_angles[nan_indices] = np.interp(
        np.flatnonzero(nan_indices),
        np.flatnonzero(~nan_indices),
        angles_deg[~nan_indices]
    )

    angles_list.append(angles_deg)
    hits_list.append(hits)

angles_df = pd.DataFrame(angles_list)
hits_df = pd.DataFrame(hits_list)


In [None]:
from statsmodels.robust import mad

num_subj = angles_df.shape[0]

# Replace NaN with median for MAD calculation
angles_no_nan = angles_df.apply(lambda x: x.fillna(x.median()))


# Calculate median and Median Absolute Deviation (MAD) for each shot
median_angles = -angles_df.median(axis=0, skipna=True)
mad_angles = mad(angles_no_nan, axis=0, center=np.median)

# Convert mad_angles to a Pandas Series
mad_series = pd.Series(mad_angles, index=median_angles.index)

# Smooth the SEM using a rolling window average
window_size = 10  # Adjust the window size as needed
smoothed_median = median_angles.rolling(window=window_size, min_periods=1, center=True).mean()
smoothed_sem = mad_series.rolling(window=window_size, min_periods=1, center=True).mean()

# Fill NaN values at the edges caused by rolling window
smoothed_median = smoothed_median.interpolate(method='linear')
smoothed_sem = smoothed_sem.interpolate(method='linear')


# Calculate percentage of hits for each shot
percentage_hits = hits_df.mean(axis=0, skipna=True)

# Bin the percentage hits into discrete bins
num_bins = 5
bins = pd.cut(percentage_hits, bins=num_bins, labels=False, retbins=True)[1]

# Map each shot to a color based on the percentage hits
RGB = lambda p,b: ((1-b)*p + b, -b*p+b, -b*p+b)
color_dict = {shot: RGB(percentage_hit, 0.5) for shot, percentage_hit in percentage_hits.items()}

# Plotting
fig, ax = plt.subplots(figsize=(1.2*21/2.54, 1.2*1.2*7/2.54))

# Plot shaded patch for smoothed SEM
x_vals = np.arange(len(smoothed_median))
sem_handle = ax.fill_between(x_vals, smoothed_median - smoothed_sem, smoothed_median + smoothed_sem, alpha=0.6, color='darkgray', label='Smoothed median +/- SEM')

# Plot markers with colored shading based on percentage of hits for each shot
for shot, percentage_hit in percentage_hits.items():
    if shot >= 15:
        color = color_dict[shot]
            
    else:
        color = 'k'
    line_handle = ax.scatter(shot, median_angles[shot], color=color, s=50)

    if shot < 15:
        median_handle = line_handle
    if shot >= 15: 
        if percentage_hit == np.max(percentage_hits):
            max_label_handle = line_handle
            max_percentage_val = round(percentage_hit*100, -1)
        elif percentage_hit == np.min(percentage_hits):
            min_label_handle = line_handle            
            min_percentage_val = round(percentage_hit*100, -1)

median_handle.set_label(f'Median (N={num_subj})')            
max_label_handle.set_label(f'{max_percentage_val:.0f}% hits')
min_label_handle.set_label(f'{min_percentage_val:.0f}% hits')


baseline_handle = ax.axvspan(xmin=0, xmax=15, color='#D2B48C', alpha=0.3, label='Baseline')
plt.axhline(0, color='black', linestyle='--', dashes=(8,4))

# Add custom legend lines using handles
legend_handles = [
    median_handle,
    min_label_handle,
    max_label_handle,
    sem_handle,
    baseline_handle
]

ax.set_xlim(0, 200)
ax.set_ylim(-15, 50)

# Customize plot
ax.set_xlabel('Shot')
ax.set_ylabel('Angular Error')
# ax.set_title('Median Angular Error with MAD and Marker Color Based on Percentage of Hits for Each Shot')
ax.legend(handles=legend_handles, loc='upper right', bbox_to_anchor=(1, 1.2))
plt.xticks(rotation=45)
plt.tight_layout()

fig_savename = os.path.join(fig_path, f"median error.png")
plt.savefig(fig_savename, bbox_inches='tight', dpi=600)

In [None]:
# Convert mad_angles to a Pandas Series
mad_series = pd.Series(mad_angles, index=median_angles.index)

# Smooth the SEM using a rolling window average
window_size = 3  # Adjust the window size as needed
smoothed_sem = mad_series.rolling(window=window_size, min_periods=1, center=True).mean()

# Fill NaN values at the edges caused by rolling window
smoothed_sem = smoothed_sem.interpolate(method='linear')

smoothed_sem