# TabPFN Inference Pipeline for Behaviour Classification

This notebook runs inference using a trained TabPFN model on new, unlabeled accelerometer data.  
We use the [TabPFN model](../models/tabpfn_b4_basal/) trained on Burst 4 basal corrected data.  
The [tabpfn environment](../environment_tabpfn.yml) is required to run this notebook. 

In [None]:
import os
import time
import datetime

import numpy as np
import pandas as pd
import joblib

from tqdm import tqdm
import pyarrow as pa
import pyarrow.parquet as pq

import warnings

In [None]:
def load_tabpfn_model(model_dir, model_name):
    """
    Load a TabPFN model and its label encoder from disk.
    Returns: model, label_encoder
    """
    model_path = f"{model_dir}/{model_name}_tabpfn_model.joblib"
    encoder_path = f"{model_dir}/{model_name}_label_encoder.joblib"
    model = joblib.load(model_path)
    label_encoder = joblib.load(encoder_path)
    print(f"Loaded TabPFN model from {model_path}")
    print(f"Loaded label encoder from {encoder_path}")
    return model, label_encoder


# Example usage:
# model, label_encoder = load_tabpfn_model("../models/tabpfn_b1_unc", "tabpfn_b1_unc")

In [None]:
import pandas as pd
import pyarrow.parquet as pq
import pytz
import datetime


def extract_features_for_inference(pq_file_path, date, chunk_size=10000):
    """
    Given a parquet file and a date, extract all feature data for that date and yield (features, metadata) chunks.
    Assumes 'burst_start_time' is the date column, and 'Ind_ID', 'new_burst', 'burst_id' are metadata columns.
    """
    # Set timezone
    tz = pytz.timezone("Africa/Johannesburg")
    # Create timezone-aware datetime objects for start and end of day
    start_dt = tz.localize(datetime.datetime.combine(date, datetime.time(0, 0, 0)))
    end_dt = tz.localize(datetime.datetime.combine(date, datetime.time(23, 59, 59)))

    # Read filtered data
    df = pq.read_table(
        pq_file_path,
        filters=[
            ("burst_start_time", ">=", start_dt),
            ("burst_start_time", "<", end_dt),
        ],
    ).to_pandas()

    if df.empty:
        print(f"No data found for {date}")
        return

    # Identify feature columns (exclude metadata)
    exclude_cols = ["Ind_ID", "new_burst", "burst_id", "burst_start_time"]
    feature_columns = [col for col in df.columns if col not in exclude_cols]

    # Sort for consistency
    df = df.sort_values(by=["Ind_ID", "burst_id", "new_burst"]).reset_index(drop=True)

    n = len(df)

    chunks = []
    for start in range(0, n, chunk_size):
        chunk = df.iloc[start : start + chunk_size]
        features = chunk[feature_columns].copy()
        metadata = chunk[["Ind_ID", "burst_id", "new_burst", "burst_start_time"]].copy()
        chunks.append((features, metadata))

    print(f"Extracted {n} rows for {date} and created {len(chunks)} chunks.")

    return chunks


# Example usage:
# pq_file_path = (
#     "../data/raw/focal_sampling/focal_sampled_features_burst_1_uncorrected.parquet"
# )
# date = datetime.date(2022, 7, 3)
# chunks = extract_features_for_inference(pq_file_path, date)
# features_df_1, metadata_df_1 = chunks[0]
# print(features_df_1.head())
# print(metadata_df_1.head())

In [None]:
def predict_and_format_probas_tabpfn(model, label_encoder, chunks):
    """
    For each chunk (features, metadata), run TabPFN prediction, format the probas DataFrame, and combine all chunks.
    Returns a single DataFrame with per-class probabilities, predicted behaviour, and metadata.
    """
    probas_list = []
    class_names = list(label_encoder.classes_)

    for features_df, metadata_df in chunks:
        # Predict probabilities and class indices
        probas = model.predict_proba(features_df)
        preds = np.argmax(probas, axis=1)
        # Format probabilities DataFrame
        probas_df = pd.DataFrame(probas, columns=class_names)
        probas_df["predicted_behaviour"] = label_encoder.inverse_transform(preds)
        # Combine with metadata
        combined = pd.concat(
            [metadata_df.reset_index(drop=True), probas_df.reset_index(drop=True)],
            axis=1,
        )
        probas_list.append(combined)

    # Concatenate all chunks
    final_df = pd.concat(probas_list, ignore_index=True)
    # Reorder columns if needed
    col_order = [
        "Ind_ID",
        "burst_id",
        "new_burst",
        "burst_start_time",
        "predicted_behaviour",
    ]
    class_cols = [col for col in class_names if col in final_df.columns]
    final_df = final_df[
        col_order
        + class_cols
        + [col for col in final_df.columns if col not in col_order + class_cols]
    ]
    return final_df


