# Behaviour Cloning Evaluation

In [1]:
import sys
import os
import torch
import pandas as pd
from typing import List

if 'google.colab' in sys.modules:
  from google.colab import drive
  drive.mount( "/content/drive")
  if os.path.isdir('drive/MyDrive/Projects/Offline_RL_BSc_Thesis/notebooks'):
    os.chdir('drive/MyDrive/Projects/Offline_RL_BSc_Thesis/notebooks')


project_root = os.path.abspath(os.path.join(os.path.dirname("__file__"), "../"))
if project_root not in sys.path:
    sys.path.append(project_root)

torch_device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

## Test Evaluation

In [2]:
from src.utils.experiments import prepare_data
from src.normalization import NormalizationModule
from src.datasets import BCDataset

from torch.utils.data import DataLoader
from sklearn.metrics import (
    balanced_accuracy_score,
    accuracy_score,
    precision_recall_fscore_support,
    confusion_matrix,
    classification_report
)

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
def evaluate_model_test_dataset(model: torch.nn.Module,
                                device: torch.device,
                                test_df: pd.DataFrame,
                                norm_technique_script: NormalizationModule,
                                selected_features: List[str]):
    model.eval()
    model.to(device)

    # Prepare normalized and filtered test data
    X_test, y_test = prepare_data(df=test_df,
                                  selected_features=selected_features,
                                  norm_script=norm_technique_script)

    test_dataset = BCDataset(states=X_test, actions=y_test)
    test_loader = DataLoader(test_dataset, batch_size=512, shuffle=False)

    all_preds, all_labels = [], []

    with torch.no_grad():
        for states, labels in test_loader:
            states, labels = states.to(device), labels.to(device)
            logits = model(states)
            preds = torch.argmax(torch.softmax(logits, dim=1), dim=1)
            all_preds.append(preds.cpu())
            all_labels.append(labels.cpu())

    # Aggregate predictions and labels
    y_pred = torch.cat(all_preds).numpy()
    y_true = torch.cat(all_labels).numpy()

    # === Compute metrics ===
    bal_acc = balanced_accuracy_score(y_true, y_pred)
    acc = accuracy_score(y_true, y_pred)

    precision, recall, f1, _ = precision_recall_fscore_support(
        y_true, y_pred, average=None, zero_division=0
    )
    precision_macro, recall_macro, f1_macro, _ = precision_recall_fscore_support(
        y_true, y_pred, average='macro', zero_division=0
    )
    precision_weighted, recall_weighted, f1_weighted, _ = precision_recall_fscore_support(
        y_true, y_pred, average='weighted', zero_division=0
    )

    cm = confusion_matrix(y_true, y_pred)
    class_report = classification_report(y_true, y_pred, output_dict=True, zero_division=0)

    # === Package results ===
    results = {
        "accuracy": acc,
        "balanced_accuracy": bal_acc,
        "precision_macro": precision_macro,
        "recall_macro": recall_macro,
        "f1_macro": f1_macro,
        "precision_weighted": precision_weighted,
        "recall_weighted": recall_weighted,
        "f1_weighted": f1_weighted,
        "precision_per_class": precision,
        "recall_per_class": recall,
        "f1_per_class": f1,
        "confusion_matrix": cm,
        "classification_report": class_report
    }

    # Optional: pretty print key metrics
    print("=" * 100)
    print("✅ Test Set Evaluation")
    print(f"Accuracy:           {acc:.4f}")
    print(f"Balanced Accuracy:  {bal_acc:.4f}")
    print(f"Macro F1 Score:     {f1_macro:.4f}")
    print(f"Weighted F1 Score:  {f1_weighted:.4f}")
    print("=" * 100)

    return results

### Replay Buffer

### Final Policy

In [4]:
pd.read_parquet('../data/replay_buffer_episodes/rb_train.parquet')

Unnamed: 0,X,Y,lv_X,lv_Y,angle,angular_velocity,leg_1,leg_2,action,reward,done,episode
0,-0.004820,1.401093,-0.488218,-0.436776,0.005592,0.110589,False,False,0,-1.197077,False,0
1,-0.009640,1.390689,-0.487559,-0.462411,0.011056,0.109301,False,False,1,-2.399095,False,0
2,-0.014555,1.379694,-0.499388,-0.488779,0.018885,0.156586,False,False,3,-0.752683,False,0
3,-0.019400,1.368104,-0.490665,-0.515199,0.024956,0.121428,False,False,2,3.208104,False,0
4,-0.024134,1.357258,-0.480146,-0.482138,0.031669,0.134276,False,False,3,-0.579223,False,0
...,...,...,...,...,...,...,...,...,...,...,...,...
801678,-0.020882,0.096869,0.037885,-0.071286,0.023694,0.022519,False,False,2,0.751179,False,1724
801679,-0.020388,0.095725,0.047698,-0.050899,0.025377,0.033660,False,False,0,-2.117649,False,1724
801680,-0.019894,0.093980,0.047698,-0.077566,0.027060,0.033660,False,False,0,-2.285686,False,1724
801681,-0.019400,0.091636,0.047698,-0.104233,0.028743,0.033660,False,False,2,1.475848,False,1724


In [5]:
fp_test_df = pd.read_parquet('../data/replay_buffer_episodes/rb_test.parquet').drop(columns=['done', 'episode'])
fp_normalization_technique = torch.jit.load('../models/replay_buffer/normalization/max_abs_normalization.pt')
model = torch.jit.load('../models/replay_buffer/BC_raw_stratified.pt')
fp_normalization_technique=None


evaluate_model_test_dataset(model=model,
                            device=torch_device,
                            test_df=fp_test_df,
                            norm_technique_script=fp_normalization_technique,
                            selected_features=['X', 'Y', 'lv_X', 'lv_Y', 'reward', 'angle', 'angular_velocity', 'leg_1', 'leg_2'])

✅ Test Set Evaluation
Accuracy:           0.6548
Balanced Accuracy:  0.6576
Macro F1 Score:     0.5891
Weighted F1 Score:  0.6743


{'accuracy': 0.6547694901630184,
 'balanced_accuracy': 0.6575967247215885,
 'precision_macro': 0.5704658348790328,
 'recall_macro': 0.6575967247215885,
 'f1_macro': 0.5891156980924553,
 'precision_weighted': 0.7370686562128997,
 'recall_weighted': 0.6547694901630184,
 'f1_weighted': 0.6742917257457048,
 'precision_per_class': array([0.61963831, 0.37361647, 0.91232639, 0.37628217]),
 'recall_per_class': array([0.71403852, 0.62056434, 0.63091917, 0.66486486]),
 'f1_per_class': array([0.6634975 , 0.46642031, 0.74596581, 0.48057918]),
 'confusion_matrix': array([[18982,  2849,  1072,  3681],
        [ 1652,  5806,  1247,   651],
        [ 8467,  6371, 35734,  6066],
        [ 1533,   514,  1115,  6273]]),
 'classification_report': {'0': {'precision': 0.6196383103740941,
   'recall': 0.7140385194101715,
   'f1-score': 0.6634975007864657,
   'support': 26584.0},
  '1': {'precision': 0.3736164736164736,
   'recall': 0.6205643437366396,
   'f1-score': 0.46642030848329047,
   'support': 9356.0}