# Behavioural classification from accelerometer timeseries using DL models

This script is used to run three classifiers: [LSTM](https://timeseriesai.github.io/tsai/models.rnn.html#lstm), [TSSequencer](https://timeseriesai.github.io/tsai/models.tssequencerplus.html) and [HydraMultiROCKET](https://timeseriesai.github.io/tsai/models.hydramultirocketplus.html) using the tsai package.  
Create the [tsai environment](../environment_tsai.yml) and use this for running this script.  
The functions are set up to run classifications on both multiclass (8 behavioural classes) and binary (activity/inactivity) target. See the [behavioural attribution notebook](03_combine_burst_attributions.ipynb) for details.


In [1]:
import gc
import glob
import os
import time

import numpy as np

# For optimisation
# import wandb
import pandas as pd
import torch

# from tsai.all import *
import tsai.all as ts
from fastai.interpret import ClassificationInterpretation

In [2]:
def load_acc_parquet(pq_file_path, binary_col, multiclass_col):
    """Load and preprocess accelerometer data from a parquet file with annotations.

    This function loads accelerometer data from a parquet file, merges it with
    annotation data, and performs basic cleaning operations such as removing
    rows marked for removal.

    Args:
        pq_file_path (str): Path to the parquet file containing accelerometer data.
        binary_col (str): Column name for binary classification (e.g., "activity").
        multiclass_col (str): Column name for multiclass classification (e.g., "attribution_merged").

    Returns:
        pd.DataFrame: Cleaned dataframe containing accelerometer data merged with annotations.
    """
    # Load parquet file
    acc = pd.read_parquet(pq_file_path)

    # Load annotations
    annotations = pd.read_csv(
        "../data/temp/annotations/attributions_merged_majority_outliers.csv"
    )

    # Identify burst from parquet file name
    pq_file_name = os.path.basename(pq_file_path)
    burst = pq_file_name.split("_")[4]
    burst = "burst_" + burst

    # Filter annotations for the current burst
    annotations = annotations[annotations["burst"] == burst]

    # Join annotations with accelerometer data
    acc = acc.merge(annotations, on=["Ind_ID", "new_burst"], how="left")

    # Remove other string columns except features and IDs
    acc = acc.drop(
        columns=acc.select_dtypes(include="object").columns.difference(
            [binary_col, multiclass_col, "feature", "id"]
        )
    )

    # Remove rows with "Remove" value
    # First obtain rows with remove in the binary and multiclass columns
    binary_remove = acc[acc[binary_col] == "Remove"].index.sort_values()
    multi_remove = acc[acc[multiclass_col] == "Remove"].index.sort_values()
    # Check if both are equal:
    are_equal = binary_remove.equals(multi_remove)
    if are_equal:
        # Drop rows with remove in both columns
        acc = acc.drop(index=binary_remove)
        print("Removed rows with 'Remove' in both classification columns")
    else:
        print("Remove indices are not equal, please check the data")

    # Return the cleaned dataframe
    return acc


# Test function
# pq_file_path = "../data/temp/acc_tsai/Burst_1/annotated_acc_tsai_burst_1_uncorrected.parquet"
# acc = load_acc_parquet(pq_file_path, "activity", "attribution_merged")

In [3]:
def create_dsets(acc, random_seed, behav_attribution_column, cols_to_drop=None):
    """Create a TSDatasets object from a DataFrame for time series classification.

    This function takes accelerometer data in a DataFrame format, processes it
    and converts it into a TSDatasets object suitable for training time series
    classification models with tsai. It handles splitting data into train/validation
    sets, categorizing labels, and type conversion.

    Args:
        acc (pd.DataFrame): DataFrame containing accelerometer data and labels.
        random_seed (int): Random seed for reproducible train/test splits.
        behav_attribution_column (str): Column name containing the target labels.
        cols_to_drop (list, optional): Columns to remove from the DataFrame before processing.

    Returns:
        TSDatasets: A dataset object ready for use with tsai DataLoaders.
    """
    # Ensure cols_to_drop is a list and add 'new_burst' if it exists in the dataframe
    if cols_to_drop is None:
        cols_to_drop = []
    elif not isinstance(cols_to_drop, list):
        cols_to_drop = [cols_to_drop]
    all_cols_to_drop = cols_to_drop.copy()
    if "new_burst" in acc.columns:
        all_cols_to_drop.append("new_burst")

    X, y = ts.df2xy(
        acc.drop(columns=all_cols_to_drop),
        sample_col="id",
        feat_col="feature",
        target_col=behav_attribution_column,
        data_cols=None,
    )
    # Print shape of x andy
    print(f"X shape: {X.shape}")
    print(f"y shape: {y.shape}")
    # Split into train and test
    splits = ts.get_splits(
        y,
        valid_size=0.2,
        stratify=True,
        random_state=random_seed,
        shuffle=True,
        show_plot=False,
    )

    # Function to ensure labels are categorised and only one column is chosen
    # def y_func(o):
    #     return o[
    #         :, 0
    #     ].astype(
    #         "<U20"
    #     )  # Only use first column, since labels are the same across the multivariate dataset
    def y_func(o):
        # Convert to string
        labels = o[:, 0].astype("<U20")

        # If binary task with "Active"/"Inactive", ensure "Active" is encoded as 1
        unique_labels = np.unique(labels)
        if (
            len(unique_labels) == 2
            and "Active" in unique_labels
            and "Inactive" in unique_labels
        ):
            # Map "Inactive" to 0, "Active" to 1
            return np.array([0 if label == "Inactive" else 1 for label in labels])

        return labels

    # Convert X to float
    X_mod = X.astype(np.float64)

    # Create datasets for use with dataloaders
    tfms = [None, [ts.Categorize()]]
    dsets = ts.TSDatasets(X_mod, y=y_func(y), tfms=tfms, splits=splits, inplace=True)

    return dsets


# Test function
# Test function
# acc_dset = create_dsets(
#     acc, 42, "attribution_merged", cols_to_drop=["activity"]
# )

In [4]:
def run_models_on_dataset(dset, target_type="binary"):
    """Run architecture comparison on a dataset

    Args:
        dset: TSDataset object
        target_type: Type of classification ("binary" or "multiclass")

    Returns:
        tuple: (metrics_df, confusion_df)
    """
    # Create dataloaders
    dls = ts.TSDataLoaders.from_dsets(dset.train, dset.valid, bs=[64, 128])

    # Determine metrics
    is_binary = len(dls.vocab) == 2
    roc_metric = ts.RocAucBinary() if is_binary else ts.RocAuc()
    metrics = [
        ts.accuracy,
        roc_metric,
        ts.BalancedAccuracy(),
        ts.F1Score(average="weighted"),
        ts.Precision(average="weighted"),
        ts.Recall(average="weighted"),
    ]

    # Define architectures to test
    archs = [
        (ts.HydraMultiRocket, {}),
        (ts.LSTM, {"n_layers": 6, "bidirectional": True}),
        (ts.TSSequencer, {}),
        # Add more as needed
    ]

    # Initialize results collections
    results = []
    confusion_data = []

    # Run each architecture
    for arch, params in archs:
        try:
            model = ts.create_model(arch, dls=dls, **params)
            learn = ts.Learner(dls, model, metrics=metrics)

            # Train model
            start = time.time()
            lr_max = learn.lr_find()
            learn.fit_one_cycle(50, lr_max)
            elapsed = time.time() - start

            # Get metrics
            vals = learn.recorder.values[-1]

            # Store metrics
            results.append(
                {
                    "model": arch.__name__,
                    "params": str(params),
                    "train_loss": vals[0],
                    "valid_loss": vals[1],
                    "accuracy": vals[2],
                    "roc_auc": vals[3],
                    "balanced_acc": vals[4],
                    "f1_score": vals[5],
                    "precision": vals[6],
                    "recall": vals[7],
                    "time": int(elapsed),
                }
            )

            # Process confusion matrix
            interp = ClassificationInterpretation.from_learner(learn)
            conf_matrix = interp.confusion_matrix()
            classes = dset.vocab

            for actual_idx, actual in enumerate(classes):
                for pred_idx, predicted in enumerate(classes):
                    confusion_data.append(
                        {
                            "model": arch.__name__,
                            "actual_label": actual,
                            "predicted_label": predicted,
                            "count": conf_matrix[actual_idx, pred_idx],
                        }
                    )

        except Exception as e:
            print(f"Error with {arch.__name__}: {str(e)}")

    # Make dataframes
    results_df = pd.DataFrame(results)
    confusion_df = pd.DataFrame(confusion_data)

    # For binary confusion matrix replace 0/1 with labels
    if target_type == "binary":
        confusion_df["actual_label"] = confusion_df["actual_label"].replace(
            {0: "Inactive", 1: "Active"}
        )
        confusion_df["predicted_label"] = confusion_df["predicted_label"].replace(
            {0: "Inactive", 1: "Active"}
        )

    return results_df, confusion_df

In [5]:
def process_parquet_file(file_path, random_seed_list, binary_col, multiclass_col):
    """Process a single parquet file with multiple seeds

    Args:
        file_path: Path to parquet file
        random_seed_list: List of random seeds

    Returns:
        dict: Dict with all results DataFrames
    """
    # Extract metadata from path
    burst = os.path.basename(os.path.dirname(file_path))  # from folder
    correction_type = (
        os.path.basename(file_path).split("_")[-1].split(".")[0]
    )  # from filename
    print(f"Processing {correction_type} for {burst}")

    # Initialize result collections
    binary_metrics_list = []
    binary_conf_list = []
    multiclass_metrics_list = []
    multiclass_conf_list = []

    # Process each seed
    for seed in random_seed_list:
        print(f"  Processing with seed {seed}")

        # Load data
        acc_data = load_acc_parquet(file_path, binary_col, multiclass_col)

        # Create datasets
        binary_dset = create_dsets(
            acc_data, seed, binary_col, cols_to_drop=[multiclass_col]
        )
        multiclass_dset = create_dsets(
            acc_data, seed, multiclass_col, cols_to_drop=[binary_col]
        )

        # Run models on binary data
        binary_metrics, binary_conf = run_models_on_dataset(binary_dset, "binary")

        # Add metadata
        binary_metrics["burst"] = burst
        binary_metrics["correction_type"] = correction_type
        binary_metrics["random_seed"] = seed
        binary_metrics["target"] = "Activity"

        binary_conf["burst"] = burst
        binary_conf["correction_type"] = correction_type
        binary_conf["random_seed"] = seed
        binary_conf["target"] = "Activity"

        # Run models on multiclass data
        multiclass_metrics, multiclass_conf = run_models_on_dataset(
            multiclass_dset, "multiclass"
        )

        # Add metadata
        multiclass_metrics["burst"] = burst
        multiclass_metrics["correction_type"] = correction_type
        multiclass_metrics["random_seed"] = seed
        multiclass_metrics["target"] = "Behaviour"

        multiclass_conf["burst"] = burst
        multiclass_conf["correction_type"] = correction_type
        multiclass_conf["random_seed"] = seed
        multiclass_conf["target"] = "Behaviour"

        # Append to result lists
        binary_metrics_list.append(binary_metrics)
        binary_conf_list.append(binary_conf)
        multiclass_metrics_list.append(multiclass_metrics)
        multiclass_conf_list.append(multiclass_conf)

    # Define the common columns to appear first
    first_columns = ["burst", "correction_type", "target", "random_seed"]

    # Concatenate and reorder binary metrics
    binary_metrics_df = pd.concat(binary_metrics_list)
    if not binary_metrics_df.empty:
        remaining_cols = [
            col for col in binary_metrics_df.columns if col not in first_columns
        ]
        binary_metrics_df = binary_metrics_df[first_columns + remaining_cols]

    # Concatenate and reorder binary confusion matrix
    binary_conf_df = pd.concat(binary_conf_list)
    if not binary_conf_df.empty:
        remaining_cols = [
            col for col in binary_conf_df.columns if col not in first_columns
        ]
        binary_conf_df = binary_conf_df[first_columns + remaining_cols]

    # Concatenate and reorder multiclass metrics
    multiclass_metrics_df = pd.concat(multiclass_metrics_list)
    if not multiclass_metrics_df.empty:
        remaining_cols = [
            col for col in multiclass_metrics_df.columns if col not in first_columns
        ]
        multiclass_metrics_df = multiclass_metrics_df[first_columns + remaining_cols]

    # Concatenate and reorder multiclass confusion matrix
    multiclass_conf_df = pd.concat(multiclass_conf_list)
    if not multiclass_conf_df.empty:
        remaining_cols = [
            col for col in multiclass_conf_df.columns if col not in first_columns
        ]
        multiclass_conf_df = multiclass_conf_df[first_columns + remaining_cols]

    # Return reordered DataFrames
    return {
        "binary_metrics": binary_metrics_df,
        "binary_conf": binary_conf_df,
        "multiclass_metrics": multiclass_metrics_df,
        "multiclass_conf": multiclass_conf_df,
    }

In [6]:
def run_model_comparisons(folder_path, random_seed_list, correction_filters=None):
    """Run model comparisons across multiple parquet files

    Args:
        folder_path: Path to folder containing parquet files
        random_seed_list: List of random seeds
        max_files: Maximum files to process (for testing)

    Returns:
        dict: Dict with combined results DataFrames
    """
    # Find all parquet files
    all_files = glob.glob(os.path.join(folder_path, "**/*.parquet"), recursive=True)
    # Apply correction type filters if provided
    if correction_filters is not None:
        all_files = [
            f
            for f in all_files
            if any(corr_type in os.path.basename(f) for corr_type in correction_filters)
        ]

    print(f"Found {len(all_files)} parquet files")

    # Process all files and collect results
    all_results = []
    for file_path in all_files:
        result = process_parquet_file(
            file_path, random_seed_list, "activity", "attribution_merged"
        )
        all_results.append(result)

    # Combine results
    combined_results = {
        "binary_metrics": pd.concat([r["binary_metrics"] for r in all_results]),
        "binary_conf": pd.concat([r["binary_conf"] for r in all_results]),
        "multiclass_metrics": pd.concat([r["multiclass_metrics"] for r in all_results]),
        "multiclass_conf": pd.concat([r["multiclass_conf"] for r in all_results]),
    }

    return combined_results

### Test functions sequentially

In [None]:
# Test loading files
folder_location = "../data/temp/acc_tsai/Burst_4/"

burst = folder_location.split("/")[-2]
print("DL training started for", burst)
correction_types_to_process = ["uncorrected", "rotdaily", "rotbasal"]

# Find all parquet files
all_files = glob.glob(os.path.join(folder_location, "**/*.parquet"), recursive=True)
# Apply correction type filters if provided
if correction_types_to_process is not None:
    all_files = [
        f
        for f in all_files
        if any(
            corr_type in os.path.basename(f)
            for corr_type in correction_types_to_process
        )
    ]

print(all_files)

filetypes = [os.path.basename(f).split("_")[-1].split(".")[0] for f in all_files]
print(filetypes)


DL training started for Burst_4
['../data/temp/acc_tsai/Burst_4\\annotated_acc_tsai_burst_4_uncorrected.parquet']
['uncorrected']


In [None]:
# Test for one file
pq_file_path = (
    "../data/temp/acc_tsai/Burst_1/annotated_acc_tsai_burst_1_uncorrected.parquet"
)
random_seed_list = [42]
b1_unc = process_parquet_file(
    pq_file_path, random_seed_list, "activity", "attribution_merged"
)

In [None]:
b1_unc_bin_met = b1_unc["binary_metrics"]
b1_unc_beh_met = b1_unc["multiclass_metrics"]
b1_unc_bin_conf = b1_unc["binary_conf"]
b1_unc_beh_conf = b1_unc["multiclass_conf"]

## Run function over multiple files

In [7]:
pq_folder_locations = [
    "../data/temp/acc_tsai/Burst_1/",
    "../data/temp/acc_tsai/Burst_2/",
    "../data/temp/acc_tsai/Burst_3/",
    "../data/temp/acc_tsai/Burst_4/",
]
# Random seed list
random_seeds = [42, 100, 123, 1234, 123456]
correction_types_to_process = ["uncorrected", "rotdaily", "rotbasal"]

In [None]:
all_burst_results = [
    run_model_comparisons(folder, random_seeds, correction_types_to_process)
    for folder in pq_folder_locations
]

In [9]:
# Obtain individual dataframes for results and confusion matrices
binary_metrics_list = [result["binary_metrics"] for result in all_burst_results]
binary_conf_list = [result["binary_conf"] for result in all_burst_results]
multiclass_metrics_list = [result["multiclass_metrics"] for result in all_burst_results]
multiclass_conf_list = [result["multiclass_conf"] for result in all_burst_results]

# Concatenate all results into single dataframes
binary_metrics = pd.concat(binary_metrics_list, ignore_index=True)
binary_conf = pd.concat(binary_conf_list, ignore_index=True)
multiclass_metrics = pd.concat(multiclass_metrics_list, ignore_index=True)
multiclass_conf = pd.concat(multiclass_conf_list, ignore_index=True)

In [None]:
# Output to CSV files
# binary_metrics.to_csv(
#     "../data/output/activity_comparison/activity_acceleration_dl_metrics.csv",
#     index=False,
# )
# binary_conf.to_csv(
#     "../data/output/activity_comparison/activity_acceleration_dl_confusion.csv",
#     index=False,
# )
multiclass_metrics.to_csv(
    "../data/output/behaviour_comparison/behaviour_acceleration_dl_metrics.csv",
    index=False,
)
multiclass_conf.to_csv(
    "../data/output/behaviour_comparison/behaviour_acceleration_dl_confusion.csv",
    index=False,
)