# Behaviour Cloning Evaluation

In [1]:
import sys
import os
import torch
import pandas as pd
from typing import List

if 'google.colab' in sys.modules:
  from google.colab import drive
  drive.mount( "/content/drive")
  if os.path.isdir('drive/MyDrive/Projects/Offline_RL_BSc_Thesis/notebooks'):
    os.chdir('drive/MyDrive/Projects/Offline_RL_BSc_Thesis/notebooks')


project_root = os.path.abspath(os.path.join(os.path.dirname("__file__"), "../"))
if project_root not in sys.path:
    sys.path.append(project_root)

torch_device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

## Test Evaluation

In [2]:
from src.utils.experiments import prepare_data
from src.normalization import NormalizationModule
from src.datasets import BCDataset

from torch.utils.data import DataLoader
from sklearn.metrics import (
    balanced_accuracy_score,
    accuracy_score,
    precision_recall_fscore_support,
    confusion_matrix,
    classification_report
)

  from .autonotebook import tqdm as notebook_tqdm


In [6]:
def evaluate_model_test_dataset(model: torch.nn.Module,
                                device: torch.device,
                                test_df: pd.DataFrame,
                                norm_technique_script: NormalizationModule,
                                selected_features: List[str]):
    model.eval()
    model.to(device)

    # Prepare normalized and filtered test data
    X_test, y_test = prepare_data(df=test_df,
                                  selected_features=selected_features,
                                  norm_script=norm_technique_script)

    test_dataset = BCDataset(states=X_test, actions=y_test)
    test_loader = DataLoader(test_dataset, batch_size=512, shuffle=False)

    all_preds, all_labels = [], []

    with torch.no_grad():
        for states, labels in test_loader:
            states, labels = states.to(device), labels.to(device)
            logits = model(states)
            preds = torch.argmax(torch.softmax(logits, dim=1), dim=1)
            all_preds.append(preds.cpu())
            all_labels.append(labels.cpu())

    # Aggregate predictions and labels
    y_pred = torch.cat(all_preds).numpy()
    y_true = torch.cat(all_labels).numpy()

    # === Compute metrics ===
    bal_acc = balanced_accuracy_score(y_true, y_pred)
    acc = accuracy_score(y_true, y_pred)

    precision, recall, f1, _ = precision_recall_fscore_support(
        y_true, y_pred, average=None, zero_division=0
    )
    precision_macro, recall_macro, f1_macro, _ = precision_recall_fscore_support(
        y_true, y_pred, average='macro', zero_division=0
    )
    precision_weighted, recall_weighted, f1_weighted, _ = precision_recall_fscore_support(
        y_true, y_pred, average='weighted', zero_division=0
    )

    cm = confusion_matrix(y_true, y_pred)
    class_report = classification_report(y_true, y_pred, output_dict=True, zero_division=0)

    # === Package results ===
    results = {
        "accuracy": acc,
        "balanced_accuracy": bal_acc,
        "precision_macro": precision_macro,
        "recall_macro": recall_macro,
        "f1_macro": f1_macro,
        "precision_weighted": precision_weighted,
        "recall_weighted": recall_weighted,
        "f1_weighted": f1_weighted,
        "precision_per_class": precision,
        "recall_per_class": recall,
        "f1_per_class": f1,
        "confusion_matrix": cm,
        "classification_report": class_report
    }

    # Optional: pretty print key metrics
    print("=" * 100)
    print("Test Set Evaluation")
    print(f"Accuracy:           {acc:.4f}")
    print(f"Balanced Accuracy:  {bal_acc:.4f}")
    print(f"Macro F1 Score:     {f1_macro:.4f}")
    print(f"Weighted F1 Score:  {f1_weighted:.4f}")
    print("=" * 100)

    return results

### Replay Buffer

In [7]:
rb_test_df = pd.read_parquet('../data/replay_buffer_episodes/rb_test.parquet').drop(columns=['done', 'episode'])
fp_normalization_technique = torch.jit.load('../models/replay_buffer/normalization/standard_normalization.pt')
model = torch.jit.load('../models/replay_buffer/BC_standard_refined.pt')


