In [16]:
import altair as alt
import h5py
from pathlib import Path
import numpy as np
import polars as pl
import yaml
from tqdm import tqdm
import pprint
import random
import sys
print(sys.executable)
alt.data_transformers.enable("vegafusion")

C:\Users\f0397\miniforge3\envs\neural-feature-identification-pipeline\python.exe


DataTransformerRegistry.enable('vegafusion')

In [17]:
def load_project_config(config_path: Path) -> dict:
    with open(config_path, "r") as f:
        return yaml.safe_load(f)

In [18]:
def read_hdf_datasets(filepath: Path, keys: list[str]) -> dict:
    """
    Reads specified datasets from a single HDF5 file.
    :param filepath:
    :param keys:
    :return:
    """
    if not filepath.exists():
        print(f"Warning File not found at {filepath}. Skipping.")
        return {}

    data = {}
    with h5py.File(filepath, 'r') as f:
        for key in keys:
            if key in f:
                data[key] = f[key][:]
            else:
                print(f"Warning. Key '{key}' not found in {filepath.name}.")
    return data

In [19]:
def tidy_and_transform_data(session_data: dict, project_config: dict) -> pl.DataFrame:
    """
    Structures the data into a tidy, long-format DataFrame for VIS with Altair.
    Transforms the data for VIS by downsampling and applying a normalization transformation to each channel.
    :param project_config:
    :param session_data:
    :return:
    """
    # Create initial nested DataFrame
    df_constructor = {
        'timestamps': all_data['kinematics']['nip_time'].T,
        'kinematics': all_data['kinematics']['kinematics'].T,
    }
    feature_names = project_config['analysis']['feature_sets']
    for name in feature_names:
        key_name = name.lower()
        if 'features' in session_data.get(key_name, {}):
            df_constructor[name] = session_data[key_name]['features'].T
    nested_df = pl.DataFrame(df_constructor)

    # Flatten timestamp column
    flat_df = nested_df.explode('timestamps')

    # Dynamically create expression to unnest each array column into individual columns
    unnested_expressions = []
    for name in ['kinematics'] + feature_names:
        col_name = name if name != 'kinematics' else 'kinematics'
        if col_name in flat_df.columns:
            list_len = flat_df.select(pl.col(col_name).arr.len().first()).item()
            if list_len is not None:
                unnested_expressions.extend(
                    [pl.col(col_name).arr.get(i).alias(f"{name}_{i+1}") for i in range(list_len)]
                )

    # Build the wide DataFrame by applying the unnesting expressions
    wide_df = flat_df.select(
        pl.col('timestamps'),
        *unnested_expressions
    )

    # Aggregate/downsample data to prevent browser memory issues
    print(f"Original data has {len(wide_df)} timestamps")
    plt_point_count = project_config['vis']['num_of_x_points']
    time_range = wide_df['timestamps'].max() - wide_df['timestamps'].min()
    resampling_interval = time_range / plt_point_count
    if resampling_interval < 1: resampling_interval = 1
    print(f"Resampling data by averaging over {resampling_interval:0.2f} timestamp units")
    wide_df_resampled = wide_df.group_by(
        (pl.col("timestamps") // resampling_interval).alias("time_bin")
    ).agg(pl.all().mean()).drop("time_bin")
    print(f"Resampled data has {len(wide_df_resampled)} timestamps")

    # Melt (wide2long/unpivot transform) the wide DataFrame into a long, tidy format
    tidy_df = wide_df_resampled.unpivot(index=['timestamps'], variable_name='feature_id', value_name='value')

    # Add a column for easy filtering ('feature_type': 'kinematics', 'nfr', ...)
    tidy_df = tidy_df.with_columns(
        pl.col('feature_id').str.split_exact(by='_', n=1).struct.field('field_0').alias('feature_type')
    )

    # Channel by channel normalization to improve heatmap visibility
    kinematics_df  = tidy_df.filter(pl.col('feature_type') == 'kinematics')
    features_df = tidy_df.filter(pl.col('feature_type') != 'kinematics')
    features_df_normalized = features_df.with_columns(
        min_val=pl.min('value').over('feature_id'),
        max_val=pl.max('value').over('feature_id')
    ).with_columns(
        range_val=(pl.col('max_val') - pl.col('min_val'))
    ).with_columns(
        norm_val=pl.when(pl.col('range_val')>0)
                 .then((pl.col('value') - pl.col('min_val')) / pl.col('range_val'))
                 .otherwise(0.0)
    ).with_columns(
        value=(pl.col('norm_val') * pl.col('max_val').sqrt()).fill_nan(0)
    ).drop('min_val', 'max_val', 'range_val', 'norm_val')
    tidy_df_transformed = pl.concat([kinematics_df, features_df_normalized])

    return tidy_df_transformed

In [20]:
def make_kinematics_plot(plt_df: pl.DataFrame, project_config: dict, x_domain: list) -> alt.Chart():
    """
    Creates a stacked line chart of the kinematic labels
    :param plt_df:
    :param project_config:
    :return:
    """
    kinematics_df = plt_df.filter(pl.col('feature_type') == 'kinematics')
    plt_offset = project_config['vis']['kinematics_offset']
    kinematics_df_offset = kinematics_df.with_columns(
        pl.col('feature_id').str.split_exact(by='_', n=1).struct.field('field_1').cast(pl.Int32).alias('dof_id')
    ).with_columns(
        (pl.col('value') + (pl.col('dof_id')*plt_offset)).alias('plot_value')
    )
    plt_kinematics =  alt.Chart(kinematics_df_offset).mark_line().encode(
        x=alt.X('timestamps:Q', title='Time (NIP Units)', scale=alt.Scale(zero=False, domain=x_domain)),
        y=alt.Y('plot_value:Q', title='Kinematic Position (Offset)', axis=alt.Axis(labels=False, ticks=False, grid=False)),
        color=alt.Color('feature_id:N', title="DOF ID", sort=alt.EncodingSortField(field='dof_id', order='descending')),
    ).properties(
        width=1800,
        height=360,
    )

    return plt_kinematics

def make_event_markers_plot(trial_start_stamps: np.ndarray, trial_stop_stamps: np.ndarray, x_domain: list) -> alt.Chart():
    starts_df = pl.DataFrame({'timestamp': trial_start_stamps.flatten(), 'event': 'start'})
    stops_df = pl.DataFrame({'timestamp': trial_stop_stamps.flatten(), 'event': 'stop'})
    events_df = pl.concat([starts_df, stops_df])
    plt_event_markers = alt.Chart(events_df).mark_rule(strokeDash=[4, 4], size=2).encode(
        x=alt.X('timestamp:Q', scale=alt.Scale(zero=False, domain=x_domain)),
        color=alt.Color(
            'event:N',
            scale=alt.Scale(domain=['start', 'stop'], range=['green', 'red']),
        )
    ).properties(
        width=1800,
        height=36,
    )


    return plt_event_markers

In [21]:
def make_features_heatmap(plt_df: pl.DataFrame, feature_type: str, color_scheme: str, selected_chans: list[str] = None) -> alt.Chart:
    feature_data = plt_df.filter(pl.col('feature_type') == feature_type)

    if selected_chans:
        feature_data = feature_data.filter(pl.col('feature_id').is_in(selected_chans))

    return alt.Chart(feature_data).mark_rect().encode(
        x=alt.X('timestamps:Q', title='Time (NIP Units)', scale=alt.Scale(zero=False)),
        y=alt.Y('feature_id:O', title='Feature Index', sort=None, axis=alt.Axis(labels=False, ticks=False)),
        detail='feature_id:N',
    ).properties(
        title=f'{feature_type.upper()} Features Vs NIP Time',
        width=1800,
        height=720
    )

def make_feature_line_plot(plt_df: pl.DataFrame, feature_type: str, selected_channels: list[str] = None, x_domain: list = None) -> alt.Chart:
    """
    Create a line chart for a given feature set with superimposed, transparent channels
    :param plt_df:
    :param feature_type:
    :return:
    """
    feature_data = plt_df.filter(pl.col('feature_type') == feature_type)

    # Apply subsampling if a list of selected channels is provided. Intended for the DWT feature set
    if selected_channels:
        feature_data = feature_data.filter(pl.col('feature_id').is_in(selected_channels))

    return alt.Chart(feature_data).mark_line(opacity=0.12).encode(
        x=alt.X('timestamps:Q', title='Time (NIP Units)', scale=alt.Scale(zero=False, domain=x_domain)),
        y=alt.Y('value:Q', title=f'{feature_type.upper()} Normalized Activation', axis=alt.Axis(labels=False, ticks=False,grid=False)),
        detail='feature_id:N',
    ).properties(width=1800, height=360)


In [22]:
#def main():
config = load_project_config(Path("../config.yaml"))
BASE_DIR = Path("C:/Users/f0397/Desktop/75")

files_to_load = {
    'events': {'path': BASE_DIR / "events.h5", 'keys': ['trial_start_idxs', 'trial_stop_idxs']},
    'kinematics': {'path': BASE_DIR / "kinematics.h5", 'keys': ['kinematics', 'nip_time']},
}
for feature_name in config['analysis']['feature_sets']:
    key_name = feature_name.lower()
    files_to_load[key_name] = {'path': BASE_DIR / "features" / f"{feature_name}.h5", 'keys': ['features']}

all_data = {}
for name, details in tqdm(files_to_load.items(), desc='Loading data files'):
    all_data[name] = read_hdf_datasets(details['path'], details['keys'])

Loading data files: 100%|██████████| 6/6 [00:00<00:00, 62.96it/s]


In [23]:
tidy_df = tidy_and_transform_data(all_data, config)
print(tidy_df['feature_type'].unique())
tidy_df

Original data has 13586 timestamps
Resampling data by averaging over 20384.84 timestamp units
Resampled data has 661 timestamps
shape: (5,)
Series: 'feature_type' [str]
[
	"DWT-DB4"
	"kinematics"
	"NFR"
	"SBP-RAW"
	"MAV"
]


timestamps,feature_id,value,feature_type
f64,str,f64,str
4.1452e7,"""kinematics_1""",0.0,"""kinematics"""
4.1004e7,"""kinematics_1""",0.0,"""kinematics"""
4.3002e7,"""kinematics_1""",0.0,"""kinematics"""
3.7029e7,"""kinematics_1""",0.0,"""kinematics"""
3.9720e7,"""kinematics_1""",0.0,"""kinematics"""
…,…,…,…
4.2227444e7,"""MAV_192""",0.471449,"""MAV"""
3.2728e7,"""MAV_192""",0.305229,"""MAV"""
3.8436e7,"""MAV_192""",0.363485,"""MAV"""
4.3124e7,"""MAV_192""",0.46972,"""MAV"""


In [24]:
random.seed(42)
selected_channels_map = {}
dwt_subsample_count = config['vis']['dwt_subsample_count']
dwt_df = tidy_df.filter(pl.col('feature_type') == 'DWT-DB4')
all_dwt_channels = dwt_df.get_column('feature_id').unique().to_list()
if len(all_dwt_channels) > dwt_subsample_count:
    selected_channels_map['dwt-db4'] = random.sample(all_dwt_channels, dwt_subsample_count)

In [25]:
start_stamps = all_data.get('events', {}).get('trial_start_idxs', np.array([]))
stop_stamps = all_data.get('events', {}).get('trial_stop_idxs', np.array([]))

In [26]:
x_domain = [tidy_df['timestamps'].min(), tidy_df['timestamps'].max()]
x_domain

[31978596.6, 45423141.45454545]

In [27]:
plt_kinematics = make_kinematics_plot(tidy_df, config, x_domain=x_domain)
plt_events = make_event_markers_plot(start_stamps, stop_stamps, x_domain=x_domain)
plt_features = []
for f_name in config['analysis']['feature_sets']:
    feature_key = f_name.lower()
    selected_channels = selected_channels_map.get(feature_key)
    plt = make_feature_line_plot(tidy_df, f_name, selected_channels, x_domain=x_domain)
    plt_features.append(plt)

In [28]:
final_chart = alt.vconcat(
    plt_kinematics,
    plt_events,
    *plt_features,
).resolve_scale(
    color='independent',
).properties(
)

In [29]:
final_chart.save("session_vis.html")