# AML - Step 2 & Step 3 Runner (Colab)

This notebook runs Step 2 (feature sanity check) and Step 3 (evaluation) using `scripts/run.py`.

## 1. Setup Environment

In [None]:
# Clone repository (or upload your code)
!git clone --recursive https://github.com/sapeirone/aml-2025-mistake-detection.git code || echo "Repo already exists or using uploaded code"

# Change to code directory
import os
os.chdir('code')
print(f"Current directory: {os.getcwd()}")

In [None]:
# Install dependencies
!pip install -q torcheval
!pip install -q -r requirements-cpu.txt 2>/dev/null || pip install -q -r requirements.txt

## 2. Mount Google Drive (Optional)

If your features and checkpoints are in Google Drive, mount it here.

In [None]:
from google.colab import drive
drive.mount('/content/drive')

## 3. Load Data & Checkpoints

Specify paths to your unzipped directories on Google Drive.

In [None]:
# ============================================
# UPDATE THESE PATHS TO YOUR DRIVE LOCATIONS
# ============================================

# Path to unzipped omnivore features directory on Drive
# Should contain .npz files directly or in a subdirectory
OMNIVORE_DRIVE_PATH = "/content/drive/MyDrive/path/to/omnivore"  # UPDATE THIS

# Path to unzipped slowfast features directory on Drive
# Should contain .npz files directly or in a subdirectory
SLOWFAST_DRIVE_PATH = "/content/drive/MyDrive/path/to/slowfast"  # UPDATE THIS

# Path to error_recognition_best checkpoints on Drive
# Can be either:
#   - A zip file (e.g., "/content/drive/MyDrive/path/to/error_recognition_best.zip")
#   - An unzipped directory (e.g., "/content/drive/MyDrive/path/to/error_recognition_best")
# Should contain MLP/ and Transformer/ subdirectories
CHECKPOINTS_DRIVE_PATH = "/content/drive/MyDrive/path/to/error_recognition_best.zip"  # UPDATE THIS (can be .zip or directory)

print("Paths configured:")
print(f"  Omnivore: {OMNIVORE_DRIVE_PATH}")
print(f"  SlowFast: {SLOWFAST_DRIVE_PATH}")
print(f"  Checkpoints: {CHECKPOINTS_DRIVE_PATH}")

In [None]:
# Create data directory structure
!mkdir -p data/video/omnivore
!mkdir -p data/video/slowfast

