# Section for running Average Mutual Information Analysis (AMI) and Analysis of False Nearest Neighbours (FNN)
### To execute the analyses, the processed OpenPose data is required.

# AMI

In [None]:
# Necessary imports and functions for AMI
import os
import pandas as pd
from utils_dir.ami_utils import ami, plot_ami_multiple
import numpy as np
from IPython.display import clear_output

def run_ami(dir, columns, conditions, min_lag, max_lag, exclude = [], all=False):
    """
    Computes Average Mutual Information (AMI) values for specified time series columns
    across conditions.

    Parameters:
        dir : str
            The directory path where CSV files containing time series data are stored.  
        columns : list of str
            The names of the columns (time series) for which AMI should be calculated.   
        conditions : list of str
            List of condition labels (e.g., ['trial0', 'trail1', 'trial2']) 
        min_lag : int
            Minimum time lag to consider when calculating AMI.  
        max_lag : int
            Maximum time lag to consider when calculating AMI. Will be adjusted downward
            if any file contains fewer data points than this value.    
        exclude : list of str, optional
            List of participant IDs to exclude based on the filename prefix.
        all: bool
            If true, computes ami across all conditions instead of within each condition.

        Returns:
            ami_dict : dict
                Nested dictionary structured as {condition: {column: [ami_values]}}, where
                ami_values are lists of AMI values for each file and column within that condition.
    """
    
    if not all:
        ami_dict = {condition: {column: [] for column in columns} for condition in conditions}
    else:
        ami_dict = {'all': {column: [] for column in columns}}
    
    all_files = [f for f in os.listdir(dir) if f.endswith('.csv')]
    files = []
    for f in all_files:
        id = f.split('_')[0]
        if id not in exclude:
            files.append(f)
        else:
            print(f"Excluded file: {f}")

    # Find the shortest length of all time series
    min_length = float('inf')
    for file in files:
        df = pd.read_csv(os.path.join(dir, file), sep=',')
        try:
            length = len(df[columns].dropna())
            if length < min_length:
                min_length = length
        except KeyError as e:
            print(f"Missing columns in {file}: {e}")
            continue

    # Adjust max_lag if the value of minlength is smaller than max_lag.
    if min_length < max_lag:
        print(f"Adjusted max_lag from {max_lag} to {min_length} due to short file.")
        max_lag = min_length

    # Main loop
    for file in files:
        df = pd.read_csv(os.path.join(dir, file), sep=',')
        file = file.replace('-', '_')
        condition = file.split('_')[4]

        if condition not in conditions:
            continue

        filtered_df = df[columns].iloc[1:] # Remove the first row with NaN values
        filtered_df_interpolated = filtered_df.interpolate()

        for col in filtered_df_interpolated:
            if col in columns:
                time_series = filtered_df_interpolated[col]
                ami_values = ami(time_series, min_lag=min_lag, max_lag=max_lag)
                if not all:
                    ami_dict[condition][col].append(ami_values.tolist())
                else:
                    ami_dict['all'][col].append(ami_values.tolist())
            else: 
                print(f'Column {col} not found.')
                continue

            clear_output(wait=True)

    return ami_dict


def obtain_avg_ami (ami_results, conditions, columns):
    """
    Get the average ami values within or across conditions

    Parameters:
        ami_results: dict
            Dictionary containing AMI results obtained from run_ami
        conditions: list 
            Conditions to get average AMI for.
            If all=True in run_ami, this should = ['all']
        columns: list
            List of columns to evaluate.

    Returns:
        dict: dict 
            A dictionary with average AMI values. 
    """

    avg_dict = {condition: {column: [] for column in columns} for condition in conditions}

    for condition in conditions:
        for column in columns:
            series_list = ami_results[condition][column]
            
            if series_list:
                # Find the shortest length among the AMI vectors
                min_len = min(len(s) for s in series_list)
                # Truncate each vector to the same length
                filtered = [s[:min_len] for s in series_list]
                
                data = np.array(filtered)
                averages = np.mean(data, axis=0)
                avg_dict[condition][column].append(averages.tolist())
            else:
                print(f"No data for condition '{condition}' and column '{column}'.")

    return avg_dict

In [None]:
# Run AMI on the time series

directory = "/.../Processed_Timeseries" # Set directory to processed data.
conditions = ['trial0', 'trial1', 'trial2'] # Set conditions
columns = ['headRel_ed_vel', 'body_ed_vel'] # Set columns to evaluate

ami_results = {} # Directory to store results 

ami_results = run_ami(dir=directory, columns=columns, conditions=conditions, min_lag=0, max_lag=100, all=True)


In [None]:
# Get averages.

ami_averages = {} # Directory to store average results

ami_averages = obtain_avg_ami (ami_results=ami_results, conditions=['all'], columns=columns)

In [None]:
# Plot Results

conditions=['all'] # If all=True in Run AMI, set to ['all'], otherwise use the conditions original list. 

for condition in conditions:
    for column in columns:
        ami_data_for_condition_column = {}

        #task_averages = ami_averages.get(condition, {})
        task_data = ami_averages.get(condition, {}).get(column)

        if task_data:
            ami_data_for_condition_column[condition] = task_data

        if not ami_data_for_condition_column:
            print(f"No data available for condition '{condition}' and column '{column}'.")
            continue

        plot_ami_multiple(
            ami_data_for_condition_column,
            condition=condition,
            column=column,
        )

# False Nearest Neighbours (FNN)

In [None]:
# Necessary imports and functions for FNN

import os
import pandas as pd
import numpy as np
from utils_dir.fnn_utils import fnn, plot_fnn
from IPython.display import clear_output

In [None]:
# Run the FNN Analysis

directory = "/.../Processed_Timeseries"  # Set directory to processed data.
plot_image = False # Plots image for each FNN. 

fnn_all = [] # Store all the FNN results.

file_names = os.listdir(directory)

for file_name in file_names:

    if not file_name.endswith('.csv'):
        continue
    
    print(f'Loading file: {file_name}')

    file_path = os.path.join(directory, file_name)
    data = pd.read_csv(file_path, sep=',')

    continuous_data = data['body_ed_vel'].dropna() # Select continuous time series and remove NaNs due to widnows velocity calculatiuons. 

    # Compute FNN (up to 10 dimensions)
    fnn_ds, fnn_percent = fnn(continuous_data, tlag=5, min_dimension=1, max_dimension=10)


    # Collect results
    fnn_all.append(fnn_percent)  # First 10 FNN percentages

    # Optional: plot
    if plot_image:
        plot_fnn(fnn_ds, fnn_percent)

    print('FNN computed successfully!')


# Plot average FNN across all files
if fnn_all:
    clear_output(wait=True)
    fnn_all = np.array(fnn_all)
    avg_fnn = np.mean(fnn_all, axis=0)
    print("Average FNN (first 10 dimensions):")
    plot_fnn(dimensions=fnn_ds, fnn_percentages=avg_fnn, header='Upper Body Movement')
    print(avg_fnn)
else:
    print("No FNN data computed.")

