<a href="https://colab.research.google.com/github/aexomir/AML_mistake_detection/blob/feat%2Frnn/notebooks/rnn_baseline_colab.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# RNN Baseline Training for CaptainCook4D SupervisedER

Train RNN/LSTM baseline for mistake detection and compare against MLP and Transformer baselines.

## Workflow:
1. Setup repository and dependencies
2. Load features and annotations from HuggingFace
3. Train RNN models (Omnivore + SlowFast)
4. Evaluate and compare results

## Prerequisites:
- Set Colab secrets: `WANDB_API_KEY`, `HF_TOKEN`
- HuggingFace datasets: features and annotations
- Run all cells sequentially


In [None]:
# ============================================
# CONFIGURE YOUR REPOSITORY
# ============================================
# Option 1: Clone from GitHub (recommended)
REPO_URL = "https://github.com/aexomir/AML_mistake_detection.git"
REPO_BRANCH = "feat/rnn"  # Leave empty for default branch, or specify branch name

# Option 2: Manual upload - set REPO_URL to empty string and upload files manually
# REPO_URL = ""

REPO_DIR = "code"

print(f"Repository URL: {REPO_URL if REPO_URL else 'Manual upload mode'}")
print(f"Repository branch: {REPO_BRANCH if REPO_BRANCH else 'default'}")
print(f"Repository directory: {REPO_DIR}")


In [None]:
import os
import shutil

# Remove existing directory if it exists
if os.path.exists(REPO_DIR):
    print(f"Removing existing {REPO_DIR} directory...")
    shutil.rmtree(REPO_DIR)

# Clone repository
if REPO_URL:
    print(f"Cloning repository from {REPO_URL}...")
    clone_cmd = f"git clone {REPO_URL} {REPO_DIR}"
    result = os.system(clone_cmd)

    if result != 0:
        print(f"⚠ Clone failed. Please check the URL or upload files manually.")
        os.makedirs(REPO_DIR, exist_ok=True)
    else:
        print("✓ Repository cloned successfully")

        # Checkout specific branch if specified
        if REPO_BRANCH:
            print(f"Checking out branch: {REPO_BRANCH}")
            os.chdir(REPO_DIR)
            os.system(f"git checkout {REPO_BRANCH}")
            os.chdir('..')
            print(f"✓ Switched to branch: {REPO_BRANCH}")
else:
    print("Manual upload mode: Creating directory...")
    os.makedirs(REPO_DIR, exist_ok=True)

# Change to repository directory
if os.path.exists(REPO_DIR):
    os.chdir(REPO_DIR)
    print(f"\n✓ Changed to directory: {os.getcwd()}")
    print(f"\nRepository contents:")
    !ls -la
else:
    print(f"✗ Error: {REPO_DIR} directory not found!")


In [None]:
# Verify repository structure
import os

print(f"Current working directory: {os.getcwd()}")
print(f"\nChecking repository structure...")

required_items = [
    'scripts/train_rnn_baseline.py',
    'core/evaluate.py',
    'dataloader',
    'base.py',
    'constants.py'
]

missing = []
for item in required_items:
    if os.path.exists(item):
        print(f"✓ Found: {item}")
    else:
        print(f"✗ Missing: {item}")
        missing.append(item)

if missing:
    print(f"\n⚠ Warning: Some required files/directories are missing!")
    print(f"Please ensure all files are present before proceeding.")
else:
    print(f"\n✓ Repository structure looks good!")


In [None]:
# Install dependencies
# Colab comes with PyTorch pre-installed, so we'll work with that
# Remove PyTorch version constraints to avoid conflicts
if os.path.exists('requirements.txt'):
    !sed -i '/^torch==/d' requirements.txt 2>/dev/null || true
    !sed -i '/^torchvision==/d' requirements.txt 2>/dev/null || true

# Install torcheval (required for evaluation metrics)
!pip install -q torcheval

# Install all remaining dependencies from requirements.txt
if os.path.exists('requirements.txt'):
    !pip install -q -r requirements.txt
elif os.path.exists('requirements-cpu.txt'):
    !pip install -q -r requirements-cpu.txt

# Install additional dependencies for RNN baseline
!pip install -q wandb loguru

