# Behavioural classification from features using DL model (TabPFN)

This script is used to run the [TabPFN classifier](https://github.com/PriorLabs/TabPFN)  
Create the [tabpfn environment](../environment_tabpfn.yml) and use this for running this script.  
The functions are set up to run classifications on both multiclass (8 behavioural classes) and binary (activity/inactivity) target. See the [behavioural attribution notebook](03_combine_burst_attributions.ipynb) for details.


In [1]:
import glob
import os
import time
import warnings

import pandas as pd
from sklearn.metrics import (
    balanced_accuracy_score,
    confusion_matrix,
    f1_score,
    precision_score,
    recall_score,
    roc_auc_score,
)
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder
from tabpfn import TabPFNClassifier


In [2]:
def load_features_parquet(pq_file_path, binary_col, multiclass_col):
    """Load and preprocess feature data from a parquet file with annotations.

    Args:
        pq_file_path (str): Path to the parquet file containing feature data.
        binary_col (str): Column name for binary classification (e.g., "activity").
        multiclass_col (str): Column name for multiclass classification (e.g., "attribution_merged").

    Returns:
        pd.DataFrame: Cleaned dataframe containing features merged with annotations.
    """
    # Load parquet file
    features = pd.read_parquet(pq_file_path)

    # Load annotations
    annotations = pd.read_csv(
        "../data/temp/annotations/attributions_merged_majority_outliers.csv"
    )

    # Identify burst from parquet file name
    pq_file_name = os.path.basename(pq_file_path)
    burst = pq_file_name.split("_")[3]
    burst = "burst_" + burst
    # print(burst)

    # Filter annotations for the current burst
    annotations = annotations[annotations["burst"] == burst]

    # Join annotations with feature data
    features = features.merge(annotations, on=["Ind_ID", "new_burst"], how="left")

    # Remove other string columns except features and IDs
    features = features.drop(
        columns=features.select_dtypes(include="object").columns.difference(
            [binary_col, multiclass_col, "Ind_ID"]
        )
    )

    # Remove rows with "Remove" value
    # First obtain rows with remove in the binary and multiclass columns
    binary_remove = features[features[binary_col] == "Remove"].index.sort_values()
    multi_remove = features[features[multiclass_col] == "Remove"].index.sort_values()
    # Check if both are equal:
    are_equal = binary_remove.equals(multi_remove)
    if are_equal:
        # Drop rows with remove in both columns
        features = features.drop(index=binary_remove)
        print("Removed rows with 'Remove' in both classification columns")
    else:
        print("Remove indices are not equal, please check the data")

    # Return the cleaned dataframe
    return features

In [3]:
def features_tabpfn_model(X, y, random_seed, is_binary=False):
    """Train and evaluate ML models on feature data, extracting both metrics and confusion matrices.

    Args:
        X (pd.DataFrame): Feature matrix
        y (pd.Series): Target variable (either binary or multiclass)
        random_seed (int): Random seed for train/test split
        is_binary (bool): Whether this is binary classification

    Returns:
        tuple: (metrics_df, confusion_df) with model performance metrics and confusion matrix data
    """
    # Encode target variables
    target_encoder = LabelEncoder()
    y_encoded = target_encoder.fit_transform(y)

    # Special handling for binary classification to ensure proper encoding
    if is_binary and len(target_encoder.classes_) == 2:
        # Ensure "Active" is encoded as 1 and "Inactive" as 0 if these are the class names
        if (
            "Active" in target_encoder.classes_
            and "Inactive" in target_encoder.classes_
        ):
            if target_encoder.transform(["Active"])[0] != 1:
                target_encoder.classes_ = target_encoder.classes_[::-1]
                y_encoded = target_encoder.transform(y)

        # Train and test split
    X_train, X_test, y_train, y_test = train_test_split(
        X, y_encoded, stratify=y_encoded, test_size=0.25, random_state=random_seed
    )

    # Define the models
    models = {
        "TabPFN": TabPFNClassifier(random_state=42),
    }

    # Create lists to store the results and confusion matrix data
    metrics_results = []
    confusion_data = []

    # Iterate over the models
    with warnings.catch_warnings():
        warnings.simplefilter("ignore")
        for model_name, model in models.items():
            start_time = time.time()
            # Fit the model
            model.fit(X_train, y_train)

            # Predictions
            y_pred = model.predict(X_test)
            y_pred_proba = model.predict_proba(X_test)

            # Calculate accuracy
            accuracy = model.score(X_test, y_test)

            # Calculate ROC AUC - handle binary and multiclass differently
            if is_binary:
                roc_auc = roc_auc_score(y_test, y_pred_proba[:, 1])
            else:
                roc_auc = roc_auc_score(
                    y_test, y_pred_proba, multi_class="ovr", average="weighted"
                )

            # Calculate balanced accuracy
            balanced_accuracy = balanced_accuracy_score(y_test, y_pred)

            # Calculate F1 score
            f1 = f1_score(y_test, y_pred, average="weighted")

            # Calculate precision
            precision = precision_score(
                y_test, y_pred, average="weighted", zero_division=0
            )

            # Calculate recall
            recall = recall_score(y_test, y_pred, average="weighted")

            train_time = time.time() - start_time

            # Append the metrics to the results list
            metrics_results.append(
                {
                    "model": model_name,
                    "roc_auc": roc_auc,
                    "accuracy": accuracy,
                    "balanced_acc": balanced_accuracy,
                    "f1_score": f1,
                    "precision": precision,
                    "recall": recall,
                    "time": train_time,
                }
            )

            # Get and process confusion matrix
            cm = confusion_matrix(y_test, y_pred)

            # Convert confusion matrix to DataFrame format
            for i in range(len(cm)):
                for j in range(len(cm)):
                    confusion_data.append(
                        {
                            "model": model_name,
                            "actual_label": i,
                            "predicted_label": j,
                            "count": cm[i][j],
                        }
                    )

            print(f"{model_name} Done")

    # Convert results to DataFrames
    metrics_df = pd.DataFrame(metrics_results)
    confusion_df = pd.DataFrame(confusion_data)

    # Map the encoded labels back to original classes if possible
    original_classes = target_encoder.classes_
    if len(confusion_df) > 0:
        confusion_df["actual_label"] = confusion_df["actual_label"].apply(
            lambda x: original_classes[x] if x < len(original_classes) else x
        )
        confusion_df["predicted_label"] = confusion_df["predicted_label"].apply(
            lambda x: original_classes[x] if x < len(original_classes) else x
        )

    return metrics_df, confusion_df

In [4]:
def process_parquet_file(
    file_path,
    random_seed_list,
    binary_col="activity",
    multiclass_col="attribution_merged",
):
    """Process a single parquet file for both binary and multiclass classification.

    Args:
        file_path (str): Path to the parquet file
        random_seed_list (list): List of random seeds to use
        binary_col (str): Column name for binary classification
        multiclass_col (str): Column name for multiclass classification

    Returns:
        dict: Dictionary containing all results DataFrames
    """
    # Extract metadata from path
    burst = os.path.basename(os.path.dirname(file_path))  # from folder
    correction_type = (
        os.path.basename(file_path).split("_")[-1].split(".")[0]
    )  # from filename
    print(f"Processing {correction_type} for {burst}")

    # Initialize result collections
    binary_metrics_list = []
    binary_conf_list = []
    multiclass_metrics_list = []
    multiclass_conf_list = []

    # Load feature data once
    features_df = load_features_parquet(file_path, binary_col, multiclass_col)

    # Process each seed
    for seed in random_seed_list:
        print(f"  Processing with seed {seed}")

        # Prepare features for binary classification
        X = features_df.drop(
            [binary_col, multiclass_col, "new_burst", "Ind_ID"], axis=1, errors="ignore"
        )

        # Binary classification
        print(f"    Running binary classification ({binary_col})")
        y_binary = features_df[binary_col]
        binary_metrics, binary_conf = features_tabpfn_model(
            X, y_binary, seed, is_binary=True
        )

        # Add metadata to binary results
        binary_metrics["correction_type"] = correction_type
        binary_metrics["burst"] = burst
        binary_metrics["target"] = "Activity"
        binary_metrics["random_seed"] = seed

        binary_conf["correction_type"] = correction_type
        binary_conf["burst"] = burst
        binary_conf["target"] = "Activity"
        binary_conf["random_seed"] = seed

        # Multiclass classification
        print(f"    Running multiclass classification ({multiclass_col})")
        y_multiclass = features_df[multiclass_col]
        multiclass_metrics, multiclass_conf = features_tabpfn_model(
            X, y_multiclass, seed, is_binary=False
        )

        # Add metadata to multiclass results
        multiclass_metrics["correction_type"] = correction_type
        multiclass_metrics["burst"] = burst
        multiclass_metrics["target"] = "Behaviour"
        multiclass_metrics["random_seed"] = seed

        multiclass_conf["correction_type"] = correction_type
        multiclass_conf["burst"] = burst
        multiclass_conf["target"] = "Behaviour"
        multiclass_conf["random_seed"] = seed

        # Append to result lists
        binary_metrics_list.append(binary_metrics)
        binary_conf_list.append(binary_conf)
        multiclass_metrics_list.append(multiclass_metrics)
        multiclass_conf_list.append(multiclass_conf)

    # Combine results
    all_binary_metrics = (
        pd.concat(binary_metrics_list) if binary_metrics_list else pd.DataFrame()
    )
    all_binary_conf = (
        pd.concat(binary_conf_list) if binary_conf_list else pd.DataFrame()
    )
    all_multiclass_metrics = (
        pd.concat(multiclass_metrics_list)
        if multiclass_metrics_list
        else pd.DataFrame()
    )
    all_multiclass_conf = (
        pd.concat(multiclass_conf_list) if multiclass_conf_list else pd.DataFrame()
    )

    # Standard column order for metrics
    metric_columns = [
        "burst",
        "correction_type",
        "target",
        "random_seed",
        "model",
        "roc_auc",
        "accuracy",
        "balanced_acc",
        "f1_score",
        "precision",
        "recall",
        "time",
    ]

    # Standard column order for confusion matrices
    conf_columns = [
        "burst",
        "correction_type",
        "target",
        "random_seed",
        "model",
        "actual_label",
        "predicted_label",
        "count",
    ]

    # Reorder columns
    if not all_binary_metrics.empty:
        all_binary_metrics = all_binary_metrics[metric_columns]
    if not all_binary_conf.empty:
        all_binary_conf = all_binary_conf[conf_columns]
    if not all_multiclass_metrics.empty:
        all_multiclass_metrics = all_multiclass_metrics[metric_columns]
    if not all_multiclass_conf.empty:
        all_multiclass_conf = all_multiclass_conf[conf_columns]

    # Return results dictionary
    return {
        "binary_metrics": all_binary_metrics,
        "binary_conf": all_binary_conf,
        "multiclass_metrics": all_multiclass_metrics,
        "multiclass_conf": all_multiclass_conf,
    }

In [5]:
def wrap_run_feat_tabpfn(
    folder_location,
    random_seed_list,
    binary_col="activity",
    multiclass_col="attribution_merged",
    correction_filters=None,
):
    """Process all parquet files in a folder with multiple random seeds.

    Args:
        folder_location (str): Path to folder containing parquet files
        random_seed_list (list): List of random seeds to use
        binary_col (str): Column name for binary classification
        multiclass_col (str): Column name for multiclass classification

    Returns:
        dict: Dictionary with combined results DataFrames
    """
    filetype = folder_location.split("/")[-2]
    print(f"TabPFN training started for {filetype}")

    # Find all parquet files
    all_files = [
        os.path.normpath(f).replace("\\", "/")
        for f in glob.glob(folder_location + "**/*.parquet", recursive=True)
    ]

    # Filter by correction type
    if correction_filters is not None:
        all_files = [
            f
            for f in all_files
            if any(corr.lower() in f.lower() for corr in correction_filters)
        ]

    print(f"Found {len(all_files)} parquet files")

    # Process all files and collect results
    all_results = []
    for file_path in all_files:
        result = process_parquet_file(
            file_path, random_seed_list, binary_col, multiclass_col
        )
        all_results.append(result)

    # Combine results from all files
    combined_results = {
        "binary_metrics": pd.concat(
            [r["binary_metrics"] for r in all_results if not r["binary_metrics"].empty]
        ),
        "binary_conf": pd.concat(
            [r["binary_conf"] for r in all_results if not r["binary_conf"].empty]
        ),
        "multiclass_metrics": pd.concat(
            [
                r["multiclass_metrics"]
                for r in all_results
                if not r["multiclass_metrics"].empty
            ]
        ),
        "multiclass_conf": pd.concat(
            [
                r["multiclass_conf"]
                for r in all_results
                if not r["multiclass_conf"].empty
            ]
        ),
    }

    return combined_results

### Test functions sequentially

In [10]:
# Test for one file
pq_file_path = (
    "../data/raw/features/Burst_1/annotated_features_burst_1_uncorrected.parquet"
)
random_seed_list = [42]
b3_unc = process_parquet_file(
    pq_file_path, random_seed_list, "activity", "attribution_merged"
)

Processing uncorrected for Burst_1
burst_1
Removed rows with 'Remove' in both classification columns
  Processing with seed 42
    Running binary classification (activity)...
TabPFN Done
    Running multiclass classification (attribution_merged)...
TabPFN Done


In [11]:
b3_unc_bin_metrics = b3_unc["binary_metrics"]
b3_unc_bin_conf = b3_unc["binary_conf"]
b3_unc_multi_metrics = b3_unc["multiclass_metrics"]
b3_unc_multi_conf = b3_unc["multiclass_conf"]

### Run over whole folders

In [6]:
random_seeds = [42, 100, 123, 1234, 123456]
# random_seeds = [42]
correction_types_to_process = ["uncorrected", "rotdaily", "rotbasal"]
folder_locations = [
    "../data/raw/features/Burst_1/",
    "../data/raw/features/Burst_2/",
    "../data/raw/features/Burst_3/",
    "../data/raw/features/Burst_4/",
]

In [8]:
# Run the ML training for all files in the specified folders
all_burst_results = [
    wrap_run_feat_tabpfn(
        folder, random_seeds, correction_filters=correction_types_to_process
    )
    for folder in folder_locations
]

TabPFN training started for Burst_1
Found 3 parquet files
Processing rotbasal for Burst_1
Removed rows with 'Remove' in both classification columns
  Processing with seed 42
    Running binary classification (activity)
TabPFN Done
    Running multiclass classification (attribution_merged)
TabPFN Done
  Processing with seed 100
    Running binary classification (activity)
TabPFN Done
    Running multiclass classification (attribution_merged)
TabPFN Done
  Processing with seed 123
    Running binary classification (activity)
TabPFN Done
    Running multiclass classification (attribution_merged)
TabPFN Done
  Processing with seed 1234
    Running binary classification (activity)
TabPFN Done
    Running multiclass classification (attribution_merged)
TabPFN Done
  Processing with seed 123456
    Running binary classification (activity)
TabPFN Done
    Running multiclass classification (attribution_merged)
TabPFN Done
Processing rotdaily for Burst_1
Removed rows with 'Remove' in both classif

In [9]:
# Obtain individual dataframes for results and confusion matrices
binary_metrics_list = [result["binary_metrics"] for result in all_burst_results]
binary_conf_list = [result["binary_conf"] for result in all_burst_results]
multiclass_metrics_list = [result["multiclass_metrics"] for result in all_burst_results]
multiclass_conf_list = [result["multiclass_conf"] for result in all_burst_results]

# Concatenate all results into single dataframes
binary_metrics = pd.concat(binary_metrics_list, ignore_index=True)
binary_conf = pd.concat(binary_conf_list, ignore_index=True)
multiclass_metrics = pd.concat(multiclass_metrics_list, ignore_index=True)
multiclass_conf = pd.concat(multiclass_conf_list, ignore_index=True)

In [None]:
# Output to CSV files
# binary_metrics.to_csv(
#     "../data/output/activity_comparison/activity_features_dl_tabpfn_metrics.csv",
#     index=False,
# )
# binary_conf.to_csv(
#     "../data/output/activity_comparison/activity_features_dl_tabpfn_confusion.csv",
#     index=False,
# )
multiclass_metrics.to_csv(
    "../data/output/behaviour_comparison/behaviour_features_dl_tabpfn_metrics.csv",
    index=False,
)
multiclass_conf.to_csv(
    "../data/output/behaviour_comparison/behaviour_features_dl_tabpfn_confusion.csv",
    index=False,
)