# Copy omnivore features from Drive
print("Copying Omnivore features...")
!cp -r "{OMNIVORE_DRIVE_PATH}"/* data/video/omnivore/ 2>/dev/null || \
  (echo "Warning: Could not copy from {OMNIVORE_DRIVE_PATH}" && \
   echo "Trying alternative: copying contents if path points to parent directory..." && \
   find "{OMNIVORE_DRIVE_PATH}" -name "*.npz" -exec cp {} data/video/omnivore/ \; 2>/dev/null || true)

# Copy slowfast features from Drive
print("Copying SlowFast features...")
!cp -r "{SLOWFAST_DRIVE_PATH}"/* data/video/slowfast/ 2>/dev/null || \
  (echo "Warning: Could not copy from {SLOWFAST_DRIVE_PATH}" && \
   echo "Trying alternative: copying contents if path points to parent directory..." && \
   find "{SLOWFAST_DRIVE_PATH}" -name "*.npz" -exec cp {} data/video/slowfast/ \; 2>/dev/null || true)

# Verify features were copied
print("\nVerifying features...")
print(f"Omnivore files: $(ls -1 data/video/omnivore/*.npz 2>/dev/null | wc -l) .npz files")
print(f"SlowFast files: $(ls -1 data/video/slowfast/*.npz 2>/dev/null | wc -l) .npz files")

In [None]:
# Create checkpoints directory
import os
import subprocess
import shutil

os.makedirs('checkpoints', exist_ok=True)

checkpoint_path = CHECKPOINTS_DRIVE_PATH

# Check if path is a zip file or directory
is_zip = checkpoint_path.lower().endswith('.zip')
if not is_zip:
    is_zip = os.path.isfile(checkpoint_path) if os.path.exists(checkpoint_path) else False

if is_zip:
    print(f"Detected zip file: {checkpoint_path}")
    print("Copying zip file to temporary location...")
    shutil.copy(checkpoint_path, '/tmp/checkpoints.zip')
    
    print("Unzipping checkpoints...")
    subprocess.run(['unzip', '-q', '/tmp/checkpoints.zip', '-d', '/tmp/checkpoints_extracted'], check=True)
    
    # Find the error_recognition_best directory in the extracted files
    extracted_base = '/tmp/checkpoints_extracted'
    extracted_path = None
    
    # Check if error_recognition_best is at the root
    if os.path.exists(os.path.join(extracted_base, 'error_recognition_best')):
        extracted_path = os.path.join(extracted_base, 'error_recognition_best')
    else:
        # Check if MLP/Transformer are directly in extracted_base
        if os.path.exists(os.path.join(extracted_base, 'MLP')) or os.path.exists(os.path.join(extracted_base, 'Transformer')):
            extracted_path = extracted_base
        else:
            # Search for error_recognition_best in subdirectories
            for root, dirs, files in os.walk(extracted_base):
                if 'error_recognition_best' in dirs:
                    extracted_path = os.path.join(root, 'error_recognition_best')
                    break
                # Or check if MLP/Transformer are here
                if 'MLP' in dirs or 'Transformer' in dirs:
                    extracted_path = root
                    break
            
            if extracted_path is None:
                extracted_path = extracted_base
    
    print(f"Copying from extracted location: {extracted_path}")
    shutil.copytree(extracted_path, 'checkpoints/error_recognition_best', dirs_exist_ok=True)
    
    # Cleanup
    shutil.rmtree('/tmp/checkpoints_extracted', ignore_errors=True)
    os.remove('/tmp/checkpoints.zip')
    print("✓ Checkpoints extracted and copied")
else:
    print(f"Detected directory: {checkpoint_path}")
    print("Copying checkpoints from directory...")
    if os.path.basename(checkpoint_path) == 'error_recognition_best':
        shutil.copytree(checkpoint_path, 'checkpoints/error_recognition_best', dirs_exist_ok=True)
    else:
        os.makedirs('checkpoints/error_recognition_best', exist_ok=True)
        for item in os.listdir(checkpoint_path):
            src = os.path.join(checkpoint_path, item)
            dst = os.path.join('checkpoints/error_recognition_best', item)
            if os.path.isdir(src):
                shutil.copytree(src, dst, dirs_exist_ok=True)
            else:
                shutil.copy2(src, dst)
    print("✓ Checkpoints copied")

# Verify checkpoints were loaded
print("\nVerifying checkpoints...")
if os.path.exists('checkpoints/error_recognition_best'):
    print("✓ Checkpoints directory exists")
    # Count .pt files
    pt_files = []
    for root, dirs, files in os.walk('checkpoints/error_recognition_best'):
        pt_files.extend([os.path.join(root, f) for f in files if f.endswith('.pt')])
    print(f"Found {len(pt_files)} checkpoint files (.pt)")
    if pt_files:
        print("\nSample checkpoint files:")
        for f in pt_files[:5]:
            print(f"  {f}")
    
    # Show directory structure
    print("\nDirectory structure:")
    for root, dirs, files in os.walk('checkpoints/error_recognition_best'):
        level = root.replace('checkpoints/error_recognition_best', '').count(os.sep)
        indent = ' ' * 2 * level
        print(f"{indent}{os.path.basename(root)}/")
        if level < 3:  # Limit depth for readability
            subindent = ' ' * 2 * (level + 1)
            for d in dirs[:5]:  # Show first 5 dirs
                print(f"{subindent}{d}/")
else:
    print("✗ Checkpoints directory not found")

## 4. Step 2: Feature Sanity Check

In [None]:
# Run Step 2 with default path (data/)
!python scripts/run.py step2

In [None]:
# Or specify custom features root
# !python scripts/run.py step2 --features_root /path/to/features

## 5. Step 3: Evaluation Reproduction

Run evaluations for different backbones, variants, and splits. Update checkpoint paths with actual epoch numbers.

In [None]:
# Omnivore - MLP - Step split
!python scripts/run.py step3 --split step --backbone omnivore --variant MLP \
  --ckpt checkpoints/error_recognition_best/MLP/omnivore/error_recognition_MLP_omnivore_step_epoch_43.pt \
  --threshold 0.6

In [None]:
# Omnivore - MLP - Recordings split
!python scripts/run.py step3 --split recordings --backbone omnivore --variant MLP \
  --ckpt checkpoints/error_recognition_best/MLP/omnivore/error_recognition_MLP_omnivore_recordings_epoch_XX.pt \
  --threshold 0.4

In [None]:
# Omnivore - Transformer - Step split
!python scripts/run.py step3 --split step --backbone omnivore --variant Transformer \
  --ckpt checkpoints/error_recognition_best/Transformer/omnivore/error_recognition_Transformer_omnivore_step_epoch_XX.pt \
  --threshold 0.6

In [None]:
# Omnivore - Transformer - Recordings split
!python scripts/run.py step3 --split recordings --backbone omnivore --variant Transformer \
  --ckpt checkpoints/error_recognition_best/Transformer/omnivore/error_recognition_Transformer_omnivore_recordings_epoch_XX.pt \
  --threshold 0.4

In [None]:
# SlowFast - MLP - Step split
!python scripts/run.py step3 --split step --backbone slowfast --variant MLP \
  --ckpt checkpoints/error_recognition_best/MLP/slowfast/error_recognition_MLP_slowfast_step_epoch_XX.pt \
  --threshold 0.6

In [None]:
# SlowFast - MLP - Recordings split
!python scripts/run.py step3 --split recordings --backbone slowfast --variant MLP \
  --ckpt checkpoints/error_recognition_best/MLP/slowfast/error_recognition_MLP_slowfast_recordings_epoch_XX.pt \
  --threshold 0.4

In [None]:
# SlowFast - Transformer - Step split
!python scripts/run.py step3 --split step --backbone slowfast --variant Transformer \
  --ckpt checkpoints/error_recognition_best/Transformer/slowfast/error_recognition_Transformer_slowfast_step_epoch_XX.pt \
  --threshold 0.6

In [None]:
# SlowFast - Transformer - Recordings split
!python scripts/run.py step3 --split recordings --backbone slowfast --variant Transformer \
  --ckpt checkpoints/error_recognition_best/Transformer/slowfast/error_recognition_Transformer_slowfast_recordings_epoch_XX.pt \
  --threshold 0.4