# KNN Aggregator vs IBM Granite-4.0-H-Tiny Workflow

This notebook runs the complete workflow to compare KNN Aggregator with IBM Granite-4.0-H-Tiny model on benchmark datasets.

## Workflow Steps:
1. **Setup**: Install dependencies and setup project
2. **Sampling**: Generate KNN reference data from HH-RLHF dataset
3. **Evaluation**: Parallel evaluation of KNN Aggregator and IBM Granite on benchmark
4. **Comparison**: Display performance metrics and improvements


## Step 1: Install Dependencies


In [None]:
# Install required packages
%pip install -q transformers>=4.44 torch scikit-learn datasets==3.6.0 huggingface_hub safetensors tqdm pandas numpy

# Verify installation
import sys
print(f"Python version: {sys.version}")
print("‚úÖ Dependencies installed")


## Step 2: Setup Project Files

**Note**: The notebook will automatically checkout the `experimental-knn` branch which contains all KNN files.



In [1]:
import os

# Ensure we are in the /content directory before cloning
# This is important for consistent project structure
os.chdir('/content')

# Clone the repository if it doesn't already exist
repo_name = "ArmyOfSafeguards"
if not os.path.exists(repo_name):
    !git clone https://github.com/SohamNagi/ArmyOfSafeguards.git
else:
    print(f"Repository '{repo_name}' already exists. Skipping clone.")

# Change into the repository directory to perform checkout
os.chdir(repo_name)

# Checkout the experimental-knn branch
!git checkout experimental-knn
!git pull
# The working directory is now /content/ArmyOfSafeguards, which is where subsequent scripts expect the 'aggregator' directory to be.
print(f"‚úÖ Project files set up in: {os.getcwd()}")

