<a href="https://colab.research.google.com/github/aexomir/AML_mistake_detection/blob/feat%2Frnn/notebooks/rnn_baseline_colab.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# RNN Baseline Training for CaptainCook4D SupervisedER

This notebook trains the V_RNN (RNN/LSTM) baseline for mistake detection and compares it against V1 (MLP) and V2 (Transformer) baselines.

## What this notebook does:
1. **Setup**: Clone repository and install dependencies
2. **Load Data**: Load features, annotations, and optionally checkpoints from Google Drive
3. **Train**: Train the RNN baseline model
4. **Evaluate**: Evaluate the trained model
5. **Compare**: Compare results against V1 (MLP) and V2 (Transformer) baselines

## Prerequisites:
You need to have:
- Pre-extracted features (Omnivore and SlowFast) in `.npz` format or zip files
- Annotation files (should be in the repository or uploaded separately)
- (Optional) Pre-trained checkpoints for comparison

## Quick Start:
1. Configure paths in Section 1
2. Run all cells sequentially


In [2]:
# ============================================
# CONFIGURE YOUR REPOSITORY
# ============================================
# Option 1: Clone from GitHub (recommended)
REPO_URL = "https://github.com/aexomir/AML_mistake_detection.git"
REPO_BRANCH = "feat/rnn"  # Leave empty for default branch, or specify branch name

# Option 2: Manual upload - set REPO_URL to empty string and upload files manually
# REPO_URL = ""

REPO_DIR = "code"

print(f"Repository URL: {REPO_URL if REPO_URL else 'Manual upload mode'}")
print(f"Repository branch: {REPO_BRANCH if REPO_BRANCH else 'default'}")
print(f"Repository directory: {REPO_DIR}")


Repository URL: https://github.com/aexomir/AML_mistake_detection.git
Repository branch: feat/rnn
Repository directory: code


In [3]:
import os
import shutil

# Remove existing directory if it exists
if os.path.exists(REPO_DIR):
    print(f"Removing existing {REPO_DIR} directory...")
    shutil.rmtree(REPO_DIR)

# Clone repository
if REPO_URL:
    print(f"Cloning repository from {REPO_URL}...")
    clone_cmd = f"git clone {REPO_URL} {REPO_DIR}"
    result = os.system(clone_cmd)

    if result != 0:
        print(f"⚠ Clone failed. Please check the URL or upload files manually.")
        os.makedirs(REPO_DIR, exist_ok=True)
    else:
        print("✓ Repository cloned successfully")

        # Checkout specific branch if specified
        if REPO_BRANCH:
            print(f"Checking out branch: {REPO_BRANCH}")
            os.chdir(REPO_DIR)
            os.system(f"git checkout {REPO_BRANCH}")
            os.chdir('..')
            print(f"✓ Switched to branch: {REPO_BRANCH}")
else:
    print("Manual upload mode: Creating directory...")
    os.makedirs(REPO_DIR, exist_ok=True)

# Change to repository directory
if os.path.exists(REPO_DIR):
    os.chdir(REPO_DIR)
    print(f"\n✓ Changed to directory: {os.getcwd()}")
    print(f"\nRepository contents:")
    !ls -la
else:
    print(f"✗ Error: {REPO_DIR} directory not found!")


Cloning repository from https://github.com/aexomir/AML_mistake_detection.git...
✓ Repository cloned successfully
Checking out branch: feat/rnn
✓ Switched to branch: feat/rnn

✓ Changed to directory: /content/code

Repository contents:
total 6024
drwxr-xr-x 9 root root    4096 Dec 23 20:57 .
drwxr-xr-x 1 root root    4096 Dec 23 20:57 ..
-rw-r--r-- 1 root root 6042142 Dec 23 20:57 3_Mistake_Detection.pdf
drwxr-xr-x 3 root root    4096 Dec 23 20:57 analysis
-rw-r--r-- 1 root root   20480 Dec 23 20:57 base.py
-rw-r--r-- 1 root root    1685 Dec 23 20:57 constants.py
drwxr-xr-x 3 root root    4096 Dec 23 20:57 core
drwxr-xr-x 2 root root    4096 Dec 23 20:57 dataloader
-rw-r--r-- 1 root root    6148 Dec 23 20:57 .DS_Store
drwxr-xr-x 2 root root    4096 Dec 23 20:57 er_annotations
drwxr-xr-x 8 root root    4096 Dec 23 20:57 .git
-rw-r--r-- 1 root root      65 Dec 23 20:57 .gitignore
-rwxr-xr-x 1 root root    1904 Dec 23 20:57 install_deps.py
-rw-r--r-- 1 root root   11357 Dec 23 20:57 LICENS

