# HydraROCKET Inference Pipeline for Behaviour Classification

This notebook runs inference using a trained HydraROCKET model on new, unlabeled accelerometer data.  
We used the [HydraROCKET model](../models/hr_b1_unc/) trained on burst 1 uncorrected  
The [tsai environment](../environment_tsai.yml) is required to run this notebook.  

In [1]:
import os
import re
import torch
import pandas as pd
import tsai.all as ts
from tqdm import tqdm
import pyarrow as pa
import pyarrow.parquet as pq
import pytz
import datetime
import numpy as np

In [2]:
def load_model_cuda(model_dir, model_name):
    """
    Load a model using tsai.load_learner_all and ensure all parameters and buffers are on CUDA.
    Raises an error if any part is not on CUDA.
    """
    learn_gpu = ts.load_learner_all(
        path=model_dir,
        dls_fname=f"{model_name}_dls",
        model_fname=f"{model_name}_model",
        learner_fname=f"{model_name}_learner",
    )
    # Check all parameters and buffers
    for name, module in learn_gpu.model.named_modules():
        if hasattr(module, "device"):
            if module.device.type != "cuda":
                raise RuntimeError(
                    f"Module '{name}' is on {module.device}, not cuda. Model was saved to CPU."
                )
        else:
            # Optionally, check the device of module parameters
            params = list(module.parameters())
            if params and params[0].device.type != "cuda":
                raise RuntimeError(
                    f"Module '{name}' parameter is on {params[0].device}, not cuda. Model was saved to CPU."
                )
    for name, param in learn_gpu.model.named_parameters():
        if param.device.type != "cuda":
            raise RuntimeError(
                f"Model parameter '{name}' is on {param.device}, not cuda. Model was saved to CPU."
            )
    for name, buf in learn_gpu.model.named_buffers():
        if buf.device.type != "cuda":
            raise RuntimeError(
                f"Model buffer '{name}' is on {buf.device}, not cuda. Model was saved to CPU."
            )
    print("All model parameters and buffers are on CUDA.")
    return learn_gpu


# test the function
# model_dir = "../models/hr_b1_unc"
# model_name = "hr_b1_unc"
# model_trial = load_model_cuda(model_dir, model_name)

