# MECH_INTERP_PHYSICS_REASONING - Google Colab Evaluation

This notebook allows you to run the PaliGemma evaluation on Google Colab with GPU support.

## ⚠️ IMPORTANT: Setup GPU Runtime First!

Before running this notebook:
1. Go to `Runtime` → `Change runtime type`
2. Select `GPU` as Hardware accelerator (T4 is fine)
3. Click `Save`

## 1. Check GPU Availability

In [None]:
# Check GPU availability
import torch
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.1f} GB")
else:
    print("⚠️ No GPU detected! Please enable GPU in Runtime settings.")

## 2. Mount Google Drive (Optional)

In [None]:
# Mount Google Drive to save results or access data
from google.colab import drive
drive.mount('/content/drive')

## 3. Setup Repository

Choose one of the following methods to get your code:

In [None]:
# Method 1: Upload files directly
# You can drag and drop your project folder into the Colab file browser

# Method 2: Clone from GitHub (uncomment and modify)
# !git clone https://github.com/YOUR_USERNAME/mech_interp_physics_reasoning.git
# %cd mech_interp_physics_reasoning

# Method 3: Copy from Google Drive (uncomment and modify)
# !cp -r /content/drive/MyDrive/mech_interp_physics_reasoning .
# %cd mech_interp_physics_reasoning

# Method 4: Upload a zip file
from google.colab import files
print("Upload your project as a zip file:")
uploaded = files.upload()
if uploaded:
    !unzip -q *.zip
    # Change to the project directory
    import os
    dirs = [d for d in os.listdir('.') if os.path.isdir(d) and 'mech_interp' in d]
    if dirs:
        %cd {dirs[0]}
    !ls -la

## 4. Install Dependencies

In [None]:
# Install required packages
!pip install -q torch torchvision transformers>=4.36.0
!pip install -q peft accelerate bitsandbytes
!pip install -q Pillow numpy pyyaml
!pip install -q wandb pytz

# Verify installations
import transformers
import peft
print(f"✓ Transformers version: {transformers.__version__}")
print(f"✓ PEFT version: {peft.__version__}")
print("✓ All packages installed successfully!")

## 5. Check and Fix Configuration

In [None]:
# Check current directory structure
!pwd
!ls -la

In [None]:
# Fix the test_size issue in config
import yaml
import os

config_path = "base_eval_config.yaml"

if os.path.exists(config_path):
    # Read current config
    with open(config_path, 'r') as f:
        config = yaml.safe_load(f)

    print(f"Current test_size: {config['data_config']['test_size']}")

    # Update test_size if needed
    if config['data_config']['test_size'] < 10:
        config['data_config']['test_size'] = 100  # or 0.2 for 20%
        
        with open(config_path, 'w') as f:
            yaml.dump(config, f, default_flow_style=False)
        
        print(f"✓ Updated test_size to: {config['data_config']['test_size']}")
    else:
        print(f"✓ test_size is already set to: {config['data_config']['test_size']}")
else:
    print(f"❌ Config file not found at {config_path}")
    print("Available files:")
    !ls -la *.yaml

## 6. Create Colab-Optimized Evaluation Script

In [None]:
%%writefile scripts/eval_colab.py
import os
# Use GPU 0 in Colab
os.environ["CUDA_VISIBLE_DEVICES"] = '0'

import sys
import torch
from datetime import datetime
import argparse
import yaml
import numpy as np
import json
from transformers import TrainingArguments
from peft import PeftModel
from collections import defaultdict
import glob


def get_device_map():
    return "cuda" if torch.cuda.is_available() else "cpu"

