In [17]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.lines as mlines
import matplotlib.patches as mpatches
import glob
import os
from scipy.io import loadmat
from filterpy.kalman import KalmanFilter
import math
import seaborn as sns

In [9]:
# Insert paths for parquets and matlab files

parquet_path = r'C:\Users\OfekSapir\Desktop\reptile_lab\retpile_lab\perdictions'
mat_path = r'C:\Users\OfekSapir\Desktop\reptile_lab\retpile_lab\mat_data'

In [10]:
# Data preparation
# Filtering by probability and positions based on given parameters
def data_prep(parquet, mat, prob_quantile=0.02, pos_std_factor=3):
    
    # Calculate probability threshold
    prob_threshold = parquet[('nose', 'prob')].quantile(prob_quantile)
    
    # Filter rows with low probability
    filtered_parquet = parquet[parquet[('nose', 'prob')] >= prob_threshold].copy()
    
    # Create positional differences for x and y
    filtered_parquet['diff_x'] = filtered_parquet[('nose', 'x')].diff().abs()
    filtered_parquet['diff_y'] = filtered_parquet[('nose', 'y')].diff().abs()
    
    # Calculate mean and std for positional differences
    mean_diff_x = np.mean(filtered_parquet['diff_x'].dropna())
    std_diff_x = np.std(filtered_parquet['diff_x'].dropna())
    
    mean_diff_y = np.mean(filtered_parquet['diff_y'].dropna())
    std_diff_y = np.std(filtered_parquet['diff_y'].dropna())
    
    # Calculate positional threshold
    x_threshold = mean_diff_x + pos_std_factor * std_diff_x
    y_threshold = mean_diff_y + pos_std_factor * std_diff_y
    
    # Filter rows with high location differences
    filtered_parquet = filtered_parquet[
        (filtered_parquet['diff_x'] <= x_threshold) &
        (filtered_parquet['diff_y'] <= y_threshold)
    ]
    
    # Get relevant columns
    raw = filtered_parquet[[('nose', 'x'), ('nose', 'y'), ('time', ''), ('angle', '')]]
    raw.columns = raw.columns.map('_'.join)
    
    # Prepare nose data
    nose_x = raw['nose_x'].tolist()
    nose_y = raw['nose_y'].tolist()
    
    # Prepare time delta
    time_unix = pd.to_datetime(raw['time_'], unit='s') # Creating Unix time
    time_delta = (time_unix - time_unix.min()).dt.total_seconds() # Creating time from zero for plot
    
    # Determine trial frames from mat
    start_frames = list(mat['arenaCSVs']['startFrameSh'])[0][0]
    end_frames = list(mat['arenaCSVs']['endFramSh'])[0][0]
    
    try:
        assert len(start_frames) == len(end_frames), "Mismatch in start and end frame lengths"
    except AssertionError as e:
        print(f"Warning: {e}. Proceeding, but results may be incorrect.")
    
    frame_range_lst = []
    for i in range(len(start_frames)):
        start_value = int(start_frames[i].item()) if isinstance(start_frames[i], (np.ndarray, np.generic)) else int(start_frames[i])
        end_value = int(end_frames[i].item()) if isinstance(end_frames[i], (np.ndarray, np.generic)) else int(end_frames[i])
        frame_range_lst.append(range(start_value, end_value))
    
    # Set angle
    angle = raw['angle_']
    
    # Set index
    index = list(range(len(raw)))
    
    # Determine which frames are in the trial 
    index_max = max(r.stop for r in frame_range_lst)
    in_trial = [any(i in r for r in frame_range_lst) for i in range(index_max)]
    in_trial.extend([False] * (len(index) - len(in_trial)))
    
    data = pd.DataFrame({
        'frame_num': index, 
        'time_from_zero': time_delta, 
        'nose_x': nose_x, 
        'nose_y': nose_y, 
        'angle': angle, 
        'in_trial': in_trial[:len(index)]
          })
    
    return data