# Example usage:
# final_df = predict_and_format_probas_tabpfn(model, label_encoder, chunks)

In [None]:
def batch_predict_dates_tabpfn(
    pq_file_path,
    model_dir,
    model_name,
    output_parquet=None,
    schema=None,
):
    """
    For each unique date in the parquet file, run TabPFN inference and concatenate results.
    Args:
        pq_file_path: Path to input parquet file
        model_dir: Directory containing model files
        model_name: Model name (used for loading)
        output_parquet: If provided, saves the concatenated DataFrame to this path
        schema: pyarrow schema for output parquet
    Returns:
        pd.DataFrame with predictions for all dates (if output_parquet is None)
    """
    # Load model and label encoder
    model, label_encoder = load_tabpfn_model(model_dir, model_name)

    # Get all unique dates in the parquet file
    df = pd.read_parquet(pq_file_path, columns=["burst_start_time"])
    df["date"] = pd.to_datetime(df["burst_start_time"]).dt.date
    unique_dates = sorted(df["date"].unique())
    print(f"Found {len(unique_dates)} unique dates in the data.")
    del df

    probas_list = []
    for date in tqdm(unique_dates, desc="Processing dates"):
        chunks = extract_features_for_inference(pq_file_path, date)
        if not chunks:
            continue
        probas = predict_and_format_probas_tabpfn(model, label_encoder, chunks)
        probas_list.append(probas)

    final_df = pd.concat(probas_list, ignore_index=True)

    # Ensure correct filetype
    final_df["Ind_ID"] = final_df["Ind_ID"].astype(str)
    final_df["burst_id"] = final_df["burst_id"].astype(int)
    final_df["new_burst"] = final_df["new_burst"].astype(int)
    final_df["predicted_behaviour"] = final_df["predicted_behaviour"].astype(str)
    for col in label_encoder.classes_:
        if col in final_df.columns:
            final_df[col] = final_df[col].astype(float).round(4)

    # Ensure column order
    class_cols = [
        "Eating",
        "Grooming actor",
        "Grooming receiver",
        "Resting",
        "Running",
        "Self-scratching",
        "Sleeping",
        "Walking",
    ]
    final_df = final_df.reindex(
        columns=[
            "Ind_ID",
            "burst_id",
            "new_burst",
            "burst_start_time",
            "predicted_behaviour",
        ]
        + class_cols
    )

    if output_parquet:
        table = pa.Table.from_pandas(final_df, schema=schema, preserve_index=False)
        pq.write_table(table, output_parquet, compression="zstd", compression_level=9)
        print(f"Saved predictions to {output_parquet}")
        return None
    return final_df

## Run inference
We run inference on the focal sampled dataset

In [None]:
# Define schema for output parquet file
parquet_schema = pa.schema(
    [
        ("Ind_ID", pa.string()),
        ("burst_id", pa.int32()),
        ("new_burst", pa.int32()),
        ("burst_start_time", pa.timestamp("ns", tz="Africa/Johannesburg")),
        ("predicted_behaviour", pa.string()),
        ("Eating", pa.float32()),
        ("Grooming actor", pa.float32()),
        ("Grooming receiver", pa.float32()),
        ("Resting", pa.float32()),
        ("Running", pa.float32()),
        ("Self-scratching", pa.float32()),
        ("Sleeping", pa.float32()),
        ("Walking", pa.float32()),
    ]
)

In [None]:
import pandas as pd

df = pd.read_parquet(
    "../data/raw/focal_sampling/features/focal_sampled_features_burst_2_uncorrected.parquet"
)
print(f"Shape: {df.shape}")
print(f"Columns: {df.columns.tolist()}")

In [None]:
# Collect warnings
collected_warnings = []


def custom_warning_handler(message, category, filename, lineno, file=None, line=None):
    collected_warnings.append(f"{category.__name__}: {message} ({filename}:{lineno})")

In [None]:
# TabPFN Burst 4 Rotbasal
with warnings.catch_warnings():
    warnings.simplefilter("always")
    warnings.showwarning = custom_warning_handler

    batch_predict_dates_tabpfn(
        pq_file_path="../data/raw/focal_sampling/features/focal_sampled_features_burst_4_rotbasal.parquet",
        model_dir="../models/tabpfn_b4_basal/",
        model_name="tabpfn_b4_basal",
        output_parquet="../data/output/inference_results/tabpfn_b4_basal_focal_sampled_random_predictions.parquet",
        schema=parquet_schema,
    )