Repository 'ArmyOfSafeguards' already exists. Skipping clone.
M	aggregator/knn_reference_hh_rlhf_full.checkpoint.json
Already on 'experimental-knn'
Your branch is up to date with 'origin/experimental-knn'.
remote: Enumerating objects: 7, done.[K
remote: Counting objects: 100% (7/7), done.[K
remote: Compressing objects: 100% (1/1), done.[K
remote: Total 4 (delta 3), reused 4 (delta 3), pack-reused 0 (from 0)[K
Unpacking objects: 100% (4/4), 765 bytes | 69.00 KiB/s, done.
From https://github.com/SohamNagi/ArmyOfSafeguards
   573aebe..5eece91  experimental-knn -> origin/experimental-knn
Updating 573aebe..5eece91
Fast-forward
 aggregator/evaluate_vs_granite.py | 25 [32m+++++++++++++++++++++[m[31m----[m
 1 file changed, 21 insertions(+), 4 deletions(-)
‚úÖ Project files set up in: /content/ArmyOfSafeguards


## Step 3: Generate KNN Reference Data (Sampling)

This step samples the HH-RLHF dataset and generates reference data for the KNN aggregator by running all 4 safeguards on the dataset.

**Note**: Reference data files are saved to the current working directory (`/content/ArmyOfSafeguards/`), not in the `aggregator/` subdirectory. You can use all available processed data even if generation is interrupted - the script will resume from the last checkpoint.


In [2]:
# Generate KNN reference data from HH-RLHF dataset
# This will download the dataset and run all 4 safeguards (factuality, toxicity, sexual, jailbreak)

import os

# Make sure we're in the right directory
if not os.path.exists("aggregator"):
    print("‚ùå Error: aggregator directory not found!")
    print(f"Current directory: {os.getcwd()}")
    print("Please run Step 2 first to setup project files.")
    raise FileNotFoundError("aggregator directory not found")

script_path = "aggregator/generate_knn_reference_hh_rlhf_full.py"
if os.path.exists(script_path):
    print(f"‚úÖ Found script: {script_path}")
    print(f"Current directory: {os.getcwd()}")
    !python aggregator/generate_knn_reference_hh_rlhf_full.py
else:
    print(f"‚ùå Script not found: {script_path}")
    print(f"Current directory: {os.getcwd()}")
    print(f"Files in current directory: {os.listdir('.')}")
    if os.path.exists("aggregator"):
        print(f"Files in aggregator/: {os.listdir('aggregator')}")


‚úÖ Found script: aggregator/generate_knn_reference_hh_rlhf_full.py
Current directory: /content/ArmyOfSafeguards
Loading Anthropic/hh-rlhf dataset (subset='harmless-base', split='train')...
Processing 42537 prompt pairs (~85074 responses) with batch size 64.
üîÅ Resuming from checkpoint at /content/ArmyOfSafeguards/aggregator/knn_reference_hh_rlhf_full.checkpoint.json
Chosen responses (safe):   0% 0/42039 [00:00<?, ?it/s]2025-11-27 15:46:28.452515: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1764258388.495280   14458 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1764258388.509486   14458 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
W0000 00:00

## Step 4: Compare KNN Aggregator vs IBM Granite-4.0-H-Tiny

This step evaluates both models in parallel on the benchmark dataset and compares their performance.

**Parameters**:
- `--limit`: Number of examples to evaluate (default: 100)
- `--threshold`: Confidence threshold for KNN aggregator (default: 0.7)
- `--dataset`: Benchmark dataset to use (default: hh-rlhf)


In [6]:
import torch
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"CUDA version: {torch.version.cuda}")

PyTorch version: 2.9.0+cu126
CUDA available: False


In [None]:
# Compare KNN Aggregator vs IBM Granite on benchmark dataset
# Adjust parameters as needed:
#   --limit: Number of examples (default: 100)
#   --threshold: Confidence threshold (default: 0.7)
#   --dataset: Dataset name (default: hh-rlhf)

import os
import torch
from pathlib import Path

# Make sure we're in the right directory
if not os.path.exists("aggregator"):
    print("‚ùå Error: aggregator directory not found!")
    print(f"Current directory: {os.getcwd()}")
    print("Please run Step 2 first to setup project files.")
    raise FileNotFoundError("aggregator directory not found")

script_path = "aggregator/evaluate_vs_granite.py"

# Check for reference data in current directory first (new location), then aggregator/ (old location)
ref_path_current = "knn_reference_hh_rlhf_full.jsonl"
ref_path_old = "aggregator/knn_reference_hh_rlhf_full.jsonl"

if os.path.exists(ref_path_current):
    ref_path = ref_path_current
    print(f"‚úÖ Found reference data in current directory: {ref_path}")
elif os.path.exists(ref_path_old):
    ref_path = ref_path_old
    print(f"‚úÖ Found reference data in aggregator/ (legacy location): {ref_path}")
else:
    ref_path = None
    print(f"‚ö†Ô∏è  Reference data not found in current directory or aggregator/")
    print(f"   Checked: {os.path.abspath(ref_path_current)}")
    print(f"   Checked: {os.path.abspath(ref_path_old)}")
    print("Please run Step 3 first to generate reference data.")

if os.path.exists(script_path):
    if ref_path and os.path.exists(ref_path):
        print(f"‚úÖ Found script: {script_path}")
        print(f"Current directory: {os.getcwd()}")
        !python aggregator/evaluate_vs_granite.py --dataset hh-rlhf --limit 100 --knn-reference {ref_path} --threshold 0.7
    else:
        print(f"‚ö†Ô∏è  Reference data not found. Please run Step 3 first to generate reference data.")
else:
    print(f"‚ùå Script not found: {script_path}")
    print(f"Current directory: {os.getcwd()}")


‚úÖ Found script: aggregator/evaluate_vs_granite.py
‚úÖ Found reference data: aggregator/knn_reference_hh_rlhf_full.jsonl
Current directory: /content/ArmyOfSafeguards

COMPARING KNN AGGREGATOR vs IBM GRANITE-4.0-H-TINY

[1/3] Loading KNN reference data from: aggregator/knn_reference_hh_rlhf_full.jsonl
[KNN] Fitted with 4930 reference samples, k=7
[KNN] Loaded 4930 reference samples
‚úÖ KNN reference data loaded

[2/3] Loading IBM Granite model...
Loading IBM Granite model: ibm-granite/granite-4.0-h-tiny
‚ö†Ô∏è  Using CPU for Granite model (slower)
Loading model shards (this may take a moment)...
2025-11-27 20:28:30.819825: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1764275311.008909   81871 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:

## Step 5: View Comparison Results


In [None]:
import json
from pathlib import Path
from datetime import datetime

# Find the latest evaluation results (check both current directory and aggregator/)
result_files = list(Path(".").glob("evaluation_knn_vs_granite_*.json"))
result_files.extend(Path("aggregator").glob("evaluation_knn_vs_granite_*.json"))
if result_files:
    latest_file = max(result_files, key=lambda p: p.stat().st_mtime)
    print(f"Latest results: {latest_file}")

    with open(latest_file) as f:
        results = json.load(f)

    print("\n" + "="*60)
    print("COMPARISON RESULTS: KNN Aggregator vs IBM Granite-4.0-H-Tiny")
    print("="*60)

    if "knn_aggregator" in results and "ibm_granite" in results:
        knn = results["knn_aggregator"]
        granite = results["ibm_granite"]

        print("\nKNN Aggregator Performance:")
        print(f"  Accuracy:  {knn.get('accuracy', 0):.2%}")
        print(f"  Precision: {knn.get('precision', 0):.2%}")
        print(f"  Recall:    {knn.get('recall', 0):.2%}")
        print(f"  F1-Score:  {knn.get('f1_score', 0):.2%}")

        print("\nIBM Granite-4.0-H-Tiny Performance:")
        print(f"  Accuracy:  {granite.get('accuracy', 0):.2%}")
        print(f"  Precision: {granite.get('precision', 0):.2%}")
        print(f"  Recall:    {granite.get('recall', 0):.2%}")
        print(f"  F1-Score:  {granite.get('f1_score', 0):.2%}")

        if "improvement" in results:
            imp = results["improvement"]
            print("\nPerformance Improvement (KNN vs Granite):")
            print(f"  Accuracy:  {imp.get('accuracy', {}).get('percentage', 0):+.1f}% ({imp.get('accuracy', {}).get('absolute', 0):+.2%})")
            print(f"  Precision: {imp.get('precision', {}).get('percentage', 0):+.1f}% ({imp.get('precision', {}).get('absolute', 0):+.2%})")
            print(f"  Recall:    {imp.get('recall', {}).get('percentage', 0):+.1f}% ({imp.get('recall', {}).get('absolute', 0):+.2%})")
            print(f"  F1-Score:  {imp.get('f1_score', {}).get('percentage', 0):+.1f}% ({imp.get('f1_score', {}).get('absolute', 0):+.2%})")
else:
    print("No results found. Make sure Step 5 completed successfully.")