In [11]:
# Screen view probability calculation
# Running on data after going through data_prep
def screen_view_probability_cal(data):
    # Define the target angle (pi/2 radians)
    reference_angle_1 = 0
    reference_angle_2 = np.pi
    
    # Calculate the deviation from the target angles
    data['angle_deviation'] = np.where(
        (reference_angle_1 < data['angle']) & (data['angle'] < reference_angle_2),
        0,
        np.where(
            (3 * np.pi / 2 < data['angle']) & (data['angle'] < 2 * np.pi),
            np.abs(2 * np.pi - (data['angle'] - reference_angle_1)),
        np.where(
            (np.pi < data['angle']) & (data['angle'] < 3 * np.pi / 2),
            np.abs(data['angle'] - reference_angle_2),
            np.nan  # If the angle does not fall within these ranges
        )))
    
    # Normalize the deviation to create a probability score
    threshold = np.pi / 2
    data['probability'] = np.clip(1 - (data['angle_deviation'] / threshold), 0, 1)
    
    return data

In [12]:
# Plot function nose x-y plane real world
def plot_plane(data, date, path_to_save=r'C:\Users\OfekSapir\Desktop\reptile_lab\retpile_lab\filtered', save=False):   
    # Assign screen sight probability for each frame
    screen_view_probability_cal(data)

    # Identify segments where probability is above 0.8
    high_prob = data['probability'] > 0.8

    default_color = 'black'
    highlight_color = 'red'

    plt.figure(figsize=(10, 6))
    plt.plot(data['nose_y'], data['nose_x'], color=default_color, linewidth=2)
    
    # Plot with highlighting segments based on probability
    for i in range(1, len(data)):
        color = highlight_color if high_prob.iloc[i] else default_color
        plt.plot([data['nose_y'].iloc[i-1], data['nose_y'].iloc[i]], 
                 [data['nose_x'].iloc[i-1], data['nose_x'].iloc[i]], 
                 color=color, linewidth=2)

    plt.title(f'Nose x-y Plane for {date}')
    plt.ylabel('X Coordinates (cm)')
    plt.xlabel('Y Coordinates (cm)')
    plt.grid(True)
    plt.tight_layout()

    counter = 1
    filename = os.path.join(path_to_save, f'plane_{counter}.png')
    while os.path.exists(filename):
        counter += 1
        filename = os.path.join(path_to_save, f'plane_{counter}.png')

    if save:
        if not os.path.exists(path_to_save):
            os.makedirs(path_to_save)
        plt.savefig(filename)
    plt.show()

In [8]:
# Plot function nose x & y coordinates over time
def plot_overtime(data, date, path_to_save=r'C:\Users\OfekSapir\Desktop\reptile_lab\retpile_lab\figs', save=False):
     # Assign screen sight probability for each frame
    screen_view_probability_cal(data)
     
    coordinates_col = ['nose_x', 'nose_y']
    fig, ax = plt.subplots(figsize=(15, 8))
    for col in coordinates_col:
        ax.plot(data['time_from_zero'], data[col], label=f'Coordinate of {col[-1]}')
        ax.set_title(f'Location of Pogona Over Time for {date}')
        ax.set_xlabel('Timestamp')
        ax.set_ylabel('Coordinates (cm)')
        ax.grid(True)
        ax.legend()
    plt.tight_layout()

    counter = 1
    filename = os.path.join(path_to_save, f'overtime_plane_{counter}.png')
    while os.path.exists(filename):
        counter += 1
        filename = os.path.join(path_to_save, f'overtime_plane_{counter}.png')
        
    if save:
        if not os.path.exists(path_to_save):
            os.makedirs(path_to_save)
        plt.savefig(filename)
    plt.show()

In [15]:
# First self written Kalman filter
# Example for parameters is given below
def kalman_filter_1(data, process_variance, measurement_variance, estimated_error, initial_value):
    n = len(data)
    kalman_estimates = np.zeros(n)
    kalman_gain = 0

    # Initial guesses
    current_estimate = initial_value
    current_error = estimated_error

    for i in range(n):
        # Kalman Gain
        kalman_gain = current_error / (current_error + measurement_variance)
        
        # Update estimate with measurement
        current_estimate = current_estimate + kalman_gain * (data[i] - current_estimate)
        kalman_estimates[i] = current_estimate
        
        # Update the error covariance
        current_error = (1 - kalman_gain) * current_error + process_variance
    
    return kalman_estimates
