# BabyAI PyTorch Training Setup (Colab)

This notebook contains all the setup steps needed to run Act-PRM training on BabyAI using PyTorch in Google Colab.

## Cell Info:
1. **Cell 1**: Clone act-prm-tinker repository (`pytorch` branch + BabyAI files from `madison/babyai`)
2. **Cell 2**: Install PyTorch ecosystem (may require runtime restart)
3. **Cell 3**: Install act-prm package and dependencies
4. **Cell 4**: Install BabyAI dependencies (gym, babyai-text, babyai, gym-minigrid)
5. **Cell 5**: Authenticate with Hugging Face and W&B
6. **Cell 6**: Fix model config (remove FlashAttention2)
7. **Cell 7**: Post-restart setup: BabyAI sys.path (run after every kernel restart)
8. **Cell 8**: Run training

## Important Notes:
- The PyTorch training infrastructure lives on the `pytorch` branch. The BabyAI environment and configs live on `madison/babyai`. Cell 1 checks out `pytorch` and cherry-picks the BabyAI files.
- BabyAI depends on the **old** `gym` (not `gymnasium`) and is **incompatible with NumPy >= 2.0** out of the box. A compatibility shim in `env.py` patches `np.bool8`; see `src/act_prm/environments/babyai_text/README.md` for details.
- After installing PyTorch (Cell 2), you may need to **restart the runtime** and then continue from Cell 3.

In [None]:
# Clone act-prm-tinker Repository
# Checks out 'pytorch' branch, then cherry-picks BabyAI files from 'madison/babyai'
import os

# Get token from Colab secrets (if you set it up) or prompt user
github_token = os.environ.get("GITHUB_TOKEN", "")

if not github_token:
    from getpass import getpass
    github_token = getpass("Paste your GitHub Personal Access Token: ")

repo_path = '/content/act-prm-tinker'
if os.path.exists(repo_path):
    print(f"✓ Repository already exists at {repo_path}")
    %cd {repo_path}
    print("Pulling latest changes...")
    !git pull origin pytorch || echo "Note: Could not pull pytorch branch"
else:
    print(f"Cloning repository to {repo_path}...")
    !git clone https://{github_token}@github.com/HazyResearch/act-prm-tinker.git {repo_path}
    %cd {repo_path}
    !git checkout pytorch
    print("✓ Checked out pytorch branch")

# Cherry-pick BabyAI environment files and configs from madison/babyai
print("\nAdding BabyAI files from madison/babyai branch...")
!git fetch origin madison/babyai
!git checkout origin/madison/babyai -- \
    src/act_prm/environments/babyai_text/ \
    configs/environments/babyai/ \
    configs/environments/act_prm/babyai.yaml

# Patch __init__.py to register the babyai_text environment
init_path = 'src/act_prm/environments/__init__.py'
with open(init_path, 'r') as f:
    content = f.read()

if 'babyai_text' not in content:
    babyai_block = '''
    elif name == "babyai_text":
        if is_async:
            from .babyai_text import AsyncBabyAiTextEnv
            return AsyncBabyAiTextEnv(**kwargs)
        else:
            from .babyai_text import BabyAiTextEnv
            return BabyAiTextEnv(**kwargs)

    raise NotImplementedError'''

    content = content.replace(
        '    raise NotImplementedError',
        babyai_block,
        1  # replace only the first occurrence
    )
    with open(init_path, 'w') as f:
        f.write(content)
    print("✓ Registered babyai_text in environment factory")
else:
    print("✓ babyai_text already registered in environment factory")

print(f"\nCurrent directory: {os.getcwd()}")
print(f"Repository exists: {os.path.exists(repo_path)}")

In [None]:
# Install PyTorch Ecosystem (CUDA 12.1)
# ⚠️ This cell may require a runtime restart after installation. After restart, continue from next cell

import subprocess
import sys