In [3]:
def extract_data_for_inference(pq_file_path, date):
    """
    Given a parquet file and a date, extract all data for that date and return X_torch (cuda tensor),
    plus a DataFrame with unique burst info and unique ids.
    """
    # Set timezone
    tz = pytz.timezone("Africa/Johannesburg")
    # Create timezone-aware datetime objects for start and end of day
    start_dt = tz.localize(datetime.datetime.combine(date, datetime.time(0, 0, 0)))
    end_dt = tz.localize(datetime.datetime.combine(date, datetime.time(23, 59, 59)))
    # Set the column names
    acc_cols = ["X", "Y", "Z"]
    id_cols = ["Ind_ID", "burst_id", "new_burst", "sample_number"]
    # Read filtered data. Ensure burst_start_time is read for adding to final dataframe
    table = pq.read_table(
        pq_file_path,
        filters=[
            ("burst_start_time", ">=", start_dt),
            ("burst_start_time", "<", end_dt),
        ],
        columns=id_cols + acc_cols + ["burst_start_time"],
    )
    filtered_df = table.to_pandas()
    # Extract burst from pq_file_path
    base = os.path.basename(pq_file_path)
    match = re.search(r"burst_(\d+)", base)
    if match:
        burst_num = int(match.group(1))
        # print(f"Burst number: {burst_num}")
    else:
        raise ValueError("Could not extract burst number from file name.")
    # Filter bursts with more than expected samples
    expected_samples_dict = {1: 138, 2: 69, 3: 47, 4: 34}
    expected_samples = expected_samples_dict[burst_num]
    # Check for bursts with too many samples
    burst_sample_counts = (
        filtered_df.groupby(["Ind_ID", "new_burst"])["sample_number"]
        .max()
        .reset_index()
    )
    bad_bursts = burst_sample_counts[
        burst_sample_counts["sample_number"] > expected_samples
    ]
    # Filter out bad bursts
    if not bad_bursts.empty:
        filtered_df = filtered_df.merge(
            bad_bursts[["Ind_ID", "new_burst"]],
            on=["Ind_ID", "new_burst"],
            how="left",
            indicator=True,
        )
        filtered_df = filtered_df[filtered_df["_merge"] == "left_only"].drop(
            columns=["_merge"]
        )

    # Unique burst info
    unique_bursts = filtered_df[
        ["Ind_ID", "new_burst", "burst_start_time"]
    ].drop_duplicates()
    unique_bursts["new_burst"] = unique_bursts["new_burst"].astype(str)
    # Melt to long format
    filtered_df = filtered_df.drop(columns=["burst_start_time"])
    filtered_df_melted = pd.melt(
        filtered_df,
        id_vars=id_cols,
        value_vars=acc_cols,
        var_name="feature",
        value_name="value",
    )
    # Create id column
    filtered_df_melted["id"] = (
        filtered_df_melted["Ind_ID"]
        + "_"
        + filtered_df_melted["burst_id"].astype(str)
        + "_"
        + filtered_df_melted["new_burst"].astype(str)
    )
    filtered_df_melted.drop(["Ind_ID", "burst_id", "new_burst"], axis=1, inplace=True)
    # Sort by filtered_df_melted["id"] and sample_number
    filtered_df_melted = filtered_df_melted.sort_values(by=["id", "sample_number"])
    # Reset index
    filtered_df_melted = filtered_df_melted.reset_index(drop=True)
    # Pivot to wide
    filtered_df_wide = filtered_df_melted.pivot_table(
        index=["id", "feature"],
        columns="sample_number",
        values="value",
    )
    filtered_df_wide.columns = [str(col) for col in filtered_df_wide.columns]
    filtered_df_wide = filtered_df_wide.reset_index()
    # Get unique ids
    unique_ids = filtered_df_wide["id"].unique()
    # Convert to X_torch
    X, _ = ts.df2xy(
        filtered_df_wide,
        sample_col="id",
        feat_col="feature",
        target_col=None,
        data_cols=None,
    )
    X = X.astype(np.float64)
    X_torch = torch.from_numpy(X).float().cuda()

    # Check if any of the tensors are empty or contain NaNs
    nan_burst_indices = []
    has_nan = []
    for i in range(X_torch.shape[0]):
        if torch.isnan(X_torch[i]).any():
            nan_burst_indices.append(i)
            has_nan.append(True)
        else:
            has_nan.append(False)

    # Add nan column to unique_ids after converting to DataFrame
    # Only needed for debugging purposes
    # unique_ids = pd.DataFrame(unique_ids, columns=["id"])
    # unique_ids["has_nan"] = has_nan

    print(f"Extracted data for {len(unique_ids)} unique bursts on {date}.")
    # Print warning about bad bursts
    if not bad_bursts.empty:
        print(
            "WARNING: Found bursts with more samples than expected. These will be removed:"
        )
        print(bad_bursts)
    # Print warning about nan values in X_torch
    if nan_burst_indices:
        print(
            f"WARNING: Found NaNs in {len(nan_burst_indices)} bursts at indices: {nan_burst_indices}"
        )
        print("Corresponding burst ids with NaNs:")
        for idx in nan_burst_indices:
            print(f"  index {idx}: {unique_ids[idx]}")
    return X_torch, unique_ids, unique_bursts


# test the function
# pq_file_path = (
#     "../data/raw/focal_sampling/focal_sampled_acc_burst_1_uncorrected.parquet"
# )
# date = datetime.date(2022, 7, 3)
# X_torch, unique_ids, unique_bursts = extract_data_for_inference(pq_file_path, date)