In [4]:
# Verify repository structure
import os

print(f"Current working directory: {os.getcwd()}")
print(f"\nChecking repository structure...")

required_items = [
    'scripts/train_rnn_baseline.py',
    'core/evaluate.py',
    'dataloader',
    'base.py',
    'constants.py'
]

missing = []
for item in required_items:
    if os.path.exists(item):
        print(f"✓ Found: {item}")
    else:
        print(f"✗ Missing: {item}")
        missing.append(item)

if missing:
    print(f"\n⚠ Warning: Some required files/directories are missing!")
    print(f"Please ensure all files are present before proceeding.")
else:
    print(f"\n✓ Repository structure looks good!")


Current working directory: /content/code

Checking repository structure...
✓ Found: scripts/train_rnn_baseline.py
✓ Found: core/evaluate.py
✓ Found: dataloader
✓ Found: base.py
✓ Found: constants.py

✓ Repository structure looks good!


In [5]:
# Install dependencies
# Colab comes with PyTorch pre-installed, so we'll work with that
# Remove PyTorch version constraints to avoid conflicts
if os.path.exists('requirements.txt'):
    !sed -i '/^torch==/d' requirements.txt 2>/dev/null || true
    !sed -i '/^torchvision==/d' requirements.txt 2>/dev/null || true

# Install torcheval (required for evaluation metrics)
!pip install -q torcheval

# Install all remaining dependencies from requirements.txt
if os.path.exists('requirements.txt'):
    !pip install -q -r requirements.txt
elif os.path.exists('requirements-cpu.txt'):
    !pip install -q -r requirements-cpu.txt

# Install additional dependencies for RNN baseline
!pip install -q wandb loguru

print("✓ All dependencies installed successfully")