def main():
    parser = argparse.ArgumentParser(description="Evaluate a PaliGemma model on CLEVRER test set.")
    parser.add_argument("checkpoint_dir", type=str, nargs="?", help="Relative path to checkpoint folder")
    parser.add_argument("--base", action="store_true", help="Evaluate base model without LoRA")
    args = parser.parse_args()

    script_dir = os.path.dirname(os.path.abspath(__file__))
    home_dir = os.environ.get("HOME_DIR", os.path.abspath(os.path.join(script_dir, "..")))
    HOME_DIR = home_dir
    sys.path.insert(0, HOME_DIR)

    if args.base:
        config_path = os.path.join(home_dir, "base_eval_config.yaml")
    else:
        if not args.checkpoint_dir:
            raise ValueError("checkpoint_dir must be provided unless --base is used")
        checkpoint_dir = os.path.join(home_dir, args.checkpoint_dir)
        config_path = os.path.join(checkpoint_dir, "config.yaml")

    with open(config_path, "r") as f:
        config = yaml.safe_load(f)

    from src.processing_paligemma import PaliGemmaProcessor
    from src.modeling_paligemma import PaliGemmaForConditionalGeneration
    from src.utils import make_clevrer_collate_fn, compute_accuracy, CLEVRERTrainer
    from src.data import ClevrerDataset

    model_config = config["model_train"]
    data_config = config["data_config"]
    model_id = model_config["model"]
    question_type = data_config.get("question_type", "all")

    test_frames_dir = os.path.join(HOME_DIR, data_config.get("data_path", "test_frames"))
    annotations_path = os.path.join(HOME_DIR, data_config.get("json_path", "miscellaneous/validation.json"))

    print(f"Loading dataset from: {test_frames_dir}")
    print(f"Annotations from: {annotations_path}")

    dataset = ClevrerDataset(
        frames_root=test_frames_dir,
        json_path=annotations_path,
        question_type=question_type,
        transform=None,
        shuffle=False
    )

    print(f"Total samples in dataset: {len(dataset)}")

    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    if args.base:
        results_root = os.path.join(HOME_DIR, "artifacts", "BASE", f"eval_{question_type}_{timestamp}")
    else:
        results_root = os.path.join(checkpoint_dir, f"eval_{question_type}_{timestamp}")

    os.makedirs(results_root, exist_ok=True)
    split_cache_path = os.path.join(results_root, "split_indices.json")

    train_ds, test_ds = dataset.train_test_split(
        test_size=data_config['test_size'],
        cache_path=split_cache_path
    )

    print(f"Test dataset size: {len(test_ds)}")

    device = get_device_map()
    print(f"Using device: {device}")
    
    print("Loading model...")
    if args.base:
        model = PaliGemmaForConditionalGeneration.from_pretrained(
            model_id,
            device_map="auto",
            torch_dtype=torch.bfloat16 if device == "cuda" else torch.float32,
            attn_implementation="eager",
            token_compression=model_config.get('token_compression')
        )
    else:
        base_model = PaliGemmaForConditionalGeneration.from_pretrained(
            model_id,
            device_map="auto",
            torch_dtype=torch.bfloat16 if device == "cuda" else torch.float32,
            attn_implementation="eager",
            token_compression=model_config.get("token_compression"),
            target_length=model_config.get("target_length")
        )
        last_checkpoint = glob.glob(os.path.join(checkpoint_dir, "checkpoint-*"))[-1]
        model = PeftModel.from_pretrained(base_model, last_checkpoint)

    print("Model loaded successfully!")

    processor = PaliGemmaProcessor.from_pretrained(model_id)
    collate_fn = make_clevrer_collate_fn(
        model=model,
        processor=processor,
        model_config=model_config,
        data_config=data_config,
        dtype=model.dtype
    )

    eval_args = TrainingArguments(
        output_dir="/tmp/eval",
        per_device_eval_batch_size=model_config.get("eval_batch_size", 4),
        eval_accumulation_steps=1,
        dataloader_pin_memory=False,
        bf16=True if device == "cuda" else False,
        remove_unused_columns=False,
        report_to=[],
        save_strategy="no",
        logging_strategy="no"
    )

    trainer = CLEVRERTrainer(
        model=model,
        args=eval_args,
        eval_dataset=test_ds,
        data_collator=collate_fn,
        compute_metrics=compute_accuracy,
        processing_class=processor
    )

    print("Starting evaluation...")
    pred_output = trainer.predict(test_ds)
    
    # Process results
    preds = pred_output.predictions.tolist()
    labels = pred_output.label_ids.tolist()

    special_tokens = {0, 1, 2, 3, -100, 257152}
    correct_flags = []
    per_sample_results = []
    type_correct = defaultdict(int)
    type_total = defaultdict(int)

    for i, (pred_row, label_row) in enumerate(zip(preds, labels)):
        item = test_ds[i]
        qtype = item["question_type"]

        filtered_pred = [x for x in pred_row if x not in special_tokens]
        filtered_label = [x for x in label_row if x not in special_tokens]

        correct = sorted(filtered_pred) == sorted(filtered_label)
        correct_flags.append(correct)

        type_total[qtype] += 1
        type_correct[qtype] += int(correct)

        per_sample_results.append({
            "question_type": qtype,
            "question_id": item["question_id"],
            "video_filename": item["video_filename"],
            "predicted_token_ids": filtered_pred,
            "label_token_ids": filtered_label,
            "correct": correct
        })

    # Save results
    results_file = os.path.join(results_root, "eval_results.txt")
    details_path = os.path.join(results_root, "eval_details.json")
    
    with open(details_path, "w") as f:
        json.dump(per_sample_results, f, indent=2)

    accuracy = np.mean(correct_flags) if correct_flags else 0.0
    with open(results_file, "w") as f:
        f.write(f"accuracy: {accuracy:.4f}\n\n")
        f.write("Question Type Accuracies:\n")
        for qtype, total in type_total.items():
            acc = type_correct[qtype] / total if total > 0 else 0.0
            f.write(f"{qtype}: {acc:.4f} ({type_correct[qtype]}/{total})\n")

    print(f"\n{'='*50}")
    print(f"Evaluation Complete!")
    print(f"{'='*50}")
    print(f"Overall Accuracy: {accuracy:.4f}")
    print(f"\nQuestion Type Accuracies:")
    for qtype, total in type_total.items():
        acc = type_correct[qtype] / total if total > 0 else 0.0
        print(f"  {qtype}: {acc:.4f} ({type_correct[qtype]}/{total})")
    print(f"\nResults saved to: {results_root}")

