# Behavioural classification from features using DL models

This script is used to run two classifiers: [Category Embedding](https://pytorch-tabular.readthedocs.io/en/latest/models/#category-embedding-model-multi-layer-perceptron) and [GANDALF](https://pytorch-tabular.readthedocs.io/en/latest/models/#gated-adaptive-network-for-deep-automated-learning-of-features-gandalf) using the torch tabular package.  
Create the [torch_tabular environment](../environment_torch_tabular.yml) and use this for running this script.  
The functions are set up to run classifications on both multiclass (8 behavioural classes) and binary (activity/inactivity) target. See the [behavioural attribution notebook](03_combine_burst_attributions.ipynb) for details.


In [2]:
import gc
import glob
import os
import time
import traceback

import numpy as np
import pandas as pd
import torch
from pytorch_tabular import TabularModel
from pytorch_tabular.config import (
    DataConfig,
    OptimizerConfig,
    TrainerConfig,
)
from pytorch_tabular.models import (
    CategoryEmbeddingModelConfig,
    GANDALFConfig,
)
from pytorch_tabular.models.common.heads import LinearHeadConfig
from sklearn.metrics import (
    accuracy_score,
    balanced_accuracy_score,
    confusion_matrix,
    f1_score,
    precision_score,
    recall_score,
    roc_auc_score,
)
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder


In [3]:
def load_features_parquet(pq_file_path, binary_col, multiclass_col):
    """Load and preprocess feature data from a parquet file with annotations.

    Args:
        pq_file_path (str): Path to the parquet file containing feature data.
        binary_col (str): Column name for binary classification (e.g., "activity").
        multiclass_col (str): Column name for multiclass classification (e.g., "attribution_merged").

    Returns:
        pd.DataFrame: Cleaned dataframe containing features merged with annotations.
    """
    # Load parquet file
    features = pd.read_parquet(pq_file_path)

    # Load annotations
    annotations = pd.read_csv(
        "../data/temp/annotations/attributions_merged_majority_outliers.csv"
    )

    # Identify burst from parquet file name
    pq_file_name = os.path.basename(pq_file_path)
    burst = pq_file_name.split("_")[3]
    burst = "burst_" + burst
    print(burst)

    # Filter annotations for the current burst
    annotations = annotations[annotations["burst"] == burst]

    # Join annotations with feature data
    features = features.merge(annotations, on=["Ind_ID", "new_burst"], how="left")

    # Remove other string columns except features and IDs
    features = features.drop(
        columns=features.select_dtypes(include="object").columns.difference(
            [binary_col, multiclass_col, "Ind_ID"]
        )
    )

    # Remove rows with "Remove" value
    # First obtain rows with remove in the binary and multiclass columns
    binary_remove = features[features[binary_col] == "Remove"].index.sort_values()
    multi_remove = features[features[multiclass_col] == "Remove"].index.sort_values()
    # Check if both are equal:
    are_equal = binary_remove.equals(multi_remove)
    if are_equal:
        # Drop rows with remove in both columns
        features = features.drop(index=binary_remove)
        print("Removed rows with 'Remove' in both classification columns")
    else:
        print("Remove indices are not equal, please check the data")

    # Return the cleaned dataframe
    return features

In [None]:
def features_dl_models(X, y, random_seed, is_binary=False):
    """
    Train DL models and evaluate using scikit-learn metrics AFTER training.
    """
    # --- Data Preparation ---
    target_encoder = LabelEncoder()
    y_encoded = target_encoder.fit_transform(y)
    num_classes = len(target_encoder.classes_)

    if is_binary and num_classes == 2:
        if (
            "Active" in target_encoder.classes_
            and "Inactive" in target_encoder.classes_
        ):
            if target_encoder.transform(["Active"])[0] != 1:
                target_encoder.classes_ = target_encoder.classes_[::-1]
                y_encoded = target_encoder.transform(y)

    X_train, X_test, y_train, y_test = train_test_split(
        X, y_encoded, stratify=y_encoded, test_size=0.25, random_state=random_seed
    )

    # Verify number of classes in train and test match
    if len(np.unique(y_train)) != num_classes or len(np.unique(y_test)) != num_classes:
        raise ValueError(
            f"Mismatch in number of classes: Train classes {np.unique(y_train)}, Test classes {np.unique(y_test)}"
        )

    train_df = X_train.copy()
    test_df = X_test.copy()
    train_df["target"] = y_train
    test_df["target"] = y_test
    train_df.reset_index(drop=True, inplace=True)
    test_df.reset_index(drop=True, inplace=True)

    # PyTorch Tabular Configurations
    # Data config
    data_config = DataConfig(
        target=["target"],
        continuous_cols=list(X_train.columns),
        categorical_cols=[],
        pin_memory=True,
        num_workers=0,
    )
    
    trainer_config = TrainerConfig(
        auto_lr_find=True,
        batch_size=64,
        max_epochs=100,
        # min_epochs=50,
        early_stopping="valid_loss",
        early_stopping_mode="min",
        early_stopping_min_delta=0.001,
        early_stopping_patience=10,
        checkpoints="valid_loss",
        load_best=True,
        track_grad_norm=2,
        progress_bar="none",
        accelerator="gpu",
    )

    # Head config
    head_config = LinearHeadConfig(
        layers="",
        dropout=0.1,
        initialization="kaiming",
    ).__dict__

    # Optimizer config
    optimizer_config = OptimizerConfig(
        lr_scheduler="CosineAnnealingWarmRestarts",
        lr_scheduler_params={"T_0": 100, "T_mult": 1, "eta_min": 1e-5},
    )

    # Create model list with set of common parameters
    common_params = {
        "task": "classification",
        "head": "LinearHead",
        "head_config": head_config,
    }

    # Define models to test with specific configurations
    model_list = [
        CategoryEmbeddingModelConfig(layers="1024-512-256", **common_params),
        GANDALFConfig(gflu_stages=15, learnable_sparsity=False, **common_params),
    ]

    # Model training loop
    metrics_results = []
    confusion_data = []  # If you need confusion matrix data later

    for model_config in model_list:
        model_name = model_config.__class__.__name__.replace("Config", "")

        try:
            print(f"  Training {model_name}")
            start_time = time.time()

            # Create model instance
            model = TabularModel(
                data_config=data_config,
                model_config=model_config,
                optimizer_config=optimizer_config,
                trainer_config=trainer_config,
                verbose=False,
            )

            model.fit(train=train_df, validation=test_df)

            # Get predictions and probabilities
            pred_df_with_proba = model.predict(test_df)

            # Extract class predictions
            y_pred = pred_df_with_proba["prediction"].values

            # Extract probabilities based on task type
            roc_auc = np.nan  # Default to NaN
            proba_values = None

            # Extract columns with probabilities for different classes
            prob_cols = [
                col for col in pred_df_with_proba.columns if "_probability" in col
            ]

            if not prob_cols:
                print(
                    f"Warning: Probability columns not found for {model_name}. Cannot calculate ROC AUC."
                )
            else:
                # Sort columns numerically by class index (0_probability, 1_probability, ...)
                prob_cols.sort(key=lambda name: int(name.split("_")[0]))

                proba_df = pred_df_with_proba[prob_cols]
                proba_values = proba_df.values

                # Calculate ROC AUC using extracted probabilities
                try:
                    if is_binary:
                        # Expect columns like '0_probability', '1_probability'
                        if proba_values.shape[1] >= 2:
                            # Use probability of the positive class (index 1, after sorting)
                            positive_proba = proba_values[:, 1]
                            roc_auc = roc_auc_score(y_test, positive_proba)
                        elif proba_values.shape[1] == 1:
                            # Handle case where only one probability (e.g., positive class) might be returned
                            roc_auc = roc_auc_score(y_test, proba_values[:, 0])
                        else:
                            print(
                                f"Warning: Unexpected probability shape for binary task in {model_name}: {proba_values.shape}"
                            )
                    else:  # Multiclass
                        if proba_values.shape[1] == num_classes:
                            roc_auc = roc_auc_score(
                                y_test,
                                proba_values,
                                multi_class="ovr",
                                average="weighted",
                            )
                        else:
                            print(
                                f"Warning: Probability shape mismatch for multiclass task in {model_name}. Expected {num_classes} columns, got {proba_values.shape[1]}."
                            )

                except ValueError as e:
                    print(f"Could not calculate ROC AUC for {model_name}: {e}")
                    roc_auc = np.nan  # Set back to NaN if calculation fails

            # Calculate other metrics using y_pred
            if y_pred.size > 0:
                accuracy = accuracy_score(y_test, y_pred)
                balanced_acc = balanced_accuracy_score(y_test, y_pred)
                f1 = f1_score(y_test, y_pred, average="weighted", zero_division=0)
                precision = precision_score(
                    y_test, y_pred, average="weighted", zero_division=0
                )
                recall = recall_score(
                    y_test, y_pred, average="weighted", zero_division=0
                )

                # Calculate and Store Confusion Matrix
                cm = confusion_matrix(
                    y_test, y_pred, labels=np.arange(num_classes)
                )  # Ensure all classes are included
                for i in range(len(cm)):
                    for j in range(len(cm)):
                        confusion_data.append(
                            {
                                "model": model_name,
                                "actual_label": i,
                                "predicted_label": j,
                                "count": cm[i][j],
                            }
                        )
            else:
                # Assign NaN if prediction failed
                accuracy, balanced_acc, f1, precision, recall = [np.nan] * 5

            train_time = time.time() - start_time

            # Store results
            metrics_results.append(
                {
                    "model": model_name,
                    "roc_auc": roc_auc,
                    "accuracy": accuracy,  # Simple accuracy
                    "balanced_acc": balanced_acc,  # Use this instead of renaming later
                    "f1_score": f1,
                    "precision": precision,
                    "recall": recall,
                    "time": train_time,
                }
            )

            # Cleanup
            del model
            gc.collect()
            if torch.cuda.is_available():
                torch.cuda.empty_cache()

        except Exception as e:
            print(f"!! Error training/evaluating {model_name}: {str(e)}")
            print(traceback.format_exc())
            # Append placeholder results if needed
            metrics_results.append(
                {
                    "model": model_name,
                    "roc_auc": np.nan,
                    "accuracy": np.nan,
                    "balanced_acc": np.nan,
                    "f1_score": np.nan,
                    "precision": np.nan,
                    "recall": np.nan,
                    "time": np.nan,
                    "random_seed": random_seed,
                }
            )

    # Convert results to DataFrame
    metrics_df = pd.DataFrame(metrics_results)
    confusion_df = pd.DataFrame(confusion_data)

    # Map the encoded labels back to original classes if possible
    original_classes = target_encoder.classes_
    if len(confusion_df) > 0:
        confusion_df["actual_label"] = confusion_df["actual_label"].apply(
            lambda x: original_classes[x] if x < len(original_classes) else x
        )
        confusion_df["predicted_label"] = confusion_df["predicted_label"].apply(
            lambda x: original_classes[x] if x < len(original_classes) else x
        )

    return metrics_df, confusion_df

In [None]:
def process_parquet_file(
    file_path,
    random_seed_list,
    binary_col="activity",
    multiclass_col="attribution_merged",
):
    """Process a single parquet file with both binary and multiclass classification using DL models.

    Args:
        file_path (str): Path to the parquet file
        random_seed_list (list): List of random seeds to use
        binary_col (str): Column name for binary classification
        multiclass_col (str): Column name for multiclass classification

    Returns:
        dict: Dictionary containing all results DataFrames
    """
    # Extract metadata from path
    burst = os.path.basename(os.path.dirname(file_path))
    correction_type = os.path.basename(file_path).split("_")[-1].split(".")[0]
    print(f"Processing {correction_type} for {burst}")

    # Initialize result collections
    binary_metrics_list = []
    binary_conf_list = []
    multiclass_metrics_list = []
    multiclass_conf_list = []

    # Load feature data once
    features_df = load_features_parquet(file_path, binary_col, multiclass_col)

    # Create feature matrix X and target variables just once - outside the loop
    X = features_df.drop(
        [binary_col, multiclass_col, "new_burst", "Ind_ID"], axis=1, errors="ignore"
    )
    y_binary = features_df[binary_col]
    y_multiclass = features_df[multiclass_col]

    # Process each seed
    for seed in random_seed_list:
        print(f"  Processing with seed {seed}")

        # Binary classification
        print(f"    Running binary classification ({binary_col})")
        binary_metrics, binary_conf = features_dl_models(
            X, y_binary, seed, is_binary=True
        )

        # Add metadata to binary results
        binary_metrics["correction_type"] = correction_type
        binary_metrics["burst"] = burst
        binary_metrics["target"] = "Activity"
        binary_metrics["random_seed"] = seed

        binary_conf["correction_type"] = correction_type
        binary_conf["burst"] = burst
        binary_conf["target"] = "Activity"
        binary_conf["random_seed"] = seed

        # Multiclass classification
        print(f"    Running multiclass classification ({multiclass_col})")
        multiclass_metrics, multiclass_conf = features_dl_models(
            X, y_multiclass, seed, is_binary=False
        )

        # Add metadata to multiclass results
        multiclass_metrics["correction_type"] = correction_type
        multiclass_metrics["burst"] = burst
        multiclass_metrics["target"] = "Behaviour"
        multiclass_metrics["random_seed"] = seed

        multiclass_conf["correction_type"] = correction_type
        multiclass_conf["burst"] = burst
        multiclass_conf["target"] = "Behaviour"
        multiclass_conf["random_seed"] = seed

        # Append to result lists
        binary_metrics_list.append(binary_metrics)
        binary_conf_list.append(binary_conf)
        multiclass_metrics_list.append(multiclass_metrics)
        multiclass_conf_list.append(multiclass_conf)

    # Combine results
    all_binary_metrics = (
        pd.concat(binary_metrics_list) if binary_metrics_list else pd.DataFrame()
    )
    all_binary_conf = (
        pd.concat(binary_conf_list) if binary_conf_list else pd.DataFrame()
    )
    all_multiclass_metrics = (
        pd.concat(multiclass_metrics_list)
        if multiclass_metrics_list
        else pd.DataFrame()
    )
    all_multiclass_conf = (
        pd.concat(multiclass_conf_list) if multiclass_conf_list else pd.DataFrame()
    )

    # Standard column order for metrics
    metric_columns = [
        "burst",
        "correction_type",
        "target",
        "random_seed",
        "model",
        "roc_auc",
        "accuracy",
        "balanced_acc",
        "f1_score",
        "precision",
        "recall",
        "time",
    ]

    # Standard column order for confusion matrices
    conf_columns = [
        "burst",
        "correction_type",
        "target",
        "random_seed",
        "model",
        "actual_label",
        "predicted_label",
        "count",
    ]

    # Reorder columns
    if not all_binary_metrics.empty:
        all_binary_metrics = all_binary_metrics[metric_columns]
    if not all_binary_conf.empty:
        all_binary_conf = all_binary_conf[conf_columns]
    if not all_multiclass_metrics.empty:
        all_multiclass_metrics = all_multiclass_metrics[metric_columns]
    if not all_multiclass_conf.empty:
        all_multiclass_conf = all_multiclass_conf[conf_columns]

    # Create result dict before cleanup
    result_dict = {
        "binary_metrics": all_binary_metrics,
        "binary_conf": all_binary_conf,
        "multiclass_metrics": all_multiclass_metrics,
        "multiclass_conf": all_multiclass_conf,
    }

    # Force garbage collection to free memory
    gc.collect()

    # Empty PyTorch CUDA cache if available
    try:
        import torch

        if torch.cuda.is_available():
            torch.cuda.empty_cache()
    except (ImportError, AttributeError):
        pass

    # Return the result dictionary
    return result_dict

In [6]:
def wrap_run_feat_dl(
    folder_location,
    random_seed_list,
    binary_col="activity",
    multiclass_col="attribution_merged",
    correction_filters=None,
):
    """Process all parquet files in a folder with multiple random seeds.

    Args:
        folder_location (str): Path to folder containing parquet files
        random_seed_list (list): List of random seeds to use
        binary_col (str): Column name for binary classification
        multiclass_col (str): Column name for multiclass classification

    Returns:
        dict: Dictionary with combined results DataFrames
    """
    filetype = folder_location.split("/")[-2]
    print(f"DL training started for {filetype}")

    # Find all parquet files
    all_files = [
        os.path.normpath(f).replace("\\", "/")
        for f in glob.glob(folder_location + "**/*.parquet", recursive=True)
    ]

    # Apply correction type filters if provided
    if correction_filters is not None:
        all_files = [
            f
            for f in all_files
            if any(corr_type in os.path.basename(f) for corr_type in correction_filters)
        ]

    print(f"Found {len(all_files)} parquet files")

    # Process all files and collect results
    all_results = []
    for file_path in all_files:
        result = process_parquet_file(
            file_path, random_seed_list, binary_col, multiclass_col
        )
        all_results.append(result)

    # Combine results from all files
    combined_results = {
        "binary_metrics": pd.concat(
            [r["binary_metrics"] for r in all_results if not r["binary_metrics"].empty]
        ),
        "binary_conf": pd.concat(
            [r["binary_conf"] for r in all_results if not r["binary_conf"].empty]
        ),
        "multiclass_metrics": pd.concat(
            [
                r["multiclass_metrics"]
                for r in all_results
                if not r["multiclass_metrics"].empty
            ]
        ),
        "multiclass_conf": pd.concat(
            [
                r["multiclass_conf"]
                for r in all_results
                if not r["multiclass_conf"].empty
            ]
        ),
    }

    return combined_results

### Test functions sequentially

In [None]:
pq_file_path = "../data/raw/features/Burst_2/annotated_features_burst_2_RB.parquet"
random_seed_list = [42]

In [None]:
acc_feat = load_features_parquet(
    pq_file_path, binary_col="activity", multiclass_col="attribution_merged"
)

In [None]:
binary_col = "activity"
multiclass_col = "attribution_merged"

# Initialize result collections
binary_metrics_list = []
binary_conf_list = []
multiclass_metrics_list = []
multiclass_conf_list = []

# Create feature matrix X and target variables just once - outside the loop
X = acc_feat.drop(
    [binary_col, multiclass_col, "new_burst", "Ind_ID"], axis=1, errors="ignore"
)
y_binary = acc_feat[binary_col]
y_multiclass = acc_feat[multiclass_col]

In [None]:
multi_metrics, multi_conf = features_dl_models(
    X, y_multiclass, random_seed_list[0], is_binary=False
)

In [None]:
bin_sweep_df = features_dl_models(X, y_binary, random_seed_list[0], is_binary=True)

In [None]:
acc_met = process_parquet_file(
    pq_file_path,
    random_seed_list,
    binary_col="activity",
    multiclass_col="attribution_merged",
)

In [None]:
acc_met_bin = acc_met["binary_metrics"]
acc_met_multi = acc_met["multiclass_metrics"]
acc_conf_bin = acc_met["binary_conf"]
acc_conf_multi = acc_met["multiclass_conf"]

# Concatenate the results
acc_met_bin = pd.concat([acc_met_bin])
acc_met_multi = pd.concat([acc_met_multi])
acc_conf_bin = pd.concat([acc_conf_bin])
acc_conf_multi = pd.concat([acc_conf_multi])

### Run over whole folders

In [7]:
random_seeds = [42, 100, 123, 1234, 123456]
# random_seeds = [42]
correction_types_to_process = ["uncorrected", "rotdaily", "rotbasal"]
folder_locations = [
    "../data/raw/features/Burst_1/",
    "../data/raw/features/Burst_2/",
    "../data/raw/features/Burst_3/",
    "../data/raw/features/Burst_4/",
]

In [None]:
# Run the ML training for all files in the specified folders
all_burst_results = [
    wrap_run_feat_dl(
        folder, random_seeds, correction_filters=correction_types_to_process
    )
    for folder in folder_locations
]

In [9]:
# Obtain individual dataframes for results and confusion matrices
binary_metrics_list = [result["binary_metrics"] for result in all_burst_results]
binary_conf_list = [result["binary_conf"] for result in all_burst_results]
multiclass_metrics_list = [result["multiclass_metrics"] for result in all_burst_results]
multiclass_conf_list = [result["multiclass_conf"] for result in all_burst_results]

# Concatenate all results into single dataframes
binary_metrics = pd.concat(binary_metrics_list, ignore_index=True)
binary_conf = pd.concat(binary_conf_list, ignore_index=True)
multiclass_metrics = pd.concat(multiclass_metrics_list, ignore_index=True)
multiclass_conf = pd.concat(multiclass_conf_list, ignore_index=True)

In [None]:
# Output to CSV files
# binary_metrics.to_csv(
#     "../data/output/activity_comparison/activity_features_dl_metrics.csv",
#     index=False,
# )
# binary_conf.to_csv(
#     "../data/output/activity_comparison/activity_features_dl_confusion.csv",
#     index=False,
# )
multiclass_metrics.to_csv(
    "../data/output/behaviour_comparison/behaviour_features_dl_metrics.csv",
    index=False,
)
multiclass_conf.to_csv(
    "../data/output/behaviour_comparison/behaviour_features_dl_confusion.csv",
    index=False,
)