def check_package_installed(package_name, version_check=None):
    """Check if a package is installed and optionally verify version"""
    try:
        result = subprocess.run(
            [sys.executable, '-m', 'pip', 'show', package_name],
            capture_output=True, text=True
        )
        if result.returncode == 0:
            if version_check:
                # Extract version from output
                for line in result.stdout.split('\n'):
                    if line.startswith('Version:'):
                        version = line.split(':', 1)[1].strip()
                        print(f"  Found {package_name} version: {version}")
                        return True
            return True
        return False
    except Exception:
        return False

print("Checking current PyTorch installation...")
torch_installed = check_package_installed('torch')
torchvision_installed = check_package_installed('torchvision')
transformers_installed = check_package_installed('transformers')

if torch_installed and torchvision_installed and transformers_installed:
    print("✓ PyTorch ecosystem already installed")
    print("\nVerifying versions...")
    try:
        import torch
        import torchvision
        import transformers
        print(f"  PyTorch: {torch.__version__}")
        print(f"  Torchvision: {torchvision.__version__}")
        print(f"  Transformers: {transformers.__version__}")
        print(f"  CUDA available: {torch.cuda.is_available()}")
        if torch.cuda.is_available():
            print(f"  GPU: {torch.cuda.get_device_name(0)}")
        print("\n✓ PyTorch is ready. You can skip reinstalling if versions are compatible.")
        print("  If you need to reinstall, uncomment the lines below.")
    except ImportError as e:
        print(f"⚠ Import error: {e}")
        print("  Proceeding with installation...")
        torch_installed = False

if not (torch_installed and torchvision_installed and transformers_installed):
    print("\nInstalling PyTorch ecosystem for CUDA 12.1...")
    print("⚠️  This may take a few minutes and may require a runtime restart.")

    # Uninstall existing versions
    !pip uninstall -y torch torchvision transformers 2>/dev/null || true

    # Install PyTorch and torchvision
    !pip install torch torchvision --index-url https://download.pytorch.org/whl/cu121 --no-cache-dir

    # Install transformers and accelerate
    !pip install transformers accelerate --no-cache-dir --upgrade

    print("\n✓ Installation complete!")
    print("⚠️  IMPORTANT: You may need to restart the runtime now.")
    print("   After restart, run Cell 4 (skip this cell).")

    # Verify installation
    try:
        import torch
        import torchvision
        import transformers
        print(f"\n✓ Verification:")
        print(f"  PyTorch: {torch.__version__}")
        print(f"  Torchvision: {torchvision.__version__}")
        print(f"  Transformers: {transformers.__version__}")
        print(f"  CUDA available: {torch.cuda.is_available()}")
        if torch.cuda.is_available():
            print(f"  GPU: {torch.cuda.get_device_name(0)}")
    except ImportError as e:
        print(f"\n⚠️  Import failed: {e}")
        print("   Please restart the runtime and continue from Cell 4.")

In [None]:
# Install ACT-PRM Package and Dependencies

import os
import subprocess
import sys

# Ensure we're in the repo directory
repo_path = '/content/act-prm-tinker'
if not os.path.exists(repo_path):
    print(f"✗ Repository not found at {repo_path}")
    print("  Please run Cell 2 first to clone the repository.")
else:
    %cd {repo_path}
    print(f"✓ Working directory: {os.getcwd()}")

    # Check if act-prm is already installed
    try:
        import act_prm
        print(f"✓ act-prm already installed at: {act_prm.__file__}")
    except ImportError:
        print("Installing act-prm package...")
        !pip install tinker-cookbook
        !pip install -e .
        print("✓ act-prm installed successfully")

    # Verify installation
    try:
        import act_prm
        from act_prm.environments import get_env
        print(f"✓ act-prm can be imported")
        print(f"  Location: {act_prm.__file__}")
    except ImportError as e:
        print(f"✗ Import failed: {e}")
        import traceback
        traceback.print_exc()

In [None]:
# Install BabyAI Dependencies
import os
import sys

%cd /content

# Clone BabyAI repository if needed
if not os.path.exists('/content/Grounding_LLMs_with_online_RL'):
    print("Cloning BabyAI repository...")
    !git clone https://github.com/flowersteam/Grounding_LLMs_with_online_RL.git
    print("✓ Repository cloned successfully")
else:
    print("✓ BabyAI repository already exists")

