# Model Training for Behaviour Classification

This notebook trains and saves the final shortlisted models for behaviour classification based on feature data.  
We use TabPFN trained on Burst 4 basal corrected data.  
The [tabpfn environment](../environment_tabpfn.yml) is required to run this notebook.  

In [1]:
import os
import time
from datetime import datetime

import numpy as np
import pandas as pd
import torch
import joblib

# For TabPFN
from tabpfn import TabPFNClassifier

# For metrics and otheru tilities
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder
from sklearn.metrics import (
    accuracy_score,
    balanced_accuracy_score,
    f1_score,
    precision_score,
    recall_score,
    confusion_matrix,
    precision_recall_fscore_support,
)

In [2]:
def load_training_features(pq_file_path, multiclass_col="attribution_merged"):
    """
    Load and preprocess feature data from a parquet file with behaviour annotations.
    Args:
        pq_file_path (str): Path to the parquet file containing feature data.
        multiclass_col (str): Column name for multiclass classification (default: "attribution_merged").
    Returns:
        pd.DataFrame: Cleaned dataframe containing features merged with annotations.
    """
    print(f"Loading features from: {pq_file_path}")

    # Load feature parquet file
    features = pd.read_parquet(pq_file_path)

    # Load annotations
    annotations = pd.read_csv(
        "../data/temp/annotations/attributions_merged_majority_outliers.csv"
    )

    # Identify burst from parquet file name
    pq_file_name = os.path.basename(pq_file_path)
    burst = pq_file_name.split("_")[3]
    burst = "burst_" + burst

    # Filter annotations for the current burst
    annotations = annotations[annotations["burst"] == burst]

    # Join annotations with feature data
    features = features.merge(annotations, on=["Ind_ID", "new_burst"], how="left")

    # Remove other string columns except multiclass_col and IDs
    features = features.drop(
        columns=features.select_dtypes(include="object").columns.difference(
            [multiclass_col, "Ind_ID"]
        )
    )

    # Remove rows with "Remove" value in multiclass column
    remove_indices = features[features[multiclass_col] == "Remove"].index
    features = features.drop(index=remove_indices)
    print(f"Removed {len(remove_indices)} rows with 'Remove' value")

    # Print class distribution
    class_dist = features[multiclass_col].value_counts()
    print(f"\nClass distribution:")
    for class_name, count in class_dist.items():
        print(f"  {class_name}: {count} ({count / len(features) * 100:.1f}%)")

    return features


# Test the function
# pq_file_path = (
#     "../data/raw/features/Burst_1/annotated_features_burst_1_uncorrected.parquet"
# )
# features_df = load_training_features(pq_file_path)

In [3]:
def train_and_evaluate_tabpfn(
    features_df, target_col="attribution_merged", random_seed=42, test_size=0.2
):
    """
    Train a TabPFN model and evaluate its performance (overall and per-class).
    Returns: model, label_encoder, metrics_dict, per_class_df, confusion_df, train_samples, valid_samples
    """

    # Prepare features and target
    X = features_df.drop([target_col, "Ind_ID", "new_burst"], axis=1, errors="ignore")
    y = features_df[target_col]

    # Encode target
    label_encoder = LabelEncoder()
    y_encoded = label_encoder.fit_transform(y)

    # Split data
    X_train, X_valid, y_train, y_valid = train_test_split(
        X, y_encoded, stratify=y_encoded, test_size=test_size, random_state=random_seed
    )
    train_samples = len(X_train)
    valid_samples = len(X_valid)

    # Train TabPFN
    model = TabPFNClassifier(random_state=random_seed)
    model.fit(X_train, y_train)

    # Predict
    y_pred = model.predict(X_valid)
    class_names = label_encoder.classes_

    # Overall metrics
    metrics_dict = {
        "accuracy": accuracy_score(y_valid, y_pred),
        "balanced_accuracy": balanced_accuracy_score(y_valid, y_pred),
        "f1_score": f1_score(y_valid, y_pred, average="weighted"),
        "precision": precision_score(
            y_valid, y_pred, average="weighted", zero_division=0
        ),
        "recall": recall_score(y_valid, y_pred, average="weighted"),
        "train_samples": train_samples,
        "valid_samples": valid_samples,
    }

    # Confusion matrix
    cm = confusion_matrix(y_valid, y_pred)
    confusion_data = []
    for i, actual in enumerate(class_names):
        for j, predicted in enumerate(class_names):
            confusion_data.append(
                {
                    "actual_label": actual,
                    "predicted_label": predicted,
                    "count": int(cm[i, j]),
                }
            )
    confusion_df = pd.DataFrame(confusion_data)

    # Per-class metrics
    precision, recall, f1, support = precision_recall_fscore_support(
        y_valid, y_pred, labels=range(len(class_names)), zero_division=0
    )
    per_class_df = pd.DataFrame(
        {
            "class": class_names,
            "precision": precision,
            "recall": recall,
            "f1_score": f1,
            "support": support,
        }
    )

    return model, label_encoder, metrics_dict, per_class_df, confusion_df


