# AML - Step 2 & Step 3 Runner (Colab)

This notebook runs Step 2 (feature sanity check) and Step 3 (evaluation) using `scripts/run.py`.

## Quick Start:
1. **Configure your repo** in Section 1
2. **Mount Google Drive** (Section 2) if your data is there
3. **Set data paths** (Section 3) and load features/checkpoints
4. **Run Step 2** to verify features
5. **Run Step 3** to evaluate models

## 1. Setup Environment

In [None]:
# ============================================
# CONFIGURE YOUR REPOSITORY
# ============================================
# Repository URL (use .git format, not tree/branch URLs)
REPO_URL = "https://github.com/aexomir/AML_mistake_detection.git"

# Branch to checkout (optional, leave empty for default branch)
REPO_BRANCH = "feat/step02"

# Repository directory name (will be cloned into this folder)
REPO_DIR = "aml_repo"

print(f"Repository URL: {REPO_URL}")
print(f"Repository branch: {REPO_BRANCH if REPO_BRANCH else 'default'}")
print(f"Repository directory: {REPO_DIR}")

In [None]:
import os
import shutil

# Remove existing directory if it exists
if os.path.exists(REPO_DIR):
    print(f"Removing existing {REPO_DIR} directory...")
    shutil.rmtree(REPO_DIR)

# Clone repository
if REPO_URL:
    print(f"Cloning repository from {REPO_URL}...")
    clone_cmd = f"git clone {REPO_URL} {REPO_DIR}"
    result = os.system(clone_cmd)
    
    if result != 0:
        print(f"⚠ Clone failed. Creating {REPO_DIR} directory for manual upload...")
        os.makedirs(REPO_DIR, exist_ok=True)
    else:
        print("✓ Repository cloned successfully")
        
        # Checkout specific branch if specified
        if REPO_BRANCH:
            print(f"Checking out branch: {REPO_BRANCH}")
            os.chdir(REPO_DIR)
            os.system(f"git checkout {REPO_BRANCH}")
            os.chdir('..')
            print(f"✓ Switched to branch: {REPO_BRANCH}")
else:
    print("No repo URL configured. Creating directory for manual upload...")
    os.makedirs(REPO_DIR, exist_ok=True)

# Change to repository directory
if os.path.exists(REPO_DIR):
    os.chdir(REPO_DIR)
    print(f"\n✓ Changed to directory: {os.getcwd()}")
    print(f"\nRepository contents:")
    !ls -la
else:
    print(f"✗ Error: {REPO_DIR} directory not found!")

In [None]:
# Verify we're in the correct directory and repo structure
import os

print(f"Current working directory: {os.getcwd()}")
print(f"\nChecking repository structure...")

# Check for key files/directories
required_items = [
    'scripts/run.py',
    'core/evaluate.py',
    'dataloader',
    'base.py',
    'constants.py'
]

missing = []
for item in required_items:
    if os.path.exists(item):
        print(f"✓ Found: {item}")
    else:
        print(f"✗ Missing: {item}")
        missing.append(item)

if missing:
    print(f"\n⚠ Warning: Some required files/directories are missing!")
    print(f"Make sure you're in the correct repository directory.")
else:
    print(f"\n✓ Repository structure looks good!")
    print(f"\nYou can now proceed to load data and run steps.")

In [None]:
# Install dependencies
# Remove PyTorch version constraints to avoid conflicts with Colab's pre-installed version
if os.path.exists('requirements.txt'):
    !sed -i '/^torch==/d' requirements.txt 2>/dev/null || true
    !sed -i '/^torchvision==/d' requirements.txt 2>/dev/null || true

# Install PyTorch ecosystem with compatible versions
# Use Colab-compatible PyTorch versions
!pip install -q torch==2.2.0+cu118 torchvision==0.17.0+cu118 torchaudio==2.2.0+cu118 --index-url https://download.pytorch.org/whl/cu118 2>/dev/null || \
  pip install -q torch torchvision torchaudio

# Install torcheval for evaluation metrics (requires torchaudio)
!pip install -q torcheval

# Install all remaining dependencies from requirements.txt
if os.path.exists('requirements.txt'):
    !pip install -q -r requirements.txt
elif os.path.exists('requirements-cpu.txt'):
    !pip install -q -r requirements-cpu.txt

print("✓ All dependencies installed successfully")

## 2. Mount Google Drive (Optional)

If your features and checkpoints are in Google Drive, mount it here.

In [None]:
from google.colab import drive
drive.mount('/content/drive')

## 3. Load Data & Checkpoints

Specify paths to your unzipped directories on Google Drive.

In [None]:
# ============================================
# UPDATE THESE PATHS TO YOUR DRIVE LOCATIONS
# ============================================