# Install gym (old API, NOT gymnasium) and supporting libraries
print("\nInstalling gym and BabyAI dependencies...")
!pip install "gym>=0.21,<0.27" blosc colorama termcolor matplotlib

# Install babyai-text
print("\nInstalling babyai-text...")
%cd /content/Grounding_LLMs_with_online_RL/babyai-text
!pip install -e . --no-deps

# Install babyai
print("\nInstalling babyai...")
%cd babyai
!pip install -e . --no-deps

# Install gym-minigrid
print("\nInstalling gym-minigrid...")
%cd ../gym-minigrid
!pip install -e . --no-deps

# Add to Python path (important for imports after runtime restart)
print("\nAdding BabyAI paths to sys.path...")
babyai_paths = [
    '/content/Grounding_LLMs_with_online_RL/babyai-text',
    '/content/Grounding_LLMs_with_online_RL/babyai-text/babyai',
    '/content/Grounding_LLMs_with_online_RL/babyai-text/gym-minigrid',
]

for path in babyai_paths:
    if os.path.exists(path) and path not in sys.path:
        sys.path.insert(0, path)
        print(f"  ✓ Added: {path}")
    elif path in sys.path:
        print(f"  ✓ Already in path: {path}")
    else:
        print(f"  ⚠ Path does not exist: {path}")

# Test imports
print("\nTesting BabyAI imports...")
try:
    import babyai_text
    import babyai
    import gym_minigrid
    print("✓ All BabyAI packages can be imported")
    print(f"  babyai_text location: {babyai_text.__file__}")
    print(f"  babyai location: {babyai.__file__}")
    print(f"  gym_minigrid location: {gym_minigrid.__file__}")
except ImportError as e:
    print(f"✗ Import failed: {e}")
    import traceback
    traceback.print_exc()
    print("\n⚠️  If imports fail, make sure all paths are correct and packages are installed.")

In [None]:
from huggingface_hub import login
from getpass import getpass
import wandb

# Hugging Face authentication
print("Hugging Face Login:")
token = getpass("Paste your Hugging Face token: ")
print("✓ Hugging Face authenticated")

# Weights & Biases authentication
print("\nWeights & Biases Login:")
wandb_key = getpass("Paste your W&B API key: ")
wandb.login(key=wandb_key)
print("✓ Weights & Biases authenticated")

In [None]:
%cd /content/act-prm-tinker

# Remove the flash_attention_2 requirement
!sed -i '/attn_implementation: "flash_attention_2"/d' configs/model/hf_llama3_1_8b_inst.yaml

# Verify the change
!cat configs/model/hf_llama3_1_8b_inst.yaml

In [None]:
%cd /content/act-prm-tinker

# Ensure BabyAI paths are in sys.path (in case kernel was restarted)
import sys
import os

babyai_paths = [
    '/content/Grounding_LLMs_with_online_RL/babyai-text',
    '/content/Grounding_LLMs_with_online_RL/babyai-text/babyai',
    '/content/Grounding_LLMs_with_online_RL/babyai-text/gym-minigrid',
]

for path in babyai_paths:
    if os.path.exists(path) and path not in sys.path:
        sys.path.insert(0, path)


In [None]:
%cd /content/act-prm-tinker

import os
os.environ["PYTHONPATH"] = "src:" + \
    "/content/Grounding_LLMs_with_online_RL/babyai-text:" + \
    "/content/Grounding_LLMs_with_online_RL/babyai-text/babyai:" + \
    "/content/Grounding_LLMs_with_online_RL/babyai-text/gym-minigrid:" + \
    os.environ.get("PYTHONPATH", "")

# Run training
!python main_pytorch.py \
  --env_config act_prm/babyai \
  --base_env_config babyai/default \
  --eval_env_config babyai/eval \
  --generator_config aprm_qwen3_ap \
  --trainer_config aprm_for_sft100 \
  --replay_buffer_config default \
  --model_config hf_llama3_1_8b_inst \
  --lora_config r32_a32_qkvo \
  --log_path ./logs \
  --save_rollouts_every 10 \
  --seed 42 \
  --replicate 0 \
  --verbose