process_variance = 1e-4  # Small value for smoother results (adjust as needed)
measurement_variance = 0.001  # Measurement noise (adjust based on your data)
estimated_error = 0.1 # Initial estimate error

In [18]:
# Second Kalman filter created by the package KalmanFilter
# Output is a body part coordinates
def kalman_filter_2(coords):
    kf = KalmanFilter(dim_x=2, dim_z=1)  # 2 states (position, velocity), 1 measurement
    
    # Initial state: assume starting at first coordinate with zero initial velocity
    kf.x = np.array([[coords[0]], [0]])  
    
    # State transition matrix (F)
    dt = 1  # Assuming 1 frame step; adjust if time interval is different
    kf.F = np.array([[1, dt], [0, 1]])
    
    # Measurement function (H): we only measure position
    kf.H = np.array([[1, 0]])
    
    # Measurement noise covariance (R): based on data variance
    kf.R = np.array([[np.var(coords) * 0.01]])  # Scaled down to smooth more aggressively
    
    # Process noise covariance (Q)
    kf.Q = np.array([[1e-4, 0], [0, 1e-4]])  # Small values to maintain smoothness
    
    # Initial state covariance (P)
    kf.P = np.eye(2) * 500
    
    # Apply the filter
    smoothed_coords = []
    for z in coords:
        kf.predict()
        kf.update([[z]])
        smoothed_coords.append(kf.x[0, 0])  # Extract filtered position
    
    return np.array(smoothed_coords)

In [19]:
# Setting up the data
parquet_files = glob.glob(os.path.join(parquet_path, '**', '*.parquet'), recursive=True)
mat_files = glob.glob(os.path.join(mat_path, '**', '*.mat'), recursive=True)

parquet_dict = {}
mat_lst = []
data_lst = []
data_lst_with_prob = []

# Load the parquets data
for i, file in enumerate(parquet_files):
    date = str(file[-17:-15] + '/' + file[-19:-17] + '/' + file[-23:-19])
    parquet_dict[date] = pd.read_parquet(file, engine='pyarrow')

# Load the matlab tables data 
for j, mat in enumerate(mat_files):
    mat_lst.append(loadmat(mat))

# Preparing the data for work
for parquet_val, mat in zip(parquet_dict.values(), mat_lst):
    data_lst.append(data_prep(parquet_val, mat, pos_std_factor=0))

# Calculating screen view probabilities for the data
for data in data_lst:
    prob_data = screen_view_probability_cal(data)
    data_lst_with_prob.append(prob_data)
    
if len(parquet_dict.keys()) != len(mat_lst):
    print('Parquet files and mat files are not synced!')



In [ ]:
# Plotting positional distribution box-whisker plot / violin plot

fig, ax = plt.subplots(figsize=(15, 8))
labels = list(parquet_dict.keys())
colors = ['red', 'blue', 'green', 'orange', 'purple', 'brown']

# values = [np.log1p(value['diff_x']) for value in parquet_dict.values()]
values = [value['diff_x'] for value in parquet_dict.values()]

# values_df = pd.concat(values, axis=1)
# values_df.columns = [labels[0], labels[1], labels[2], labels[3], labels[4], labels[5]]
# melted_values = values_df.melt(var_name='label', value_name='value')

for value in values:
    if pd.isna(value.iloc[0]):  
        value.iloc[0] = 0
        


# sns.violinplot(x='label', y='value', data=melted_values)
ax.boxplot(values, tick_labels=labels)
ax.set_xticks(range(1, len(labels) + 1))
ax.set_xticklabels(labels, rotation=45)
ax.set_ylabel('Amount')
ax.set_title('Position Difference Distribution')
plt.show()

