# Error-Type-Aware Analysis for SupervisedER Task

This notebook runs per-error-type performance analysis for the SupervisedER task.

## Prerequisites:
You need to have:
- Pre-extracted features (Omnivore and SlowFast) in `.npz` format
- Checkpoints from the official release (`error_recognition_best` directory)
- Annotation files (should be in the repository or uploaded separately)

## Quick Start:
1. Upload your data to Google Drive (or use direct upload)
2. Configure paths in Section 1
3. Run all cells sequentially


In [None]:
# ============================================
# CONFIGURE YOUR REPOSITORY
# ============================================
# Option 1: Clone from GitHub (recommended)
REPO_URL = "https://github.com/aexomir/AML_mistake_detection.git"
REPO_BRANCH = "feat/error-type-analysis-v2"  # Leave empty for default branch, or specify branch name

# Option 2: Manual upload - set REPO_URL to empty string and upload files manually
# REPO_URL = ""

REPO_DIR = "aml_repo"

print(f"Repository URL: {REPO_URL if REPO_URL else 'Manual upload mode'}")
print(f"Repository branch: {REPO_BRANCH if REPO_BRANCH else 'default'}")
print(f"Repository directory: {REPO_DIR}")


In [None]:
import os
import shutil

# Remove existing directory if it exists
if os.path.exists(REPO_DIR):
    print(f"Removing existing {REPO_DIR} directory...")
    shutil.rmtree(REPO_DIR)

# Clone repository
if REPO_URL:
    print(f"Cloning repository from {REPO_URL}...")
    clone_cmd = f"git clone {REPO_URL} {REPO_DIR}"
    result = os.system(clone_cmd)

    if result != 0:
        print(f"⚠ Clone failed. Please check the URL or upload files manually.")
        os.makedirs(REPO_DIR, exist_ok=True)
    else:
        print("✓ Repository cloned successfully")

        # Checkout specific branch if specified
        if REPO_BRANCH:
            print(f"Checking out branch: {REPO_BRANCH}")
            os.chdir(REPO_DIR)
            os.system(f"git checkout {REPO_BRANCH}")
            os.chdir('..')
            print(f"✓ Switched to branch: {REPO_BRANCH}")
else:
    print("Manual upload mode: Creating directory...")
    os.makedirs(REPO_DIR, exist_ok=True)

# Change to repository directory
if os.path.exists(REPO_DIR):
    os.chdir(REPO_DIR)
    print(f"\n✓ Changed to directory: {os.getcwd()}")
    print(f"\nRepository contents:")
    !ls -la
else:
    print(f"✗ Error: {REPO_DIR} directory not found!")


In [None]:
# Install dependencies
# Colab comes with PyTorch pre-installed, so we'll work with that
# Remove PyTorch version constraints to avoid conflicts
if os.path.exists('requirements.txt'):
    !sed -i '/^torch==/d' requirements.txt 2>/dev/null || true
    !sed -i '/^torchvision==/d' requirements.txt 2>/dev/null || true

# Install torcheval (required for evaluation metrics)
%pip install -q torcheval

# Install all remaining dependencies from requirements.txt
if os.path.exists('requirements.txt'):
    %pip install -q -r requirements.txt
elif os.path.exists('requirements-cpu.txt'):
    %pip install -q -r requirements-cpu.txt

print("✓ All dependencies installed successfully")