In [4]:
def predict_and_format_probas(
    model, X_torch, unique_ids, unique_bursts, class_mapping_path
):
    """
    Run prediction, format the probas DataFrame, and merge with burst info.
    """
    # Load class mapping
    class_mappings = pd.read_csv(class_mapping_path)
    class_mapping_dict = dict(
        zip(class_mappings["class_index"], class_mappings["class_name"])
    )
    # Predict
    probas, _, preds = model.get_X_preds(X_torch)
    probas = pd.DataFrame(
        probas.cpu().numpy(), columns=[str(i) for i in range(probas.shape[1])]
    )
    probas.columns = [class_mapping_dict[int(col)] for col in probas.columns]
    probas["predicted_behaviour"] = probas.idxmax(axis=1)
    probas["id"] = unique_ids
    # Split id into columns
    probas[["Ind_ID", "burst_id", "new_burst"]] = probas["id"].str.split(
        "_", expand=True
    )
    probas.drop("id", axis=1, inplace=True)
    # Merge with burst info
    probas = probas.merge(unique_bursts, on=["Ind_ID", "new_burst"], how="left")
    # Reorder columns
    col_order = [
        "Ind_ID",
        "burst_id",
        "new_burst",
        "burst_start_time",
        "predicted_behaviour",
    ]
    probas = probas[col_order + [col for col in probas.columns if col not in col_order]]
    return probas

In [5]:
# Wrapper function
def batch_predict_dates(
    pq_file_path,
    model_dir,
    model_name,
    output_parquet=None,
    schema=None,
):
    """
    For each unique date in the parquet file, run inference and concatenate results.
    Args:
        pq_file_path: Path to input parquet file
        model_dir: Directory containing model files
        model_name: Model name (used for loading)
        class_mapping_path: Path to class mapping CSV (optional, inferred if None)
        acc_cols, id_cols: Columns for accelerometer and IDs
        output_csv: If provided, saves the concatenated DataFrame to this path
    Returns:
        pd.DataFrame with predictions for all dates
    """

    # Load model and check CUDA
    model = load_model_cuda(model_dir, model_name)
    # Obtain path of csv with class mapping
    class_mapping_path = f"{model_dir}/{model_name}_class_mapping.csv"

    # Get all unique dates in the parquet file
    df = pd.read_parquet(pq_file_path, columns=["burst_start_time"])
    df["date"] = pd.to_datetime(df["burst_start_time"]).dt.date
    unique_dates = sorted(df["date"].unique())
    print(f"Found {len(unique_dates)} unique dates in the data.")
    del df

    probas_list = []
    for date in tqdm(unique_dates, desc="Processing dates"):
        X_torch, unique_ids, unique_bursts = extract_data_for_inference(
            pq_file_path, date
        )
        probas = predict_and_format_probas(
            model, X_torch, unique_ids, unique_bursts, class_mapping_path
        )
        probas_list.append(probas)

    final_df = pd.concat(probas_list, ignore_index=True)

    # Ensure correct filetype
    final_df["Ind_ID"] = final_df["Ind_ID"].astype(str)
    final_df["burst_id"] = final_df["burst_id"].astype(int)
    final_df["new_burst"] = final_df["new_burst"].astype(int)
    final_df["predicted_behaviour"] = final_df["predicted_behaviour"].astype(str)
    for col in [
        "Eating",
        "Grooming actor",
        "Grooming receiver",
        "Resting",
        "Running",
        "Self-scratching",
        "Sleeping",
        "Walking",
    ]:
        final_df[col] = final_df[col].astype(float).round(4)

    if output_parquet:
        table = pa.Table.from_pandas(final_df, schema=schema, preserve_index=False)
        pq.write_table(table, output_parquet, compression="zstd", compression_level=9)
        print(f"Saved predictions to {output_parquet}")
        return None
    return final_df

## Run inference
We run inference on the focal sampled dataset

In [6]:
# Define schema for output parquet file
parquet_schema = pa.schema(
    [
        ("Ind_ID", pa.string()),
        ("burst_id", pa.int32()),
        ("new_burst", pa.int32()),
        ("burst_start_time", pa.timestamp("ns", tz="Africa/Johannesburg")),
        ("predicted_behaviour", pa.string()),
        ("Eating", pa.float32()),
        ("Grooming actor", pa.float32()),
        ("Grooming receiver", pa.float32()),
        ("Resting", pa.float32()),
        ("Running", pa.float32()),
        ("Self-scratching", pa.float32()),
        ("Sleeping", pa.float32()),
        ("Walking", pa.float32()),
    ]
)