In [ ]:
# Plotting positional difference distribution histogram
fig, ax = plt.subplots(figsize=(15, 8))
above_10 = []
for label, value, color in zip(labels, parquet_dict.values(), colors):
    value['diff_x'] = value[('nose', 'x')].diff().abs()
    mean_diff_x = np.mean(value['diff_x'].dropna())
    std_diff_x = np.std(value['diff_x'].dropna())
    above_10.append((value['diff_x'] >= mean_diff_x + std_diff_x * 10).sum())
    tenth_threshold = mean_diff_x + std_diff_x * 10
    
    ax.hist(value['diff_x'], bins=200, alpha=0.7)
    plt.xlim(0.95,50)
    plt.ylim(0,20)
    ax.axvline(mean_diff_x, color='green', linestyle='dashed', linewidth=1, label=f'Mean: {mean_diff_x:.2f}')
    ax.axvline(
    tenth_threshold, color=color, linestyle='dashed', linewidth=1,
    label=f'10 STDs Threshold: {tenth_threshold:.2f} for {label}'
    )

text = f"Number of frames above 10 STDs\n{above_10[0]} for {labels[0]}\n{above_10[1]} for {labels[1]}\n{above_10[2]} for {labels[2]}\n{above_10[3]} for {labels[3]}\n{above_10[4]} for {labels[4]}\n{above_10[5]} for {labels[5]}"
plt.text(
    0.95, 0.95, 
    text,
    fontsize=12,
    color="white",
    ha="right",  
    va="top",     
    transform=ax.transAxes,  
    bbox=dict(facecolor="darkolivegreen", edgecolor="none", boxstyle="round,pad=0.5")
)
plt.xlabel('Difference Between Coordinates')
plt.ylabel('Amount')
plt.legend()
plt.title('Position Difference Distribution')
plt.show()

In [ ]:
# Plotting probabilities distribution box-whisker plot
fig, ax = plt.subplots(figsize=(15, 8))
data = [v[('nose', 'prob')] for v in parquet_dict.values()]
labels = list(parquet_dict.keys())

ax.boxplot(data, tick_labels=labels)
ax.axhline(y=0.8, color='black', linestyle='--', linewidth=1.5, label="Threshold = 0.8")
ax.set_xticks(range(1, len(labels) + 1))
ax.set_xticklabels(labels, rotation=45)
ax.set_ylabel("Probability")
ax.set_title("Distribution of Predictions Probabilities for 6 Experiments")
plt.legend()
plt.show()

In [ ]:
# Plotting nose x-y plane real world
for v in range(len(mat_lst)):
    data = data_prep(list(parquet_dict.values())[v],mat_lst[v], pos_std_factor=10)
    
    data_for_kalman = data_prep(list(parquet_dict.values())[v],mat_lst[v], pos_std_factor=0)
    
    kf1_data = data_for_kalman.copy()
    nose_x_kf1 = np.array(kf1_data['nose_x'])
    nose_y_kf1 = np.array(kf1_data['nose_y'])
    # kalman_filter_1 parameters
    process_variance = 1e-4  
    measurement_variance = 0.001 
    estimated_error = 0.1  
    initial_value_x = nose_x_kf1[0]
    initial_value_y = nose_y_kf1[0]
    kf1_data['nose_x'] = kalman_filter_1(nose_x_kf1, process_variance, measurement_variance, estimated_error, initial_value_x)
    kf1_data['nose_y'] = kalman_filter_1(nose_y_kf1, process_variance, measurement_variance, estimated_error, initial_value_y)
    
    kf2_data = data_for_kalman.copy()
    kf2_data['nose_x'] = kalman_filter_2(kf2_data['nose_x'].values)
    kf2_data['nose_y'] = kalman_filter_2(kf2_data['nose_y'].values)
    
    plot_plane(data, f'{list(parquet_dict.keys())[v]}, std factor = 10', save=True)
    plot_plane(kf1_data, f'{list(parquet_dict.keys())[v]}, kalman filter 1', save=True)
    plot_plane(kf2_data, f'{list(parquet_dict.keys())[v]}, kalman filter 2', save=True)