print("✓ All dependencies installed successfully")

# Verify PyTorch installation
import torch
print(f"\nPyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"CUDA device: {torch.cuda.get_device_name(0)}")


In [None]:
# HuggingFace Configuration (placeholder - will be set after authentication)
HF_USERNAME = None
HF_DATASET_REPO = None
HF_BASELINES_REPO = None

print("="*60)
print("HuggingFace Configuration")
print("="*60)
print("Configuration will be set after authentication")
print("="*60)


In [None]:
## 2.5. Authentication & HuggingFace Setup


In [None]:
# Authentication Configuration
from google.colab import userdata
from huggingface_hub import HfApi

WANDB_API_KEY = userdata.get('WANDB_API_KEY')
HF_TOKEN = userdata.get('HF_TOKEN')

print("="*60)
print("Authentication")
print("="*60)

if WANDB_API_KEY:
    print(f"✓ WANDB_API_KEY: {WANDB_API_KEY[:8]}...{WANDB_API_KEY[-4:]}")
else:
    print("✗ WANDB_API_KEY not found")

if HF_TOKEN:
    print(f"✓ HF_TOKEN: {HF_TOKEN[:8]}...{HF_TOKEN[-4:]}")
else:
    print("✗ HF_TOKEN not found")

print("="*60)

if not WANDB_API_KEY or not HF_TOKEN:
    raise ValueError("Missing required secrets in Colab")

# Auto-detect HF username
api = HfApi(token=HF_TOKEN)
HF_USERNAME = api.whoami()['name']
print(f"✓ HuggingFace user: {HF_USERNAME}")

# Update repo configs
HF_DATASET_REPO = f"{HF_USERNAME}/captaincook4d-features"
HF_BASELINES_REPO = f"{HF_USERNAME}/captaincook4d-baselines"

print(f"✓ Dataset repo: {HF_DATASET_REPO}")
print(f"✓ Baselines repo: {HF_BASELINES_REPO}")


In [None]:
# Create data directory structure
import os

os.makedirs('data/video/omnivore', exist_ok=True)
os.makedirs('data/video/slowfast', exist_ok=True)
os.makedirs('checkpoints', exist_ok=True)
os.makedirs('annotations/annotation_json', exist_ok=True)
os.makedirs('annotations/data_splits', exist_ok=True)
os.makedirs('er_annotations', exist_ok=True)

print("✓ Directory structure created")


In [None]:
# Load features from HuggingFace
from huggingface_hub import hf_hub_download
import zipfile
import os
import shutil

print("="*60)
print("Loading Features from HuggingFace")
print("="*60)

def extract_features(zip_path, target_dir, feature_name):
    """Extract features and handle nested folder structure"""
    # Clean target directory
    if os.path.exists(target_dir):
        shutil.rmtree(target_dir)
    os.makedirs(target_dir, exist_ok=True)
    
    # Extract to parent directory
    parent_dir = os.path.dirname(target_dir)
    with zipfile.ZipFile(zip_path, 'r') as z:
        z.extractall(parent_dir)
    
    # Check if nested folder exists and fix it
    nested_path = os.path.join(target_dir, feature_name)
    if os.path.exists(nested_path) and os.path.isdir(nested_path):
        # Move contents up one level
        for item in os.listdir(nested_path):
            shutil.move(os.path.join(nested_path, item), os.path.join(target_dir, item))
        # Remove empty nested folder
        os.rmdir(nested_path)
    
    # Count features
    count = len([f for f in os.listdir(target_dir) if f.endswith('.npz')])
    return count

# Download and extract omnivore features
print("Downloading omnivore features...")
omnivore_zip = hf_hub_download(
    repo_id=HF_DATASET_REPO,
    filename="omnivore.zip",
    repo_type="dataset",
    token=HF_TOKEN
)
omnivore_count = extract_features(omnivore_zip, 'data/video/omnivore', 'omnivore')
print(f"✓ Extracted {omnivore_count} omnivore features")

# Download and extract slowfast features
print("Downloading slowfast features...")
slowfast_zip = hf_hub_download(
    repo_id=HF_DATASET_REPO,
    filename="slowfast.zip",
    repo_type="dataset",
    token=HF_TOKEN
)
slowfast_count = extract_features(slowfast_zip, 'data/video/slowfast', 'slowfast')
print(f"✓ Extracted {slowfast_count} slowfast features")

print("="*60)


In [None]:
# Load annotations (HuggingFace or Git repository)
import os
import zipfile
from huggingface_hub import hf_hub_download

print("="*60)
print("Loading Annotations")
print("="*60)

try:
    print("Attempting to load from HuggingFace...")
    annotations_zip = hf_hub_download(
        repo_id=HF_DATASET_REPO,
        filename="annotations.zip",
        repo_type="dataset",
        token=HF_TOKEN
    )
    
    with zipfile.ZipFile(annotations_zip, 'r') as z:
        z.extractall('.')
    
    print("✓ Loaded annotations from HuggingFace")
except Exception as e:
    print(f"Note: {e}")
    print("Using annotations from Git repository")

# Verify required annotation files
required_files = [
    'annotations/annotation_json/step_annotations.json',
    'annotations/annotation_json/error_annotations.json',
    'er_annotations/recordings_combined_splits.json'
]

missing = []
for file in required_files:
    if os.path.exists(file):
        print(f"✓ {file}")
    else:
        print(f"✗ Missing: {file}")
        missing.append(file)

if missing:
    print(f"\n⚠ Warning: {len(missing)} file(s) missing!")
else:
    print("\n✓ All annotation files present")

print("="*60)


## 2.6. HuggingFace Helper Functions


In [None]:
# Initialize WandB authentication
import wandb

wandb.login(key=WANDB_API_KEY)
print("✓ WandB authenticated")


In [None]:
# HuggingFace Helper Functions
from huggingface_hub import upload_file, upload_folder, create_repo, hf_hub_download
import os

def get_hf_repo_name(wandb_run_name, backbone):
    return f"{HF_USERNAME}/rnn-{backbone}-{wandb_run_name}"

def upload_checkpoints_to_hf(local_path, wandb_run_name, backbone):
    print("="*60)
    print(f"Uploading to HuggingFace: {backbone}")
    print("="*60)
    
    if not os.path.exists(local_path):
        print(f"✗ Checkpoint path not found: {local_path}")
        return False
    
    checkpoint_files = [f for f in os.listdir(local_path) if f.endswith('.pt')]
    if not checkpoint_files:
        print(f"✗ No checkpoint files found")
        return False
    
    repo_id = get_hf_repo_name(wandb_run_name, backbone)
    
    try:
        create_repo(repo_id=repo_id, private=True, exist_ok=True, token=HF_TOKEN)
        upload_folder(folder_path=local_path, repo_id=repo_id, token=HF_TOKEN, allow_patterns="*.pt")
        print(f"✓ Uploaded {len(checkpoint_files)} checkpoints to {repo_id}")
        return True
    except Exception as e:
        print(f"✗ Upload failed: {e}")
        return False

def download_checkpoint_from_hf(repo_id, filename):
    try:
        return hf_hub_download(repo_id=repo_id, filename=filename, repo_type="model", token=HF_TOKEN)
    except Exception as e:
        print(f"✗ Download failed: {e}")
        return None

def upload_results_to_hf(csv_path, wandb_run_name, backbone):
    if not os.path.exists(csv_path):
        return False
    
    repo_id = get_hf_repo_name(wandb_run_name, backbone)
    try:
        upload_file(path_or_fileobj=csv_path, path_in_repo="results.csv", repo_id=repo_id, token=HF_TOKEN)
        return True
    except Exception as e:
        print(f"✗ Results upload failed: {e}")
        return False

def create_model_card(wandb_run_name, backbone, config, wandb_url):
    repo_id = get_hf_repo_name(wandb_run_name, backbone)
    card = f"""---
tags:
- video-understanding
- error-recognition
- rnn
datasets:
- {HF_DATASET_REPO}
---

# RNN Baseline: {backbone.upper()}

## Model Details
- Backbone: {backbone}
- Architecture: {config['rnn_type']}
- Hidden Size: {config['rnn_hidden_size']}
- Layers: {config['rnn_num_layers']}

## Training
- WandB Run: {wandb_url}
- Epochs: {config['num_epochs']}
- Learning Rate: {config['learning_rate']}

## Results
See results.csv for detailed metrics.
"""
    
    try:
        upload_file(path_or_fileobj=card.encode(), path_in_repo="README.md", repo_id=repo_id, token=HF_TOKEN)
        return True
    except Exception as e:
        print(f"✗ Model card upload failed: {e}")
        return False

print("✓ HuggingFace helper functions loaded")


## 3. Train RNN Baseline with Omnivore Features


In [None]:
# Pre-training checks for Omnivore
import os
from google.colab import userdata
from huggingface_hub import HfApi
import wandb

print("="*60)
print("Pre-training Checks: Omnivore")
print("="*60)

checks_passed = True

if not os.path.exists("scripts/train_rnn_baseline.py"):
    print("✗ Training script not found")
    checks_passed = False
else:
    print("✓ Training script found")

os.makedirs("checkpoints/error_recognition/RNN/omnivore", exist_ok=True)
try:
    test_file = "checkpoints/error_recognition/RNN/omnivore/.test_write"
    with open(test_file, 'w') as f:
        f.write("test")
    os.remove(test_file)
    print("✓ Checkpoint directory writable")
except Exception as e:
    print(f"✗ Checkpoint directory not writable: {e}")
    checks_passed = False

try:
    test_key = userdata.get('WANDB_API_KEY')
    if not test_key:
        print("✗ WANDB_API_KEY not found in userdata")
        checks_passed = False
    else:
        print("✓ WANDB_API_KEY available")
except Exception as e:
    print(f"✗ Failed to get WANDB_API_KEY: {e}")
    checks_passed = False

try:
    test_token = userdata.get('HF_TOKEN')
    if not test_token:
        print("✗ HF_TOKEN not found in userdata")
        checks_passed = False
    else:
        print("✓ HF_TOKEN available")
except Exception as e:
    print(f"✗ Failed to get HF_TOKEN: {e}")
    checks_passed = False

if not wandb.api.api_key:
    print("✗ WandB not authenticated")
    checks_passed = False
else:
    print("✓ WandB authenticated")

try:
    api = wandb.Api()
    api.viewer()
    print("✓ WandB API accessible")
except Exception as e:
    print(f"✗ WandB API not accessible: {e}")
    checks_passed = False

try:
    hf_token = userdata.get('HF_TOKEN')
    hf_api = HfApi(token=hf_token)
    hf_api.whoami()
    print("✓ HuggingFace token valid")
except Exception as e:
    print(f"✗ HuggingFace token invalid: {e}")
    checks_passed = False

try:
    assert 'upload_checkpoints_to_hf' in globals()
    assert 'create_model_card' in globals()
    print("✓ HuggingFace helper functions available")
except AssertionError:
    print("✗ HuggingFace helper functions not available")
    checks_passed = False

try:
    assert 'TRAINING_CONFIG' in globals()
    print("✓ TRAINING_CONFIG defined")
except AssertionError:
    print("✗ TRAINING_CONFIG not defined")
    checks_passed = False

required_data = [
    "data/video/omnivore",
    "er_annotations/recordings_combined_splits.json"
]

for path in required_data:
    if not os.path.exists(path):
        print(f"✗ Missing: {path}")
        checks_passed = False
    else:
        print(f"✓ Found: {path}")

if not checks_passed:
    raise RuntimeError("Pre-training checks failed")

print("="*60)
print("✓ All pre-training checks passed")
print("="*60)


In [None]:
# Training Configuration
TRAINING_CONFIG = {
    "batch_size": 4,
    "num_epochs": 20,
    "learning_rate": 1e-3,
    "weight_decay": 1e-3,
    "rnn_hidden_size": 256,
    "rnn_num_layers": 2,
    "rnn_dropout": 0.2,
    "rnn_bidirectional": True,
    "rnn_type": "LSTM"
}

print("="*60)
print("Training Configuration")
print("="*60)
for key, value in TRAINING_CONFIG.items():
    print(f"  {key}: {value}")
print("="*60)

In [None]:
# Train RNN with Omnivore features
import subprocess
import sys
import os
import wandb
from google.colab import userdata

BACKBONE = "omnivore"

print(f"{'='*60}")
print(f"Training RNN + {BACKBONE.upper()}")
print(f"{'='*60}")

cmd = [
    sys.executable, "scripts/train_rnn_baseline.py",
    "--variant", "RNN",
    "--backbone", BACKBONE,
    "--split", "recordings",
    "--batch_size", str(TRAINING_CONFIG["batch_size"]),
    "--num_epochs", str(TRAINING_CONFIG["num_epochs"]),
    "--lr", str(TRAINING_CONFIG["learning_rate"]),
    "--weight_decay", str(TRAINING_CONFIG["weight_decay"]),
    "--rnn_hidden_size", str(TRAINING_CONFIG["rnn_hidden_size"]),
    "--rnn_num_layers", str(TRAINING_CONFIG["rnn_num_layers"]),
    "--rnn_dropout", str(TRAINING_CONFIG["rnn_dropout"]),
    "--rnn_bidirectional", str(TRAINING_CONFIG["rnn_bidirectional"]),
    "--rnn_type", TRAINING_CONFIG["rnn_type"],
]

env = os.environ.copy()
env['PYTHONPATH'] = os.getcwd() + (os.pathsep + env.get('PYTHONPATH', ''))
env['WANDB_API_KEY'] = userdata.get('WANDB_API_KEY')

result = subprocess.run(cmd, cwd=os.getcwd(), env=env, check=False, 
                      stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True)

if result.stdout:
    print(result.stdout)

if result.returncode == 0:
    local_ckpt_path = f"checkpoints/error_recognition/RNN/{BACKBONE}"
    api = wandb.Api()
    project_name = f"error_recognition_recordings_{BACKBONE}_RNN_video"
    runs = api.runs(f"{wandb.api.viewer()['entity']}/{project_name}", order="-created_at", per_page=1)
    wandb_run_name = runs[0].name
    wandb_run_url = runs[0].url
    upload_checkpoints_to_hf(local_ckpt_path, wandb_run_name, BACKBONE)
    create_model_card(wandb_run_name, BACKBONE, TRAINING_CONFIG, wandb_run_url)


## 3.5. Apply Compatibility Patch

In [None]:
# Apply compatibility patch for pack_padded_sequence
import os

file_path = "core/models/blocks.py"

if os.path.exists(file_path):
    with open(file_path, 'r') as f:
        content = f.read()

    original = "x = nn.utils.rnn.pack_padded_sequence(x, lengths, batch_first=True, enforce_sorted=False)"
    patched = "x = nn.utils.rnn.pack_padded_sequence(x, lengths.cpu(), batch_first=True, enforce_sorted=False)"

    if original in content:
        content = content.replace(original, patched)
        with open(file_path, 'w') as f:
            f.write(content)
        print(f"✓ Patched {file_path}")
    else:
        print("✓ Patch already applied")
else:
    print(f"✗ {file_path} not found")


In [None]:
# Pre-training checks for SlowFast
import os
from google.colab import userdata
from huggingface_hub import HfApi
import wandb

print("="*60)
print("Pre-training Checks: SlowFast")
print("="*60)

checks_passed = True

if not os.path.exists("scripts/train_rnn_baseline.py"):
    print("✗ Training script not found")
    checks_passed = False
else:
    print("✓ Training script found")

os.makedirs("checkpoints/error_recognition/RNN/slowfast", exist_ok=True)
try:
    test_file = "checkpoints/error_recognition/RNN/slowfast/.test_write"
    with open(test_file, 'w') as f:
        f.write("test")
    os.remove(test_file)
    print("✓ Checkpoint directory writable")
except Exception as e:
    print(f"✗ Checkpoint directory not writable: {e}")
    checks_passed = False

try:
    test_key = userdata.get('WANDB_API_KEY')
    if not test_key:
        print("✗ WANDB_API_KEY not found in userdata")
        checks_passed = False
    else:
        print("✓ WANDB_API_KEY available")
except Exception as e:
    print(f"✗ Failed to get WANDB_API_KEY: {e}")
    checks_passed = False

try:
    test_token = userdata.get('HF_TOKEN')
    if not test_token:
        print("✗ HF_TOKEN not found in userdata")
        checks_passed = False
    else:
        print("✓ HF_TOKEN available")
except Exception as e:
    print(f"✗ Failed to get HF_TOKEN: {e}")
    checks_passed = False

if not wandb.api.api_key:
    print("✗ WandB not authenticated")
    checks_passed = False
else:
    print("✓ WandB authenticated")

try:
    api = wandb.Api()
    api.viewer()
    print("✓ WandB API accessible")
except Exception as e:
    print(f"✗ WandB API not accessible: {e}")
    checks_passed = False

try:
    hf_token = userdata.get('HF_TOKEN')
    hf_api = HfApi(token=hf_token)
    hf_api.whoami()
    print("✓ HuggingFace token valid")
except Exception as e:
    print(f"✗ HuggingFace token invalid: {e}")
    checks_passed = False

try:
    assert 'upload_checkpoints_to_hf' in globals()
    assert 'create_model_card' in globals()
    print("✓ HuggingFace helper functions available")
except AssertionError:
    print("✗ HuggingFace helper functions not available")
    checks_passed = False

try:
    assert 'TRAINING_CONFIG' in globals()
    print("✓ TRAINING_CONFIG defined")
except AssertionError:
    print("✗ TRAINING_CONFIG not defined")
    checks_passed = False

required_data = [
    "data/video/slowfast",
    "er_annotations/recordings_combined_splits.json"
]

for path in required_data:
    if not os.path.exists(path):
        print(f"✗ Missing: {path}")
        checks_passed = False
    else:
        print(f"✓ Found: {path}")

if not checks_passed:
    raise RuntimeError("Pre-training checks failed")

print("="*60)
print("✓ All pre-training checks passed")
print("="*60)


## 4. Train RNN Baseline with SlowFast Features


In [None]:
# Train RNN with SlowFast features
import subprocess
import sys
import os
import wandb
from google.colab import userdata

BACKBONE = "slowfast"

print(f"{'='*60}")
print(f"Training RNN + {BACKBONE.upper()}")
print(f"{'='*60}")

cmd = [
    sys.executable, "scripts/train_rnn_baseline.py",
    "--variant", "RNN",
    "--backbone", BACKBONE,
    "--split", "recordings",
    "--batch_size", str(TRAINING_CONFIG["batch_size"]),
    "--num_epochs", str(TRAINING_CONFIG["num_epochs"]),
    "--lr", str(TRAINING_CONFIG["learning_rate"]),
    "--weight_decay", str(TRAINING_CONFIG["weight_decay"]),
    "--rnn_hidden_size", str(TRAINING_CONFIG["rnn_hidden_size"]),
    "--rnn_num_layers", str(TRAINING_CONFIG["rnn_num_layers"]),
    "--rnn_dropout", str(TRAINING_CONFIG["rnn_dropout"]),
    "--rnn_bidirectional", str(TRAINING_CONFIG["rnn_bidirectional"]),
    "--rnn_type", TRAINING_CONFIG["rnn_type"],
]

env = os.environ.copy()
env['PYTHONPATH'] = os.getcwd() + (os.pathsep + env.get('PYTHONPATH', ''))
env['WANDB_API_KEY'] = userdata.get('WANDB_API_KEY')

result = subprocess.run(cmd, cwd=os.getcwd(), env=env, check=False, 
                      stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True)

if result.stdout:
    print(result.stdout)

if result.returncode == 0:
    local_ckpt_path = f"checkpoints/error_recognition/RNN/{BACKBONE}"
    api = wandb.Api()
    project_name = f"error_recognition_recordings_{BACKBONE}_RNN_video"
    runs = api.runs(f"{wandb.api.viewer()['entity']}/{project_name}", order="-created_at", per_page=1)
    wandb_run_name = runs[0].name
    wandb_run_url = runs[0].url
    upload_checkpoints_to_hf(local_ckpt_path, wandb_run_name, BACKBONE)
    create_model_card(wandb_run_name, BACKBONE, TRAINING_CONFIG, wandb_run_url)


## 5. Evaluate Trained Model


In [None]:
# Evaluate trained RNN models
import subprocess
import sys
import os

EVAL_THRESHOLD = 0.6

def find_best_checkpoint(backbone):
    local_dir = f"checkpoints/error_recognition/RNN/{backbone}"
    
    if not os.path.exists(local_dir):
        return None
    
    checkpoint_files = [f for f in os.listdir(local_dir) if f.endswith('_best.pt')]
    if checkpoint_files:
        return os.path.join(local_dir, checkpoint_files[0])
    
    checkpoint_files = sorted([f for f in os.listdir(local_dir) if f.endswith('.pt') and 'epoch' in f])
    if checkpoint_files:
        return os.path.join(local_dir, checkpoint_files[-1])
    
    return None

def evaluate_model(backbone):
    print(f"{'='*60}")
    print(f"Evaluating RNN + {backbone.upper()}")
    print(f"{'='*60}")
    
    checkpoint_path = find_best_checkpoint(backbone)
    
    if not checkpoint_path or not os.path.exists(checkpoint_path):
        print(f"✗ No checkpoint found")
        return False
    
    print(f"Checkpoint: {checkpoint_path}")
    
    cmd = [
        sys.executable, "-m", "core.evaluate",
        "--variant", "RNN",
        "--backbone", backbone,
        "--split", "recordings",
        "--ckpt", checkpoint_path,
        "--threshold", str(EVAL_THRESHOLD)
    ]
    
    env = os.environ.copy()
    env['PYTHONPATH'] = os.getcwd() + (os.pathsep + env.get('PYTHONPATH', ''))
    
    result = subprocess.run(cmd, cwd=os.getcwd(), env=env, check=False)
    
    if result.returncode != 0:
        print(f"✗ Evaluation failed")
        return False
    else:
        print(f"✓ Evaluation complete")
        return True

print("="*60)
print("RNN Model Evaluation")
print("="*60)

omnivore_success = evaluate_model("omnivore")
slowfast_success = evaluate_model("slowfast")

print(f"\n{'='*60}")
print("Evaluation Summary")
print(f"{'='*60}")
print(f"Omnivore: {'✓' if omnivore_success else '✗'}")
print(f"SlowFast: {'✓' if slowfast_success else '✗'}")
print(f"{'='*60}")


## 6. Compare Results


In [None]:
# Comprehensive results comparison: RNN vs MLP vs Transformer
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
import os

# Configuration
threshold = 0.6
results_file = f"results/error_recognition/combined_results/step_True_substep_True_threshold_{threshold}.csv"

print(f"{'='*80}")
print("Comprehensive Model Comparison")
print(f"{'='*80}")
print(f"Results file: {results_file}")
print()

if not os.path.exists(results_file):
    print(f"⚠ Results file not found: {results_file}")
    print("Please run evaluation first.")
else:
    # Load results
    df = pd.read_csv(results_file)
    
    # Filter for recordings split
    df_filtered = df[df['Split'] == 'recordings'].copy()
    
    if len(df_filtered) == 0:
        print("⚠ No results found for 'recordings' split")
    else:
        # Select relevant columns
        columns_to_show = ['Variant', 'Backbone', 'Step Precision', 'Step Recall', 
                          'Step F1', 'Step Accuracy', 'Step AUC']
        
        # Check which columns exist
        existing_columns = [col for col in columns_to_show if col in df_filtered.columns]
        
        print("="*80)
        print("Model Performance Comparison (Recordings Split, Threshold=0.6)")
        print("="*80)
        print()
        
        # Display full comparison table
        comparison_df = df_filtered[existing_columns].copy()
        
        # Sort by F1 score (descending)
        if 'Step F1' in comparison_df.columns:
            comparison_df = comparison_df.sort_values('Step F1', ascending=False)
        
        # Format numeric columns
        numeric_columns = comparison_df.select_dtypes(include=[np.number]).columns
        for col in numeric_columns:
            comparison_df[col] = comparison_df[col].round(4)
        
        print(comparison_df.to_string(index=False))
        print()
        
        # Highlight best model
        if 'Step F1' in comparison_df.columns:
            best_idx = comparison_df['Step F1'].idxmax()
            best_model = comparison_df.loc[best_idx]
            print("="*80)
            print("Best Model (by F1 Score):")
            print("="*80)
            print(f"Variant: {best_model['Variant']}")
            print(f"Backbone: {best_model['Backbone']}")
            print(f"F1 Score: {best_model['Step F1']:.4f}")
            print()
        
        # Create visualizations
        print("="*80)
        print("Generating Visualizations...")
        print("="*80)
        
        # Prepare data for plotting
        variants = comparison_df['Variant'].tolist()
        backbones = comparison_df['Backbone'].tolist()
        labels = [f"{v}+{b}" for v, b in zip(variants, backbones)]
        
        metrics = ['Step Precision', 'Step Recall', 'Step F1', 'Step AUC']
        metrics = [m for m in metrics if m in comparison_df.columns]
        
        if len(metrics) > 0:
            # Create subplot for each metric
            fig, axes = plt.subplots(2, 2, figsize=(16, 12))
            fig.suptitle('Model Comparison: RNN vs MLP vs Transformer', fontsize=16, fontweight='bold')
            
            for idx, metric in enumerate(metrics):
                row = idx // 2
                col = idx % 2
                ax = axes[row, col]
                
                values = comparison_df[metric].tolist()
                colors = ['#1f77b4' if 'RNN' in label else '#ff7f0e' if 'MLP' in label else '#2ca02c' 
                         for label in labels]
                
                bars = ax.bar(range(len(labels)), values, color=colors, alpha=0.8, edgecolor='black')
                ax.set_xticks(range(len(labels)))
                ax.set_xticklabels(labels, rotation=45, ha='right')
                ax.set_ylabel(metric, fontsize=12)
                ax.set_title(metric, fontsize=13, fontweight='bold')
                ax.grid(axis='y', alpha=0.3, linestyle='--')
                ax.set_ylim([0, 1.0])
                
                # Add value labels on bars
                for bar, val in zip(bars, values):
                    height = bar.get_height()
                    ax.text(bar.get_x() + bar.get_width()/2., height + 0.01,
                           f'{val:.3f}', ha='center', va='bottom', fontsize=9)
            
            plt.tight_layout()
            plt.savefig('model_comparison.png', dpi=150, bbox_inches='tight')
            print("✓ Saved comparison plot to: model_comparison.png")
            plt.show()
            
            # Create grouped bar chart by backbone
            fig, ax = plt.subplots(figsize=(14, 8))
            
            backbones_unique = sorted(comparison_df['Backbone'].unique())
            variants_unique = sorted(comparison_df['Variant'].unique())
            
            x = np.arange(len(backbones_unique))
            width = 0.25
            
            for i, variant in enumerate(variants_unique):
                variant_data = []
                for backbone in backbones_unique:
                    matching = comparison_df[(comparison_df['Variant'] == variant) & 
                                            (comparison_df['Backbone'] == backbone)]
                    if len(matching) > 0 and 'Step F1' in matching.columns:
                        variant_data.append(matching['Step F1'].iloc[0])
                    else:
                        variant_data.append(0)
                
                offset = width * (i - len(variants_unique) / 2 + 0.5)
                bars = ax.bar(x + offset, variant_data, width, label=variant, alpha=0.8, edgecolor='black')
                
                # Add value labels
                for bar in bars:
                    height = bar.get_height()
                    if height > 0:
                        ax.text(bar.get_x() + bar.get_width()/2., height + 0.01,
                               f'{height:.3f}', ha='center', va='bottom', fontsize=9)
            
            ax.set_xlabel('Backbone', fontsize=13, fontweight='bold')
            ax.set_ylabel('Step F1 Score', fontsize=13, fontweight='bold')
            ax.set_title('F1 Score Comparison by Backbone and Variant', fontsize=14, fontweight='bold')
            ax.set_xticks(x)
            ax.set_xticklabels(backbones_unique)
            ax.legend(title='Variant', fontsize=11)
            ax.grid(axis='y', alpha=0.3, linestyle='--')
            ax.set_ylim([0, 1.0])
            
            plt.tight_layout()
            plt.savefig('f1_comparison_by_backbone.png', dpi=150, bbox_inches='tight')
            print("✓ Saved F1 comparison plot to: f1_comparison_by_backbone.png")
            plt.show()
            
            print()
            print("="*80)
            print("Analysis Complete!")
            print("="*80)
        else:
            print("⚠ No metrics found for visualization")
            
print()
print("Note: Make sure you've run training and evaluation for all models before comparison.")