# model, label_encoder, metrics_dict, per_class_df, confusion_df = (
#     train_and_evaluate_tabpfn(features_df)
# )
# print(metrics_dict)
# print(per_class_df)
# print(confusion_df)

In [4]:
def save_tabpfn_model_and_metadata(
    model,
    label_encoder,
    metrics_dict,
    confusion_df,
    per_class_df,
    model_dir,
    model_name,
):
    """
    Save TabPFN model, label encoder, and associated metadata and metrics.

    Args:
        model: Trained TabPFNClassifier
        label_encoder: Fitted LabelEncoder
        metrics_dict: Dictionary of training metrics
        confusion_df: DataFrame of confusion matrix
        per_class_df: DataFrame of per-class metrics
        model_dir: Directory to save files
        model_name: Base name for files
        train_samples: Number of training samples
        valid_samples: Number of validation samples
    """
    os.makedirs(model_dir, exist_ok=True)

    # 1. Save the model
    model_path = os.path.join(model_dir, f"{model_name}_tabpfn_model.joblib")
    joblib.dump(model, model_path)
    print(f"Model saved to: {model_path}")

    # 2. Save the label encoder
    encoder_path = os.path.join(model_dir, f"{model_name}_label_encoder.joblib")
    joblib.dump(label_encoder, encoder_path)
    print(f"Label encoder saved to: {encoder_path}")

    # 3. Save training metadata
    metadata = {
        "model_name": model_name,
        "model_type": "TabPFN",
        "training_date": datetime.now().isoformat(),
        "num_classes": len(label_encoder.classes_),
        "class_names": list(label_encoder.classes_),
        **metrics_dict,
    }
    metadata_df = pd.DataFrame([metadata])
    metadata_path = os.path.join(model_dir, f"{model_name}_training_metadata.csv")
    metadata_df.to_csv(metadata_path, index=False)
    print(f"Training metadata saved to: {metadata_path}")

    # 4. Save class mapping
    class_mapping = pd.DataFrame(
        {
            "class_index": range(len(label_encoder.classes_)),
            "class_name": label_encoder.classes_,
        }
    )
    class_mapping_path = os.path.join(model_dir, f"{model_name}_class_mapping.csv")
    class_mapping.to_csv(class_mapping_path, index=False)
    print(f"Class mapping saved to: {class_mapping_path}")

    # 5. Save confusion matrix
    confusion_path = os.path.join(model_dir, f"{model_name}_confusion_matrix.csv")
    confusion_df.to_csv(confusion_path, index=False)
    print(f"Confusion matrix saved to: {confusion_path}")

    # 6. Save per-class metrics
    per_class_path = os.path.join(model_dir, f"{model_name}_per_class_metrics.csv")
    per_class_df.to_csv(per_class_path, index=False)
    print(f"Per-class metrics saved to: {per_class_path}")

    return {
        "model_path": model_path,
        "encoder_path": encoder_path,
        "metadata_path": metadata_path,
        "class_mapping_path": class_mapping_path,
        "confusion_path": confusion_path,
        "per_class_path": per_class_path,
    }