# Verify PyTorch installation
import torch
print(f"\nPyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"CUDA device: {torch.cuda.get_device_name(0)}")


[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m179.2/179.2 kB[0m [31m4.8 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m50.1/50.1 kB[0m [31m3.5 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m454.4/454.4 kB[0m [31m12.1 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m61.6/61.6 kB[0m [31m5.0 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.3/2.3 MB[0m [31m50.2 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m54.5/54.5 kB[0m [31m3.1 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m144.2/144.2 kB[0m [31m14.0 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m92.5/92.5 kB[0m [31m8.1 MB/s[0m eta [36m

In [6]:
# ============================================
# CONFIGURE DATA PATHS
# ============================================
# Option 1: From Google Drive (recommended for large files)
USE_GOOGLE_DRIVE = True  # Set to False if uploading directly

# Paths on Google Drive (update these to match your Drive structure)
OMNIVORE_DRIVE_PATH = "/content/drive/MyDrive/AML_mistake_detection/omnivore.zip"  # Can be .zip or directory
SLOWFAST_DRIVE_PATH = "/content/drive/MyDrive/AML_mistake_detection/slowfast.zip"  # Can be .zip or directory
CHECKPOINTS_DRIVE_PATH = "/content/drive/MyDrive/AML_mistake_detection/error_recognition_best.zip"  # Can be .zip or directory
ANNOTATIONS_DRIVE_PATH = "/content/drive/MyDrive/AML_mistake_detection/annotations"  # Optional if in repo

# Option 2: Direct upload - set USE_GOOGLE_DRIVE = False and upload files in next cell

print("Data paths configured:")
print(f"  Use Google Drive: {USE_GOOGLE_DRIVE}")
print(f"  Omnivore: {OMNIVORE_DRIVE_PATH}")
print(f"  SlowFast: {SLOWFAST_DRIVE_PATH}")
print(f"  Checkpoints: {CHECKPOINTS_DRIVE_PATH}")
print(f"  Annotations: {ANNOTATIONS_DRIVE_PATH}")


Data paths configured:
  Use Google Drive: True
  Omnivore: /content/drive/MyDrive/AML_mistake_detection/omnivore.zip
  SlowFast: /content/drive/MyDrive/AML_mistake_detection/slowfast.zip
  Checkpoints: /content/drive/MyDrive/AML_mistake_detection/error_recognition_best.zip
  Annotations: /content/drive/MyDrive/AML_mistake_detection/annotations


In [7]:
# Mount Google Drive if using it
if USE_GOOGLE_DRIVE:
    from google.colab import drive
    drive.mount('/content/drive')
    print("✓ Google Drive mounted")
else:
    print("⚠ Google Drive not mounted. Please upload files directly using the file browser.")


Mounted at /content/drive
✓ Google Drive mounted


In [8]:
# Create data directory structure
import os
os.makedirs('data/video/omnivore', exist_ok=True)
os.makedirs('data/video/slowfast', exist_ok=True)
os.makedirs('checkpoints', exist_ok=True)
os.makedirs('annotations/annotation_json', exist_ok=True)
os.makedirs('annotations/data_splits', exist_ok=True)
os.makedirs('er_annotations', exist_ok=True)

print("✓ Directory structure created")


✓ Directory structure created


In [9]:
# Load features from Google Drive or direct upload
import os
import shutil
import subprocess
import glob

def load_features(source_path, dest_path, feature_name):
    """Load features from source (zip file or directory) to destination."""
    if not os.path.exists(source_path):
        print(f"⚠ {feature_name}: Source path not found: {source_path}")
        return False

    print(f"Loading {feature_name} features from: {source_path}")

    # Check if it's a zip file
    is_zip = source_path.lower().endswith('.zip') or (os.path.isfile(source_path) and 'zip' in str(source_path))

    if is_zip:
        print(f"  Detected zip file, extracting...")
        temp_zip = f'/tmp/{feature_name.lower()}.zip'
        temp_extracted = f'/tmp/{feature_name.lower()}_extracted'

        try:
            shutil.copy(source_path, temp_zip)
            subprocess.run(['unzip', '-q', temp_zip, '-d', temp_extracted], check=True)

            # Find .npz files in extracted directory
            npz_files = glob.glob(os.path.join(temp_extracted, '**/*.npz'), recursive=True)

            if npz_files:
                # Copy all .npz files to destination
                for npz_file in npz_files:
                    shutil.copy2(npz_file, dest_path)
                print(f"  ✓ Extracted and copied {len(npz_files)} .npz files")

                # Cleanup
                shutil.rmtree(temp_extracted, ignore_errors=True)
                os.remove(temp_zip)
                return True
            else:
                print(f"  ⚠ No .npz files found in extracted zip")
                shutil.rmtree(temp_extracted, ignore_errors=True)
                os.remove(temp_zip)
                return False
        except Exception as e:
            print(f"  ✗ Error extracting {feature_name} zip: {e}")
            if os.path.exists(temp_extracted):
                shutil.rmtree(temp_extracted, ignore_errors=True)
            if os.path.exists(temp_zip):
                os.remove(temp_zip)
            return False
    else:
        # It's a directory
        print(f"  Detected directory, copying .npz files...")
        npz_files = glob.glob(os.path.join(source_path, '**/*.npz'), recursive=True)

        if npz_files:
            # Copy all .npz files to destination
            for npz_file in npz_files:
                shutil.copy2(npz_file, dest_path)
            print(f"  ✓ Copied {len(npz_files)} .npz files")
            return True
        else:
            print(f"  ⚠ No .npz files found in {source_path}")
            return False

# Load Omnivore and SlowFast features
if USE_GOOGLE_DRIVE:
    load_features(OMNIVORE_DRIVE_PATH, 'data/video/omnivore', 'Omnivore')
    load_features(SLOWFAST_DRIVE_PATH, 'data/video/slowfast', 'SlowFast')
else:
    print("⚠ Please upload features manually:")
    print("  1. Use the file browser to upload .npz files or zip files")
    print("  2. Extract/copy them to data/video/omnivore/ and data/video/slowfast/")

# Verify features
omnivore_count = len([f for f in os.listdir('data/video/omnivore') if f.endswith('.npz')]) if os.path.exists('data/video/omnivore') else 0
slowfast_count = len([f for f in os.listdir('data/video/slowfast') if f.endswith('.npz')]) if os.path.exists('data/video/slowfast') else 0
print(f"\nFeature file counts:")
print(f"  Omnivore: {omnivore_count} .npz files")
print(f"  SlowFast: {slowfast_count} .npz files")


Loading Omnivore features from: /content/drive/MyDrive/AML_mistake_detection/omnivore.zip
  Detected zip file, extracting...
  ✓ Extracted and copied 384 .npz files
Loading SlowFast features from: /content/drive/MyDrive/AML_mistake_detection/slowfast.zip
  Detected zip file, extracting...
  ✓ Extracted and copied 384 .npz files

Feature file counts:
  Omnivore: 384 .npz files
  SlowFast: 384 .npz files


In [10]:
# Load annotations (if not already in repository)
import os
import shutil

if USE_GOOGLE_DRIVE and os.path.exists(ANNOTATIONS_DRIVE_PATH):
    print(f"Loading annotations from: {ANNOTATIONS_DRIVE_PATH}")

    # Copy annotation_json
    annotation_json_src = os.path.join(ANNOTATIONS_DRIVE_PATH, 'annotation_json')
    if os.path.exists(annotation_json_src):
        for file in os.listdir(annotation_json_src):
            src = os.path.join(annotation_json_src, file)
            dst = os.path.join('annotations/annotation_json', file)
            if os.path.isfile(src):
                shutil.copy2(src, dst)
                print(f"  ✓ Copied {file}")

    # Copy data_splits
    data_splits_src = os.path.join(ANNOTATIONS_DRIVE_PATH, 'data_splits')
    if os.path.exists(data_splits_src):
        for file in os.listdir(data_splits_src):
            src = os.path.join(data_splits_src, file)
            dst = os.path.join('annotations/data_splits', file)
            if os.path.isfile(src):
                shutil.copy2(src, dst)
                print(f"  ✓ Copied {file}")

    # Copy er_annotations
    er_annotations_src = os.path.join(ANNOTATIONS_DRIVE_PATH, 'er_annotations')
    if os.path.exists(er_annotations_src):
        for file in os.listdir(er_annotations_src):
            src = os.path.join(er_annotations_src, file)
            dst = os.path.join('er_annotations', file)
            if os.path.isfile(src):
                shutil.copy2(src, dst)
                print(f"  ✓ Copied {file}")
else:
    print("⚠ Annotations not found in Drive. Checking repository...")

# Verify required annotation files
print("\nVerifying annotation files...")
required_files = [
    'annotations/annotation_json/step_annotations.json',
    'annotations/annotation_json/error_annotations.json',
    'er_annotations/recordings_combined_splits.json'
]

missing = []
for file in required_files:
    if os.path.exists(file):
        print(f"✓ Found: {file}")
    else:
        print(f"✗ Missing: {file}")
        missing.append(file)

if missing:
    print(f"\n⚠ Warning: {len(missing)} required annotation file(s) are missing!")
    print("Please ensure these files are available before running training.")
else:
    print("\n✓ All required annotation files are present!")


Loading annotations from: /content/drive/MyDrive/AML_mistake_detection/annotations
  ✓ Copied error_category_idx.json
  ✓ Copied complete_step_annotations.json
  ✓ Copied activity_idx_step_idx.json
  ✓ Copied step_annotations.json
  ✓ Copied step_idx_description.json
  ✓ Copied recording_id_step_idx.json
  ✓ Copied error_annotations.json
  ✓ Copied recordings_data_split_normal.json
  ✓ Copied recipes_data_split_normal.json
  ✓ Copied recipes_data_split_combined.json
  ✓ Copied person_data_split_combined.json
  ✓ Copied environment_data_split_combined.json
  ✓ Copied environment_data_split_normal.json
  ✓ Copied person_data_split_normal.json
  ✓ Copied recordings_data_split_combined.json

Verifying annotation files...
✓ Found: annotations/annotation_json/step_annotations.json
✓ Found: annotations/annotation_json/error_annotations.json
✓ Found: er_annotations/recordings_combined_splits.json

✓ All required annotation files are present!


In [11]:
# Load checkpoints (optional - for comparison with existing baselines)
import os
import shutil
import subprocess

checkpoint_path = CHECKPOINTS_DRIVE_PATH if USE_GOOGLE_DRIVE else None

if checkpoint_path and os.path.exists(checkpoint_path):
    print(f"Loading checkpoints from: {checkpoint_path}")

    # Check if it's a zip file
    is_zip = checkpoint_path.lower().endswith('.zip') or (os.path.isfile(checkpoint_path) and 'zip' in str(checkpoint_path))

    if is_zip:
        print("Detected zip file, extracting...")
        shutil.copy(checkpoint_path, '/tmp/checkpoints.zip')

        try:
            subprocess.run(['unzip', '-q', '/tmp/checkpoints.zip', '-d', '/tmp/checkpoints_extracted'], check=True)

            # Find error_recognition_best directory
            extracted_base = '/tmp/checkpoints_extracted'
            extracted_path = None

            # Check common locations
            if os.path.exists(os.path.join(extracted_base, 'error_recognition_best')):
                extracted_path = os.path.join(extracted_base, 'error_recognition_best')
            elif os.path.exists(os.path.join(extracted_base, 'MLP')) or os.path.exists(os.path.join(extracted_base, 'Transformer')):
                extracted_path = extracted_base
            else:
                # Search recursively
                for root, dirs, files in os.walk(extracted_base):
                    if 'error_recognition_best' in dirs:
                        extracted_path = os.path.join(root, 'error_recognition_best')
                        break
                    if 'MLP' in dirs or 'Transformer' in dirs:
                        extracted_path = root
                        break

                if extracted_path is None:
                    extracted_path = extracted_base

            print(f"Copying from: {extracted_path}")
            shutil.copytree(extracted_path, 'checkpoints/error_recognition_best', dirs_exist_ok=True)

            # Cleanup
            shutil.rmtree('/tmp/checkpoints_extracted', ignore_errors=True)
            os.remove('/tmp/checkpoints.zip')
            print("✓ Checkpoints extracted")
        except Exception as e:
            print(f"✗ Error extracting checkpoints: {e}")
    else:
        # It's a directory
        print("Detected directory, copying...")
        if os.path.basename(checkpoint_path) == 'error_recognition_best':
            shutil.copytree(checkpoint_path, 'checkpoints/error_recognition_best', dirs_exist_ok=True)
        else:
            os.makedirs('checkpoints/error_recognition_best', exist_ok=True)
            for item in os.listdir(checkpoint_path):
                src = os.path.join(checkpoint_path, item)
                dst = os.path.join('checkpoints/error_recognition_best', item)
                if os.path.isdir(src):
                    shutil.copytree(src, dst, dirs_exist_ok=True)
                else:
                    shutil.copy2(src, dst)
        print("✓ Checkpoints copied")
else:
    print("⚠ Checkpoints not found. This is optional - you can still train the RNN baseline.")
    print("   If you want to compare with existing baselines, download checkpoints from:")
    print("   https://utdallas.app.box.com/s/uz3s1alrzucz03sleify8kazhuc1ksl3")

# Verify checkpoints
if os.path.exists('checkpoints/error_recognition_best'):
    pt_files = []
    for root, dirs, files in os.walk('checkpoints/error_recognition_best'):
        pt_files.extend([os.path.join(root, f) for f in files if f.endswith('.pt')])
    print(f"\n✓ Found {len(pt_files)} checkpoint files")
else:
    print("\n⚠ Checkpoints directory not found (this is optional)")


Loading checkpoints from: /content/drive/MyDrive/AML_mistake_detection/error_recognition_best.zip
Detected zip file, extracting...
Copying from: /tmp/checkpoints_extracted/error_recognition_best
✓ Checkpoints extracted

✓ Found 54 checkpoint files


## 3. Train RNN Baseline with Omnivore Features


In [15]:
!wandb login

[34m[1mwandb[0m: Logging into https://api.wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)
[34m[1mwandb[0m: Find your API key here: https://wandb.ai/authorize?ref=models
[34m[1mwandb[0m: Paste an API key from your profile and hit enter: 
[34m[1mwandb[0m: No netrc file found, creating one.
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc
[34m[1mwandb[0m: Currently logged in as: [33maexomir[0m ([33maexomir-politecnico-di-torino[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [19]:
# Train RNN baseline with Omnivore features
# Default hyperparameters: hidden_size=256, num_layers=2, bidirectional=True, rnn_type=LSTM
import subprocess
import sys
import os

# Ensure we're in the repository root directory
repo_root = os.getcwd()
if not os.path.exists("scripts/train_rnn_baseline.py"):
    print(f"⚠ Error: scripts/train_rnn_baseline.py not found in {repo_root}")
    print("Please make sure you're in the repository root directory.")
else:
    print(f"Running from directory: {repo_root}")

    cmd = [
        sys.executable, "scripts/train_rnn_baseline.py",
        "--variant", "RNN",
        "--backbone", "omnivore",
        "--split", "recordings",
        "--batch_size", "4",
        "--num_epochs", "20",
        "--lr", "1e-3",
        "--weight_decay", "1e-3",
        "--rnn_hidden_size", "256",
        "--rnn_num_layers", "2",
        "--rnn_dropout", "0.2",
        "--rnn_bidirectional", "True",
        "--rnn_type", "LSTM",
        # "--segment_features_directory", "data/"
    ]

    print("\nRunning command:")
    print(" ".join(cmd))
    print("\n" + "="*60 + "\n")

    # Set PYTHONPATH to include repo root for imports
    env = os.environ.copy()
    env['PYTHONPATH'] = repo_root + (os.pathsep + env.get('PYTHONPATH', ''))

    # Run from the repository root directory and capture output
    result = subprocess.run(
        cmd,
        cwd=repo_root,
        env=env,
        check=False,
        stdout=subprocess.PIPE,
        stderr=subprocess.STDOUT,
        text=True,
        bufsize=1
    )

    # Print output in real-time (already captured, but print it)
    if result.stdout:
        print(result.stdout)

    if result.returncode != 0:
        print(f"\n⚠ Training failed with exit code {result.returncode}")
        if result.stderr:
            print("Error output:")
            print(result.stderr)
    else:
        print("\n✓ Training completed successfully!")


[1;30;43mStreaming output truncated to the last 5000 lines.[0m
val Progress: 34/681:   4%|▍         | 28/681 [00:00<00:07, 84.32it/s]
val Progress: 35/681:   4%|▍         | 28/681 [00:00<00:07, 84.32it/s]
val Progress: 36/681:   4%|▍         | 28/681 [00:00<00:07, 84.32it/s]
val Progress: 37/681:   4%|▍         | 28/681 [00:00<00:07, 84.32it/s]
val Progress: 38/681:   4%|▍         | 28/681 [00:00<00:07, 84.32it/s]
val Progress: 39/681:   4%|▍         | 28/681 [00:00<00:07, 84.32it/s]
val Progress: 40/681:   4%|▍         | 28/681 [00:00<00:07, 84.32it/s]
val Progress: 41/681:   4%|▍         | 28/681 [00:00<00:07, 84.32it/s]
val Progress: 42/681:   4%|▍         | 28/681 [00:00<00:07, 84.32it/s]
val Progress: 43/681:   4%|▍         | 28/681 [00:00<00:07, 84.32it/s]
val Progress: 43/681:   6%|▋         | 43/681 [00:00<00:06, 104.90it/s]
val Progress: 44/681:   6%|▋         | 43/681 [00:00<00:06, 104.90it/s]
val Progress: 45/681:   6%|▋         | 43/681 [00:00<00:06, 104.90it/s]
val Progr

### Patching `core/models/blocks.py` to fix `pack_padded_sequence` error

In [18]:
import os

file_path = "core/models/blocks.py"

if os.path.exists(file_path):
    with open(file_path, 'r') as f:
        content = f.read()

    original_line = "x = nn.utils.rnn.pack_padded_sequence(x, lengths, batch_first=True, enforce_sorted=False)"
    replacement_line = "x = nn.utils.rnn.pack_padded_sequence(x, lengths.cpu(), batch_first=True, enforce_sorted=False)"

    if original_line in content:
        new_content = content.replace(original_line, replacement_line)
        with open(file_path, 'w') as f:
            f.write(new_content)
        print(f"Patched {file_path}: 'lengths' argument for pack_padded_sequence moved to CPU.")
    else:
        print(f"Original line not found in {file_path}. Patching failed or already applied.")
else:
    print(f"Error: {file_path} not found.")


Patched core/models/blocks.py: 'lengths' argument for pack_padded_sequence moved to CPU.


## 4. Train RNN Baseline with SlowFast Features


In [None]:
# Train RNN baseline with SlowFast features
import subprocess
import sys
import os

# Ensure we're in the repository root directory
repo_root = os.getcwd()
if not os.path.exists("scripts/train_rnn_baseline.py"):
    print(f"⚠ Error: scripts/train_rnn_baseline.py not found in {repo_root}")
    print("Please make sure you're in the repository root directory.")
else:
    print(f"Running from directory: {repo_root}")

    cmd = [
        sys.executable, "scripts/train_rnn_baseline.py",
        "--variant", "RNN",
        "--backbone", "slowfast",
        "--split", "recordings",
        "--batch_size", "4",
        "--num_epochs", "20",
        "--lr", "1e-3",
        "--weight_decay", "1e-3",
        "--rnn_hidden_size", "256",
        "--rnn_num_layers", "2",
        "--rnn_dropout", "0.2",
        "--rnn_bidirectional", "True",
        "--rnn_type", "LSTM",
        "--segment_features_directory", "data/"
    ]

    print("\nRunning command:")
    print(" ".join(cmd))
    print("\n" + "="*60 + "\n")

    # Set PYTHONPATH to include repo root for imports
    env = os.environ.copy()
    env['PYTHONPATH'] = repo_root + (os.pathsep + env.get('PYTHONPATH', ''))

    # Run from the repository root directory and capture output
    result = subprocess.run(
        cmd,
        cwd=repo_root,
        env=env,
        check=False,
        stdout=subprocess.PIPE,
        stderr=subprocess.STDOUT,
        text=True,
        bufsize=1
    )

    # Print output
    if result.stdout:
        print(result.stdout)

    if result.returncode != 0:
        print(f"\n⚠ Training failed with exit code {result.returncode}")
        if result.stderr:
            print("Error output:")
            print(result.stderr)
    else:
        print("\n✓ Training completed successfully!")


## 5. Evaluate Trained Model


In [None]:
# Evaluate the best model (update checkpoint path as needed)
import subprocess
import sys
import os

# Ensure we're in the repository root directory
repo_root = os.getcwd()

# Find the best checkpoint
checkpoint_dir = "checkpoints/error_recognition/RNN/omnivore"
if os.path.exists(checkpoint_dir):
    # Look for the best model
    checkpoint_files = [f for f in os.listdir(checkpoint_dir) if f.endswith('_best.pt')]
    if checkpoint_files:
        checkpoint_path = os.path.join(checkpoint_dir, checkpoint_files[0])
    else:
        # Fallback to any .pt file
        checkpoint_files = [f for f in os.listdir(checkpoint_dir) if f.endswith('.pt')]
        if checkpoint_files:
            checkpoint_path = os.path.join(checkpoint_dir, sorted(checkpoint_files)[-1])  # Get latest
        else:
            checkpoint_path = "checkpoints/error_recognition/RNN/omnivore/error_recognition_recordings_omnivore_RNN_video_best.pt"
else:
    checkpoint_path = "checkpoints/error_recognition/RNN/omnivore/error_recognition_recordings_omnivore_RNN_video_best.pt"

print(f"Using checkpoint: {checkpoint_path}")
if not os.path.exists(checkpoint_path):
    print(f"⚠ Warning: Checkpoint not found at {checkpoint_path}")
    print("Please update the checkpoint path in this cell.")

cmd = [
    sys.executable, "-m", "core.evaluate",
    "--variant", "RNN",
    "--backbone", "omnivore",
    "--split", "recordings",
    "--ckpt", checkpoint_path,
    "--threshold", "0.6"
]

print(f"\nRunning from directory: {repo_root}")
print("Running command:")
print(" ".join(cmd))
print("\n" + "="*60 + "\n")

# Set PYTHONPATH to include repo root for imports
env = os.environ.copy()
env['PYTHONPATH'] = repo_root + (os.pathsep + env.get('PYTHONPATH', ''))

# Run from the repository root directory
result = subprocess.run(cmd, cwd=repo_root, env=env, check=False)
if result.returncode != 0:
    print(f"\n⚠ Evaluation failed with exit code {result.returncode}")
    print("Check the error messages above for details.")
else:
    print("\n✓ Evaluation completed successfully!")


## 6. Compare Results

Results are saved to `results/error_recognition/combined_results/`. Compare V_RNN against V1 (MLP) and V2 (Transformer) using the same CSV file. Ensure you use the same split, backbone, and threshold for fair comparison.


In [None]:
# Display results comparison
import pandas as pd
import os

results_file = "results/error_recognition/combined_results/step_True_substep_True_threshold_0.6.csv"
if os.path.exists(results_file):
    df = pd.read_csv(results_file)
    # Filter for the same backbone and split
    print("Results Comparison (same backbone and split):")
    print(df[['Variant', 'Backbone', 'Split', 'Step Precision', 'Step Recall', 'Step F1', 'Step Accuracy', 'Step AUC']])
else:
    print(f"Results file not found: {results_file}")
    print("Please run training first.")