# Path to unzipped omnivore features directory on Drive
# Should contain .npz files directly or in a subdirectory
OMNIVORE_DRIVE_PATH = "/content/drive/MyDrive/AML/AML_MistakeDetection/features/omnivore"

# Path to unzipped slowfast features directory on Drive
# Should contain .npz files directly or in a subdirectory
SLOWFAST_DRIVE_PATH = "/content/drive/MyDrive/AML/AML_MistakeDetection/features/slowfast"

# Path to error_recognition_best checkpoints on Drive
# Can be either:
#   - A zip file (e.g., "/content/drive/MyDrive/path/to/error_recognition_best.zip")
#   - An unzipped directory (e.g., "/content/drive/MyDrive/path/to/error_recognition_best")
# Should contain MLP/ and Transformer/ subdirectories
CHECKPOINTS_DRIVE_PATH = "/content/drive/MyDrive/AML/AML_MistakeDetection/error_recognition_best.zip"

# Path to annotations directory on Drive (if annotations are stored separately)
# Should contain annotation_json/ and optionally data_splits/ and er_annotations/ subdirectories
ANNOTATIONS_DRIVE_PATH = "/content/drive/MyDrive/AML/AML_MistakeDetection/annotations"

print("Paths configured:")
print(f"  Omnivore: {OMNIVORE_DRIVE_PATH}")
print(f"  SlowFast: {SLOWFAST_DRIVE_PATH}")
print(f"  Checkpoints: {CHECKPOINTS_DRIVE_PATH}")
print(f"  Annotations: {ANNOTATIONS_DRIVE_PATH}")

In [None]:
# ============================================
# LOAD ANNOTATIONS
# ============================================
# The evaluation script requires annotation files. You need to either:
# 1. Have them in your repository (they may be gitignored)
# 2. Upload them to Google Drive and copy them here
# 3. Download them from a shared location
#
# Required files:
# - annotations/annotation_json/step_annotations.json
# - annotations/annotation_json/error_annotations.json
# - er_annotations/recordings_combined_splits.json (for step split - should be in repo)
# - annotations/data_splits/{split}_data_split_combined.json (for recordings split)

import os
import shutil

# Create annotations directory structure
# Note: ANNOTATIONS_DRIVE_PATH is defined in the previous cell (set_paths)
os.makedirs('annotations/annotation_json', exist_ok=True)
os.makedirs('annotations/data_splits', exist_ok=True)
os.makedirs('er_annotations', exist_ok=True)

# Check if annotations exist in Drive and copy them
if os.path.exists(ANNOTATIONS_DRIVE_PATH):
    print(f"Found annotations in Drive: {ANNOTATIONS_DRIVE_PATH}")
    
    # Copy annotation_json directory
    annotation_json_src = os.path.join(ANNOTATIONS_DRIVE_PATH, 'annotation_json')
    if os.path.exists(annotation_json_src):
        print("Copying annotation_json files...")
        for file in os.listdir(annotation_json_src):
            src = os.path.join(annotation_json_src, file)
            dst = os.path.join('annotations/annotation_json', file)
            if os.path.isfile(src):
                shutil.copy2(src, dst)
                print(f"  ✓ Copied {file}")
    
    # Copy data_splits directory
    data_splits_src = os.path.join(ANNOTATIONS_DRIVE_PATH, 'data_splits')
    if os.path.exists(data_splits_src):
        print("Copying data_splits files...")
        for file in os.listdir(data_splits_src):
            src = os.path.join(data_splits_src, file)
            dst = os.path.join('annotations/data_splits', file)
            if os.path.isfile(src):
                shutil.copy2(src, dst)
                print(f"  ✓ Copied {file}")
    
    # Copy er_annotations if in Drive
    er_annotations_src = os.path.join(ANNOTATIONS_DRIVE_PATH, 'er_annotations')
    if os.path.exists(er_annotations_src):
        print("Copying er_annotations files...")
        for file in os.listdir(er_annotations_src):
            src = os.path.join(er_annotations_src, file)
            dst = os.path.join('er_annotations', file)
            if os.path.isfile(src):
                shutil.copy2(src, dst)
                print(f"  ✓ Copied {file}")
else:
    print(f"⚠ Annotations not found at {ANNOTATIONS_DRIVE_PATH}")
    print("  If annotations are in your repository, they should already be available.")
    print("  Otherwise, please upload them to Drive or update ANNOTATIONS_DRIVE_PATH.")

# Verify required annotation files exist
print("\nVerifying annotation files...")
required_files = [
    'annotations/annotation_json/step_annotations.json',
    'annotations/annotation_json/error_annotations.json',
    'er_annotations/recordings_combined_splits.json'
]

missing = []
for file in required_files:
    if os.path.exists(file):
        print(f"✓ Found: {file}")
    else:
        print(f"✗ Missing: {file}")
        missing.append(file)