In [8]:
# HR Burst 1 Uncorrected
batch_predict_dates(
    pq_file_path="../data/raw/focal_sampling/acc/focal_sampled_acc_burst_1_uncorrected.parquet",
    model_dir="../models/hr_b1_unc",
    model_name="hr_b1_unc",
    output_parquet="../data/output/inference_results/hr_b1_unc_focal_sampled_random_predictions.parquet",
    schema=parquet_schema,
)

All model parameters and buffers are on CUDA.
Found 21 unique dates in the data.


Processing dates:   0%|          | 0/21 [00:00<?, ?it/s]

Extracted data for 25893 unique bursts on 2022-07-03.


Processing dates:   5%|▍         | 1/21 [00:31<10:38, 31.91s/it]

Extracted data for 25895 unique bursts on 2022-07-10.


Processing dates:  10%|▉         | 2/21 [01:03<10:00, 31.60s/it]

Extracted data for 25864 unique bursts on 2022-07-14.


Processing dates:  14%|█▍        | 3/21 [01:34<09:28, 31.58s/it]

Extracted data for 25893 unique bursts on 2022-07-15.


Processing dates:  19%|█▉        | 4/21 [02:06<08:55, 31.50s/it]

Extracted data for 25897 unique bursts on 2022-07-18.


Processing dates:  24%|██▍       | 5/21 [02:37<08:24, 31.55s/it]

Extracted data for 25894 unique bursts on 2022-07-19.


Processing dates:  29%|██▊       | 6/21 [03:09<07:53, 31.56s/it]

Extracted data for 25906 unique bursts on 2022-07-31.


Processing dates:  33%|███▎      | 7/21 [03:41<07:22, 31.63s/it]

Extracted data for 25915 unique bursts on 2022-09-03.


Processing dates:  38%|███▊      | 8/21 [04:12<06:51, 31.68s/it]

Extracted data for 25910 unique bursts on 2022-09-08.


Processing dates:  43%|████▎     | 9/21 [04:44<06:18, 31.50s/it]

Extracted data for 24722 unique bursts on 2022-09-09.


Processing dates:  48%|████▊     | 10/21 [05:13<05:40, 30.92s/it]

Extracted data for 25907 unique bursts on 2022-09-19.


Processing dates:  52%|█████▏    | 11/21 [05:44<05:09, 30.96s/it]

Extracted data for 25915 unique bursts on 2022-09-25.


Processing dates:  57%|█████▋    | 12/21 [06:15<04:39, 31.02s/it]

Extracted data for 25906 unique bursts on 2022-09-27.


Processing dates:  62%|██████▏   | 13/21 [06:46<04:08, 31.03s/it]

Extracted data for 25911 unique bursts on 2022-09-28.


Processing dates:  67%|██████▋   | 14/21 [07:18<03:37, 31.06s/it]

Extracted data for 21478 unique bursts on 2023-06-05.


Processing dates:  71%|███████▏  | 15/21 [07:43<02:56, 29.47s/it]

Extracted data for 21343 unique bursts on 2023-06-11.


Processing dates:  76%|███████▌  | 16/21 [08:09<02:21, 28.26s/it]

Extracted data for 21313 unique bursts on 2023-06-14.


Processing dates:  81%|████████  | 17/21 [08:35<01:49, 27.48s/it]

Extracted data for 21627 unique bursts on 2023-06-20.


Processing dates:  86%|████████▌ | 18/21 [09:01<01:21, 27.22s/it]

Extracted data for 21241 unique bursts on 2023-06-22.


Processing dates:  90%|█████████ | 19/21 [09:28<00:54, 27.01s/it]

Extracted data for 21366 unique bursts on 2023-06-28.


Processing dates:  95%|█████████▌| 20/21 [09:54<00:26, 26.67s/it]

Extracted data for 21256 unique bursts on 2023-06-30.


Processing dates: 100%|██████████| 21/21 [10:19<00:00, 29.51s/it]


Saved predictions to ../data/output/inference_results/hr_b1_unc_focal_sampled_random_predictions.parquet