# Example usage
# Get number of samples
# train_samples = len(features_df) * 0.8  # or use actual split sizes if available
# valid_samples = len(features_df) * 0.2

# Save everything
# save_tabpfn_model_and_metadata(
#     model=model,
#     label_encoder=label_encoder,
#     metrics_dict=metrics_dict,
#     confusion_df=confusion_df,
#     per_class_df=per_class_df,
#     model_dir="../models/tabpfn_b1_unc/",
#     model_name="tabpfn_b1_unc",
# )

## TabPFN - Burst 4 - Basal

In [9]:
CONFIG_TABPFN_B4_BASAL = {
    "data_path": "../data/raw/features/Burst_4/annotated_features_burst_4_rotbasal.parquet",
    "model_dir": "../models/tabpfn_b4_basal/",
    "random_seed": 42,
    "valid_size": 0.2,
    "target_column": "attribution_merged",
    "model_name": "tabpfn_b4_basal",
}

print("Configuration:")
for key, value in CONFIG_TABPFN_B4_BASAL.items():
    print(f"  {key}: {value}")

np.random.seed(CONFIG_TABPFN_B4_BASAL["random_seed"])
torch.manual_seed(CONFIG_TABPFN_B4_BASAL["random_seed"])
print(f"\nRandom seed set to: {CONFIG_TABPFN_B4_BASAL['random_seed']}")

Configuration:
  data_path: ../data/raw/features/Burst_4/annotated_features_burst_4_rotbasal.parquet
  model_dir: ../models/tabpfn_b4_basal/
  random_seed: 42
  valid_size: 0.2
  target_column: attribution_merged
  model_name: tabpfn_b4_basal

Random seed set to: 42


In [10]:
features_df = load_training_features(CONFIG_TABPFN_B4_BASAL["data_path"])

model, label_encoder, metrics_dict, per_class_df, confusion_df = (
    train_and_evaluate_tabpfn(
        features_df,
        target_col=CONFIG_TABPFN_B4_BASAL["target_column"],
        random_seed=CONFIG_TABPFN_B4_BASAL["random_seed"],
        test_size=CONFIG_TABPFN_B4_BASAL["valid_size"],
    )
)

model_metadata = save_tabpfn_model_and_metadata(
    model=model,
    label_encoder=label_encoder,
    metrics_dict=metrics_dict,
    confusion_df=confusion_df,
    per_class_df=per_class_df,
    model_dir=CONFIG_TABPFN_B4_BASAL["model_dir"],
    model_name=CONFIG_TABPFN_B4_BASAL["model_name"],
)

Loading features from: ../data/raw/features/Burst_4/annotated_features_burst_4_rotbasal.parquet
Removed 1026 rows with 'Remove' value

Class distribution:
  Resting: 1599 (36.5%)
  Eating: 1300 (29.7%)
  Walking: 532 (12.1%)
  Grooming actor: 385 (8.8%)
  Grooming receiver: 286 (6.5%)
  Self-scratching: 136 (3.1%)
  Sleeping: 110 (2.5%)
  Running: 31 (0.7%)




Model saved to: ../models/tabpfn_b4_basal/tabpfn_b4_basal_tabpfn_model.joblib
Label encoder saved to: ../models/tabpfn_b4_basal/tabpfn_b4_basal_label_encoder.joblib
Training metadata saved to: ../models/tabpfn_b4_basal/tabpfn_b4_basal_training_metadata.csv
Class mapping saved to: ../models/tabpfn_b4_basal/tabpfn_b4_basal_class_mapping.csv
Confusion matrix saved to: ../models/tabpfn_b4_basal/tabpfn_b4_basal_confusion_matrix.csv
Per-class metrics saved to: ../models/tabpfn_b4_basal/tabpfn_b4_basal_per_class_metrics.csv