if missing:
    print(f"\n⚠ Warning: {len(missing)} required annotation file(s) are missing!")
    print("  You need to provide these files before running Step 3.")
    print("  Please check your repository or upload them to Google Drive.")
else:
    print("\n✓ All required annotation files are present!")

In [None]:
# Create data directory structure
import os
os.makedirs('data/video/omnivore', exist_ok=True)
os.makedirs('data/video/slowfast', exist_ok=True)

# Copy omnivore features from Drive
print("Copying Omnivore features...")
if os.path.exists(OMNIVORE_DRIVE_PATH):
    !cp -r "{OMNIVORE_DRIVE_PATH}"/* data/video/omnivore/ 2>/dev/null || \
      (echo "Trying alternative: finding .npz files..." && \
       find "{OMNIVORE_DRIVE_PATH}" -name "*.npz" -exec cp {} data/video/omnivore/ \; 2>/dev/null || true)
    print("✓ Omnivore features copied")
else:
    print(f"⚠ Warning: {OMNIVORE_DRIVE_PATH} not found")

# Copy slowfast features from Drive
print("\nCopying SlowFast features...")
if os.path.exists(SLOWFAST_DRIVE_PATH):
    !cp -r "{SLOWFAST_DRIVE_PATH}"/* data/video/slowfast/ 2>/dev/null || \
      (echo "Trying alternative: finding .npz files..." && \
       find "{SLOWFAST_DRIVE_PATH}" -name "*.npz" -exec cp {} data/video/slowfast/ \; 2>/dev/null || true)
    print("✓ SlowFast features copied")
else:
    print(f"⚠ Warning: {SLOWFAST_DRIVE_PATH} not found")

# Verify features were copied
print("\nVerifying features...")
omnivore_count = len([f for f in os.listdir('data/video/omnivore') if f.endswith('.npz')]) if os.path.exists('data/video/omnivore') else 0
slowfast_count = len([f for f in os.listdir('data/video/slowfast') if f.endswith('.npz')]) if os.path.exists('data/video/slowfast') else 0
print(f"Omnivore files: {omnivore_count} .npz files")
print(f"SlowFast files: {slowfast_count} .npz files")

In [None]:
# Create checkpoints directory
import os
import subprocess
import shutil

os.makedirs('checkpoints', exist_ok=True)

checkpoint_path = CHECKPOINTS_DRIVE_PATH

# Check if path is a zip file or directory
is_zip = checkpoint_path.lower().endswith('.zip')
if not is_zip:
    is_zip = os.path.isfile(checkpoint_path) if os.path.exists(checkpoint_path) else False

if is_zip:
    print(f"Detected zip file: {checkpoint_path}")
    if not os.path.exists(checkpoint_path):
        print(f"✗ Error: {checkpoint_path} not found!")
    else:
        print("Copying zip file to temporary location...")
        shutil.copy(checkpoint_path, '/tmp/checkpoints.zip')
        
        print("Unzipping checkpoints...")
        try:
            subprocess.run(['unzip', '-q', '/tmp/checkpoints.zip', '-d', '/tmp/checkpoints_extracted'], check=True)
        except subprocess.CalledProcessError as e:
            print(f"✗ Error unzipping: {e}")
        else:
            # Find the error_recognition_best directory in the extracted files
            extracted_base = '/tmp/checkpoints_extracted'
            extracted_path = None
            
            # Check if error_recognition_best is at the root
            if os.path.exists(os.path.join(extracted_base, 'error_recognition_best')):
                extracted_path = os.path.join(extracted_base, 'error_recognition_best')
            else:
                # Check if MLP/Transformer are directly in extracted_base
                if os.path.exists(os.path.join(extracted_base, 'MLP')) or os.path.exists(os.path.join(extracted_base, 'Transformer')):
                    extracted_path = extracted_base
                else:
                    # Search for error_recognition_best in subdirectories
                    for root, dirs, files in os.walk(extracted_base):
                        if 'error_recognition_best' in dirs:
                            extracted_path = os.path.join(root, 'error_recognition_best')
                            break
                        # Or check if MLP/Transformer are here
                        if 'MLP' in dirs or 'Transformer' in dirs:
                            extracted_path = root
                            break
                    
                    if extracted_path is None:
                        extracted_path = extracted_base
            
            print(f"Copying from extracted location: {extracted_path}")
            shutil.copytree(extracted_path, 'checkpoints/error_recognition_best', dirs_exist_ok=True)
            
            # Cleanup
            shutil.rmtree('/tmp/checkpoints_extracted', ignore_errors=True)
            os.remove('/tmp/checkpoints.zip')
            print("✓ Checkpoints extracted and copied")
else:
    print(f"Detected directory: {checkpoint_path}")
    if not os.path.exists(checkpoint_path):
        print(f"✗ Error: {checkpoint_path} not found!")
    else:
        print("Copying checkpoints from directory...")
        if os.path.basename(checkpoint_path) == 'error_recognition_best':
            shutil.copytree(checkpoint_path, 'checkpoints/error_recognition_best', dirs_exist_ok=True)
        else:
            os.makedirs('checkpoints/error_recognition_best', exist_ok=True)
            for item in os.listdir(checkpoint_path):
                src = os.path.join(checkpoint_path, item)
                dst = os.path.join('checkpoints/error_recognition_best', item)
                if os.path.isdir(src):
                    shutil.copytree(src, dst, dirs_exist_ok=True)
                else:
                    shutil.copy2(src, dst)
        print("✓ Checkpoints copied")

# Verify checkpoints were loaded
print("\nVerifying checkpoints...")
if os.path.exists('checkpoints/error_recognition_best'):
    print("✓ Checkpoints directory exists")
    # Count .pt files
    pt_files = []
    for root, dirs, files in os.walk('checkpoints/error_recognition_best'):
        pt_files.extend([os.path.join(root, f) for f in files if f.endswith('.pt')])
    print(f"Found {len(pt_files)} checkpoint files (.pt)")
    if pt_files:
        print("\nSample checkpoint files:")
        for f in pt_files[:5]:
            print(f"  {f}")
    
    # Show directory structure
    print("\nDirectory structure:")
    for root, dirs, files in os.walk('checkpoints/error_recognition_best'):
        level = root.replace('checkpoints/error_recognition_best', '').count(os.sep)
        indent = ' ' * 2 * level
        print(f"{indent}{os.path.basename(root)}/")
        if level < 3:  # Limit depth for readability
            subindent = ' ' * 2 * (level + 1)
            for d in dirs[:5]:  # Show first 5 dirs
                print(f"{subindent}{d}/")
else:
    print("✗ Checkpoints directory not found")

## 4. Step 2: Feature Sanity Check

In [None]:
# Run Step 2 with default path (data/)
!python scripts/run.py step2

In [None]:
# Or specify custom features root
# !python scripts/run.py step2 --features_root /path/to/features

## 5. Step 3: Evaluation Reproduction

Run evaluations for different backbones, variants, and splits. Update checkpoint paths with actual epoch numbers.

In [None]:
# Omnivore - MLP - Step split
!python scripts/run.py step3 --split step --backbone omnivore --variant MLP \
  --ckpt checkpoints/error_recognition_best/MLP/omnivore/error_recognition_MLP_omnivore_step_epoch_43.pt \
  --threshold 0.6

In [None]:
# Omnivore - MLP - Recordings split
!python scripts/run.py step3 --split recordings --backbone omnivore --variant MLP \
  --ckpt checkpoints/error_recognition_best/MLP/omnivore/error_recognition_MLP_omnivore_recordings_epoch_XX.pt \
  --threshold 0.4

In [None]:
# Omnivore - Transformer - Step split
!python scripts/run.py step3 --split step --backbone omnivore --variant Transformer \
  --ckpt checkpoints/error_recognition_best/Transformer/omnivore/error_recognition_Transformer_omnivore_step_epoch_XX.pt \
  --threshold 0.6

In [None]:
# Omnivore - Transformer - Recordings split
!python scripts/run.py step3 --split recordings --backbone omnivore --variant Transformer \
  --ckpt checkpoints/error_recognition_best/Transformer/omnivore/error_recognition_Transformer_omnivore_recordings_epoch_XX.pt \
  --threshold 0.4

In [None]:
# SlowFast - MLP - Step split
!python scripts/run.py step3 --split step --backbone slowfast --variant MLP \
  --ckpt checkpoints/error_recognition_best/MLP/slowfast/error_recognition_MLP_slowfast_step_epoch_XX.pt \
  --threshold 0.6

In [None]:
# SlowFast - MLP - Recordings split
!python scripts/run.py step3 --split recordings --backbone slowfast --variant MLP \
  --ckpt checkpoints/error_recognition_best/MLP/slowfast/error_recognition_MLP_slowfast_recordings_epoch_XX.pt \
  --threshold 0.4

In [None]:
# SlowFast - Transformer - Step split
!python scripts/run.py step3 --split step --backbone slowfast --variant Transformer \
  --ckpt checkpoints/error_recognition_best/Transformer/slowfast/error_recognition_Transformer_slowfast_step_epoch_XX.pt \
  --threshold 0.6

In [None]:
# SlowFast - Transformer - Recordings split
!python scripts/run.py step3 --split recordings --backbone slowfast --variant Transformer \
  --ckpt checkpoints/error_recognition_best/Transformer/slowfast/error_recognition_Transformer_slowfast_recordings_epoch_XX.pt \
  --threshold 0.4