evaluate_model_test_dataset(model=model,
                            device=torch_device,
                            test_df=rb_test_df,
                            norm_technique_script=fp_normalization_technique,
                            selected_features=['X', 'Y', 'lv_X', 'lv_Y', 'reward', 'angle', 'angular_velocity', 'leg_1', 'leg_2'])

Test Set Evaluation
Accuracy:           0.5887
Balanced Accuracy:  0.5892
Macro F1 Score:     0.5172
Weighted F1 Score:  0.6211


{'accuracy': 0.5887092821503142,
 'balanced_accuracy': 0.5891766647181432,
 'precision_macro': 0.5156121821357602,
 'recall_macro': 0.5891766647181432,
 'f1_macro': 0.5172053144836699,
 'precision_weighted': 0.7093929326363385,
 'recall_weighted': 0.5887092821503142,
 'f1_weighted': 0.6210704773761083,
 'precision_per_class': array([0.55189678, 0.30887554, 0.92087411, 0.28080229]),
 'recall_per_class': array([0.53904604, 0.58807183, 0.60637734, 0.62321145]),
 'f1_per_class': array([0.54539573, 0.40502043, 0.73124461, 0.38716049]),
 'confusion_matrix': array([[14330,  5121,  1130,  6003],
        [ 1453,  5502,   961,  1440],
        [ 8359,  6318, 34344,  7617],
        [ 1823,   872,   860,  5880]]),
 'classification_report': {'0': {'precision': 0.551896784132486,
   'recall': 0.5390460427324707,
   'f1-score': 0.5453957258939276,
   'support': 26584.0},
  '1': {'precision': 0.30887554033570985,
   'recall': 0.5880718255664814,
   'f1-score': 0.40502042769332697,
   'support': 9356.0}

### Final Policy

In [8]:
fp_test_df = pd.read_parquet('../data/final_policy_episodes/fp_test.parquet').drop(columns=['done', 'episode'])
fp_normalization_technique = torch.jit.load('../models/final_policy/normalization/standard_normalization.pt')
model = torch.jit.load('../models/final_policy/BC_standard_refined.pt')


evaluate_model_test_dataset(model=model,
                            device=torch_device,
                            test_df=fp_test_df,
                            norm_technique_script=fp_normalization_technique,
                            selected_features=['X', 'Y', 'lv_X', 'lv_Y', 'reward', 'angle', 'angular_velocity', 'leg_1', 'leg_2'])

Test Set Evaluation
Accuracy:           0.9065
Balanced Accuracy:  0.9328
Macro F1 Score:     0.8528
Weighted F1 Score:  0.9107


{'accuracy': 0.9065051728488519,
 'balanced_accuracy': 0.9328487110450876,
 'precision_macro': 0.8033221480454353,
 'recall_macro': 0.9328487110450876,
 'f1_macro': 0.8528045910737406,
 'precision_weighted': 0.9239602808037797,
 'recall_weighted': 0.9065051728488519,
 'f1_weighted': 0.9107081263730141,
 'precision_per_class': array([0.90650891, 0.63735121, 0.98973795, 0.67969052]),
 'recall_per_class': array([0.89863507, 0.96805336, 0.89896804, 0.96573837]),
 'f1_per_class': array([0.90255482, 0.76864111, 0.94217183, 0.7978506 ]),
 'confusion_matrix': array([[30417,  1501,   465,  1465],
        [   95,  5515,    16,    71],
        [ 2938,  1577, 48609,   948],
        [  104,    60,    23,  5271]]),
 'classification_report': {'0': {'precision': 0.9065089110091197,
   'recall': 0.8986350744504845,
   'f1-score': 0.902554820331741,
   'support': 33848.0},
  '1': {'precision': 0.6373512076736392,
   'recall': 0.9680533614182903,
   'f1-score': 0.7686411149825784,
   'support': 5697.0},