# Verify PyTorch installation
import torch
print(f"\nPyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"CUDA device: {torch.cuda.get_device_name(0)}")


## 2. Load Data: Features, Checkpoints, and Annotations


In [None]:
# ============================================
# CONFIGURE DATA PATHS
# ============================================
# Option 1: From Google Drive (recommended for large files)
USE_GOOGLE_DRIVE = True  # Set to False if uploading directly

# Paths on Google Drive (update these to match your Drive structure)
OMNIVORE_DRIVE_PATH = "/content/drive/MyDrive/AML_mistake_detection/omnivore.zip"  # Can be .zip or directory
SLOWFAST_DRIVE_PATH = "/content/drive/MyDrive/AML_mistake_detection/slowfast.zip"  # Can be .zip or directory
CHECKPOINTS_DRIVE_PATH = "/content/drive/MyDrive/AML_mistake_detection/error_recognition_best.zip"  # Can be .zip or directory
ANNOTATIONS_DRIVE_PATH = "/content/drive/MyDrive/AML_mistake_detection/annotations"  # Optional if in repo

# Option 2: Direct upload - set USE_GOOGLE_DRIVE = False and upload files in next cell

print("Data paths configured:")
print(f"  Use Google Drive: {USE_GOOGLE_DRIVE}")
print(f"  Omnivore: {OMNIVORE_DRIVE_PATH}")
print(f"  SlowFast: {SLOWFAST_DRIVE_PATH}")
print(f"  Checkpoints: {CHECKPOINTS_DRIVE_PATH}")
print(f"  Annotations: {ANNOTATIONS_DRIVE_PATH}")


In [None]:
# Mount Google Drive if using it
if USE_GOOGLE_DRIVE:
    from google.colab import drive
    drive.mount('/content/drive')
    print("✓ Google Drive mounted")
else:
    print("⚠ Google Drive not mounted. Please upload files directly using the file browser.")


In [None]:
# Create data directory structure
import os
os.makedirs('data/video/omnivore', exist_ok=True)
os.makedirs('data/video/slowfast', exist_ok=True)
os.makedirs('checkpoints', exist_ok=True)
os.makedirs('annotations/annotation_json', exist_ok=True)
os.makedirs('annotations/data_splits', exist_ok=True)
os.makedirs('er_annotations', exist_ok=True)
os.makedirs('analysis/outputs', exist_ok=True)

print("✓ Directory structure created")


In [None]:
# Load features from Google Drive or direct upload
import os
import shutil
import subprocess
import glob

def load_features(source_path, dest_path, feature_name):
    """Load features from source (zip file or directory) to destination."""
    if not os.path.exists(source_path):
        print(f"⚠ {feature_name}: Source path not found: {source_path}")
        return False

    print(f"Loading {feature_name} features from: {source_path}")

    # Check if it's a zip file
    is_zip = source_path.lower().endswith('.zip') or (os.path.isfile(source_path) and 'zip' in str(source_path))

    if is_zip:
        print(f"  Detected zip file, extracting...")
        temp_zip = f'/tmp/{feature_name.lower()}.zip'
        temp_extracted = f'/tmp/{feature_name.lower()}_extracted'

        try:
            shutil.copy(source_path, temp_zip)
            subprocess.run(['unzip', '-q', temp_zip, '-d', temp_extracted], check=True)

            # Find .npz files in extracted directory
            npz_files = glob.glob(os.path.join(temp_extracted, '**/*.npz'), recursive=True)

            if npz_files:
                # Copy all .npz files to destination
                for npz_file in npz_files:
                    shutil.copy2(npz_file, dest_path)
                print(f"  ✓ Extracted and copied {len(npz_files)} .npz files")

                # Cleanup
                shutil.rmtree(temp_extracted, ignore_errors=True)
                os.remove(temp_zip)
                return True
            else:
                print(f"  ⚠ No .npz files found in extracted zip")
                shutil.rmtree(temp_extracted, ignore_errors=True)
                os.remove(temp_zip)
                return False
        except Exception as e:
            print(f"  ✗ Error extracting {feature_name} zip: {e}")
            if os.path.exists(temp_extracted):
                shutil.rmtree(temp_extracted, ignore_errors=True)
            if os.path.exists(temp_zip):
                os.remove(temp_zip)
            return False
    else:
        # It's a directory
        print(f"  Detected directory, copying .npz files...")
        npz_files = glob.glob(os.path.join(source_path, '**/*.npz'), recursive=True)

        if npz_files:
            # Copy all .npz files to destination
            for npz_file in npz_files:
                shutil.copy2(npz_file, dest_path)
            print(f"  ✓ Copied {len(npz_files)} .npz files")
            return True
        else:
            print(f"  ⚠ No .npz files found in {source_path}")
            return False

# Load Omnivore and SlowFast features
if USE_GOOGLE_DRIVE:
    load_features(OMNIVORE_DRIVE_PATH, 'data/video/omnivore', 'Omnivore')
    load_features(SLOWFAST_DRIVE_PATH, 'data/video/slowfast', 'SlowFast')
else:
    print("⚠ Please upload features manually:")
    print("  1. Use the file browser to upload .npz files or zip files")
    print("  2. Extract/copy them to data/video/omnivore/ and data/video/slowfast/")

# Verify features
omnivore_count = len([f for f in os.listdir('data/video/omnivore') if f.endswith('.npz')]) if os.path.exists('data/video/omnivore') else 0
slowfast_count = len([f for f in os.listdir('data/video/slowfast') if f.endswith('.npz')]) if os.path.exists('data/video/slowfast') else 0
print(f"\nFeature file counts:")
print(f"  Omnivore: {omnivore_count} .npz files")
print(f"  SlowFast: {slowfast_count} .npz files")


In [None]:
# Load checkpoints
import os
import shutil
import subprocess

checkpoint_path = CHECKPOINTS_DRIVE_PATH if USE_GOOGLE_DRIVE else None

if checkpoint_path and os.path.exists(checkpoint_path):
    print(f"Loading checkpoints from: {checkpoint_path}")

    # Check if it's a zip file
    is_zip = checkpoint_path.lower().endswith('.zip') or (os.path.isfile(checkpoint_path) and 'zip' in str(checkpoint_path))

    if is_zip:
        print("Detected zip file, extracting...")
        shutil.copy(checkpoint_path, '/tmp/checkpoints.zip')

        try:
            subprocess.run(['unzip', '-q', '/tmp/checkpoints.zip', '-d', '/tmp/checkpoints_extracted'], check=True)

            # Find error_recognition_best directory
            extracted_base = '/tmp/checkpoints_extracted'
            extracted_path = None

            # Check common locations
            if os.path.exists(os.path.join(extracted_base, 'error_recognition_best')):
                extracted_path = os.path.join(extracted_base, 'error_recognition_best')
            elif os.path.exists(os.path.join(extracted_base, 'MLP')) or os.path.exists(os.path.join(extracted_base, 'Transformer')):
                extracted_path = extracted_base
            else:
                # Search recursively
                for root, dirs, files in os.walk(extracted_base):
                    if 'error_recognition_best' in dirs:
                        extracted_path = os.path.join(root, 'error_recognition_best')
                        break
                    if 'MLP' in dirs or 'Transformer' in dirs:
                        extracted_path = root
                        break

                if extracted_path is None:
                    extracted_path = extracted_base

            print(f"Copying from: {extracted_path}")
            shutil.copytree(extracted_path, 'checkpoints/error_recognition_best', dirs_exist_ok=True)

            # Cleanup
            shutil.rmtree('/tmp/checkpoints_extracted', ignore_errors=True)
            os.remove('/tmp/checkpoints.zip')
            print("✓ Checkpoints extracted")
        except Exception as e:
            print(f"✗ Error extracting checkpoints: {e}")
    else:
        # It's a directory
        print("Detected directory, copying...")
        if os.path.basename(checkpoint_path) == 'error_recognition_best':
            shutil.copytree(checkpoint_path, 'checkpoints/error_recognition_best', dirs_exist_ok=True)
        else:
            os.makedirs('checkpoints/error_recognition_best', exist_ok=True)
            for item in os.listdir(checkpoint_path):
                src = os.path.join(checkpoint_path, item)
                dst = os.path.join('checkpoints/error_recognition_best', item)
                if os.path.isdir(src):
                    shutil.copytree(src, dst, dirs_exist_ok=True)
                else:
                    shutil.copy2(src, dst)
        print("✓ Checkpoints copied")
else:
    print("⚠ Checkpoints not found. Please upload manually:")
    print("  1. Download from: https://utdallas.app.box.com/s/uz3s1alrzucz03sleify8kazhuc1ksl3")
    print("  2. Extract error_recognition_best directory")
    print("  3. Upload to checkpoints/error_recognition_best/")

# Verify checkpoints
if os.path.exists('checkpoints/error_recognition_best'):
    pt_files = []
    for root, dirs, files in os.walk('checkpoints/error_recognition_best'):
        pt_files.extend([os.path.join(root, f) for f in files if f.endswith('.pt')])
    print(f"\n✓ Found {len(pt_files)} checkpoint files")
    if pt_files:
        print("\nSample checkpoint files:")
        for f in pt_files[:3]:
            print(f"  {f}")
else:
    print("\n✗ Checkpoints directory not found")


In [None]:
# Load annotations (if not already in repository)
import os
import shutil

if USE_GOOGLE_DRIVE and os.path.exists(ANNOTATIONS_DRIVE_PATH):
    print(f"Loading annotations from: {ANNOTATIONS_DRIVE_PATH}")

    # Copy annotation_json
    annotation_json_src = os.path.join(ANNOTATIONS_DRIVE_PATH, 'annotation_json')
    if os.path.exists(annotation_json_src):
        for file in os.listdir(annotation_json_src):
            src = os.path.join(annotation_json_src, file)
            dst = os.path.join('annotations/annotation_json', file)
            if os.path.isfile(src):
                shutil.copy2(src, dst)
                print(f"  ✓ Copied {file}")

    # Copy data_splits
    data_splits_src = os.path.join(ANNOTATIONS_DRIVE_PATH, 'data_splits')
    if os.path.exists(data_splits_src):
        for file in os.listdir(data_splits_src):
            src = os.path.join(data_splits_src, file)
            dst = os.path.join('annotations/data_splits', file)
            if os.path.isfile(src):
                shutil.copy2(src, dst)
                print(f"  ✓ Copied {file}")

    # Copy er_annotations
    er_annotations_src = os.path.join(ANNOTATIONS_DRIVE_PATH, 'er_annotations')
    if os.path.exists(er_annotations_src):
        for file in os.listdir(er_annotations_src):
            src = os.path.join(er_annotations_src, file)
            dst = os.path.join('er_annotations', file)
            if os.path.isfile(src):
                shutil.copy2(src, dst)
                print(f"  ✓ Copied {file}")
else:
    print("⚠ Annotations not found in Drive. Checking repository...")

# Verify required annotation files
print("\nVerifying annotation files...")
required_files = [
    'annotations/annotation_json/step_annotations.json',
    'annotations/annotation_json/error_annotations.json',
    'er_annotations/recordings_combined_splits.json'
]

missing = []
for file in required_files:
    if os.path.exists(file):
        print(f"✓ Found: {file}")
    else:
        print(f"✗ Missing: {file}")
        missing.append(file)

if missing:
    print(f"\n⚠ Warning: {len(missing)} required annotation file(s) are missing!")
    print("Please ensure these files are available before running the analysis.")
else:
    print("\n✓ All required annotation files are present!")


## 3. Run Error-Type Analysis

Run the analysis script with your model checkpoint and configuration.

**Note**: Use threshold 0.6 for `step` split and 0.4 for `recordings` split.


In [None]:
# Example: Omnivore - MLP - Step split
# This should reproduce: F1=24.26, AUC=75.74
!python analysis/error_type_analysis.py \
    --split step \
    --backbone omnivore \
    --variant MLP \
    --ckpt checkpoints/error_recognition_best/MLP/omnivore/error_recognition_MLP_omnivore_step_epoch_43.pt \
    --threshold 0.6 \
    --output-dir analysis/outputs


In [None]:
# Example: Omnivore - Transformer - Step split
# This should reproduce: F1=55.39, AUC=75.62
# Update the epoch number in the checkpoint path
!python analysis/error_type_analysis.py \
    --split step \
    --backbone omnivore \
    --variant Transformer \
    --ckpt checkpoints/error_recognition_best/Transformer/omnivore/error_recognition_Transformer_omnivore_step_epoch_9.pt \
    --threshold 0.6 \
    --output-dir analysis/outputs


In [None]:
# Example: Omnivore - MLP - Recordings split
# This should reproduce: F1=55.42, AUC=63.03
# Update the epoch number in the checkpoint path
!python analysis/error_type_analysis.py \
    --split recordings \
    --backbone omnivore \
    --variant MLP \
    --ckpt checkpoints/error_recognition_best/MLP/omnivore/error_recognition_MLP_omnivore_recordings_epoch_33.pt \
    --threshold 0.4 \
    --output-dir analysis/outputs


In [None]:
# Example: Omnivore - Transformer - Recordings split
# This should reproduce: F1=40.73, AUC=62.27
# Update the epoch number in the checkpoint path
!python analysis/error_type_analysis.py \
    --split recordings \
    --backbone omnivore \
    --variant Transformer \
    --ckpt checkpoints/error_recognition_best/Transformer/omnivore/error_recognition_Transformer_omnivore_recordings_epoch_31.pt \
    --threshold 0.4 \
    --output-dir analysis/outputs


In [None]:
# List available checkpoints to find correct epoch numbers
import os
import glob

checkpoint_base = 'checkpoints/error_recognition_best'
if os.path.exists(checkpoint_base):
    print("Available checkpoints:")
    for ckpt_file in sorted(glob.glob(os.path.join(checkpoint_base, '**/*.pt'), recursive=True)):
        print(f"  {ckpt_file}")
else:
    print("Checkpoints directory not found")


## 4. Display Results

Load and display the results from the analysis.


In [None]:
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

# Load results CSV
results_df = pd.read_csv('analysis/outputs/error_type_analysis_step_omnivore_MLP_threshold_0.6.csv')

# Display table
print("Per-Error-Type Performance Metrics:")
print("=" * 80)
display(results_df)


In [None]:
# Display saved plots
from IPython.display import Image, display

fig_path = 'analysis/outputs/error_type_analysis_step_omnivore_MLP_threshold_0.6.png'
if os.path.exists(fig_path):
    display(Image(fig_path))
else:
    print(f"Plot not found at {fig_path}")