if __name__ == "__main__":
    main()

## 7. Test Data Availability

In [None]:
# Check if data directories exist
import os

print("Checking data directories...")
data_dirs = ['test_frames', 'train_frames', 'miscellaneous']
for dir_name in data_dirs:
    if os.path.exists(dir_name):
        count = len(os.listdir(dir_name))
        print(f"✓ {dir_name}: {count} items")
    else:
        print(f"❌ {dir_name}: NOT FOUND")

# Check for annotation files
ann_files = ['miscellaneous/validation.json', 'miscellaneous/train.json']
for ann_file in ann_files:
    if os.path.exists(ann_file):
        print(f"✓ {ann_file}: EXISTS")
    else:
        print(f"❌ {ann_file}: NOT FOUND")

## 8. Run Evaluation

In [None]:
# Run base model evaluation
!python scripts/eval_colab.py --base

In [None]:
# If you have a checkpoint to evaluate, run this instead:
# !python scripts/eval_colab.py path/to/your/checkpoint

## 9. View Results

In [None]:
import json
import glob
import os

# Find the latest results
result_dirs = glob.glob("artifacts/BASE/eval_*")
if result_dirs:
    latest_dir = max(result_dirs, key=os.path.getctime)
    print(f"Latest results directory: {latest_dir}")
    
    # Display text results
    results_file = os.path.join(latest_dir, "eval_results.txt")
    if os.path.exists(results_file):
        with open(results_file, 'r') as f:
            print("\n" + "="*50)
            print("EVALUATION RESULTS")
            print("="*50)
            print(f.read())
    
    # Load and analyze detailed results
    details_file = os.path.join(latest_dir, "eval_details.json")
    if os.path.exists(details_file):
        with open(details_file, 'r') as f:
            details = json.load(f)
        
        print(f"\nTotal samples evaluated: {len(details)}")
        
        # Show some sample predictions
        print("\nSample predictions (first 5):")
        print("-" * 50)
        for i, result in enumerate(details[:5]):
            print(f"\nSample {i+1}:")
            print(f"  Video: {result['video_filename']}")
            print(f"  Question Type: {result['question_type']}")
            print(f"  Correct: {'✓' if result['correct'] else '✗'}")
            if not result['correct']:
                print(f"  Predicted tokens: {result['predicted_token_ids'][:10]}...")
                print(f"  Expected tokens: {result['label_token_ids'][:10]}...")
else:
    print("No results found yet. Run the evaluation first!")

## 10. Save Results to Google Drive

In [None]:
# Save results to Google Drive
import shutil

if 'drive' in globals() and result_dirs:
    save_path = "/content/drive/MyDrive/mech_interp_results"
    os.makedirs(save_path, exist_ok=True)
    
    latest_dir = max(result_dirs, key=os.path.getctime)
    dest_dir = os.path.join(save_path, os.path.basename(latest_dir))
    
    shutil.copytree(latest_dir, dest_dir, dirs_exist_ok=True)
    print(f"✓ Results saved to Google Drive: {dest_dir}")
else:
    print("Google Drive not mounted or no results to save.")

## Troubleshooting

### Out of Memory (OOM) Error
If you encounter OOM errors, try these solutions:

In [None]:
# Solution 1: Reduce batch size
import yaml

with open('base_eval_config.yaml', 'r') as f:
    config = yaml.safe_load(f)

# Set smaller batch size
config['model_train']['eval_batch_size'] = 1

with open('base_eval_config.yaml', 'w') as f:
    yaml.dump(config, f)

print("✓ Batch size reduced to 1")

In [None]:
# Solution 2: Clear GPU memory
import torch
import gc

torch.cuda.empty_cache()
gc.collect()
print("✓ GPU memory cleared")

### Debug Single Sample

In [None]:
# Test with a single sample to debug
import sys
sys.path.insert(0, '.')

from src.data import ClevrerDataset

dataset = ClevrerDataset(
    frames_root="test_frames",
    json_path="miscellaneous/validation.json",
    question_type="descriptive"
)

if len(dataset) > 0:
    sample = dataset[0]
    print(f"Sample question: {sample['question']}")
    print(f"Expected answer: {sample['answer']}")
    print(f"Question type: {sample['question_type']}")
    print(f"Number of frames: {len(sample['frames'])}")
else:
    print("Dataset is empty!")