# MAIN-XR-MD Phase-0 · Colab/Local Launcher
Use this notebook to run on Google Colab or locally with GPU, TPU, or CPU. It dynamically installs the right JAX build and optional T5 (Hugging Face Transformers, Flax) dependencies, then runs quick validation steps.

## 1) Runtime Selection
- In Colab, set the runtime you want under: Runtime → Change runtime type → Hardware accelerator: GPU or TPU.
- If left on None (CPU), the notebook will still work.
- You can keep the selection on Auto below; the installer will detect and adapt.

In [None]:
#@title Backend and Options
# Choose backend: 'auto', 'gpu', 'tpu', or 'cpu'
BACKEND = 'auto'  #@param ['auto', 'gpu', 'tpu', 'cpu']
# Install Hugging Face Transformers (Flax) for optional T5 demo
WITH_T5 = False  #@param {type:'boolean'}
# Optional: Git URL of this repo to clone when running on Colab
REPO_URL = 'https://github.com/krisztiaan/research-main-xr.git'  #@param {type:'string'}

import os, sys, subprocess, shlex, json, textwrap, platform
IN_COLAB = 'google.colab' in sys.modules
print(f'IN_COLAB={IN_COLAB}, requested BACKEND={BACKEND}, WITH_T5={WITH_T5}')

def _run(cmd, check=True):
    print('→', cmd)
    return subprocess.run(cmd, shell=True, check=check)

def _pip(*args):
    cmd = f"{shlex.quote(sys.executable)} -m pip install --upgrade " + ' '.join(args)
    return _run(cmd)

def detect_backend(requested='auto'):
    requested = (requested or 'auto').lower()
    # TPU heuristic
    has_tpu = bool(os.environ.get('COLAB_TPU_ADDR'))
    # GPU heuristic
    has_gpu = _run('nvidia-smi >/dev/null 2>&1', check=False).returncode == 0
    if requested == 'auto':
        if has_tpu: return 'tpu'
        if has_gpu: return 'gpu'
        return 'cpu'
    if requested == 'gpu' and not has_gpu:
        print('Requested GPU but none detected — falling back to CPU. Use Runtime → Change runtime type → GPU.')
        return 'cpu'
    if requested == 'tpu' and not has_tpu:
        print('Requested TPU but none detected — falling back to CPU. Use Runtime → Change runtime type → TPU.')
        return 'cpu'
    return requested

def install_stack(backend, with_t5=False):
    print(f'Installing stack for backend={backend} (Colab={IN_COLAB})')
    _pip('pip', 'setuptools', 'wheel')
    if backend == 'gpu':
        # Try CUDA 12 wheels first, then 11 as fallback.
        try:
            _pip('"jax[cuda12_pip]"', '-f', 'https://storage.googleapis.com/jax-releases/jax_cuda_releases.html')
        except Exception:
            print('CUDA 12 wheels failed; trying CUDA 11…')
            _pip('"jax[cuda11_pip]"', '-f', 'https://storage.googleapis.com/jax-releases/jax_cuda_releases.html')
        os.environ.setdefault('XLA_PYTHON_CLIENT_MEM_FRACTION', '0.85')
    elif backend == 'tpu':
        _pip('"jax[tpu]"', '-f', 'https://storage.googleapis.com/jax-releases/libtpu_releases.html')
    else:  # cpu
        _pip('jax')
    # Project deps (lightweight)
    # If editable project present, prefer that; otherwise just install runtime deps.
    if os.path.exists('pyproject.toml') or os.path.exists('setup.cfg') or os.path.exists('setup.py'):
        try:
            _pip('-e', '.')
        except Exception as e:
            print('Editable install failed; continuing with requirements if present.', e)
    if os.path.exists('requirements.txt'):
        try:
            _pip('-r', 'requirements.txt')
        except Exception as e:
            print('requirements.txt install failed; continuing.', e)
    else:
        # Fallback minimal deps for standalone Colab notebook
        base = [
            'flax>=0.12.0', 'optax>=0.2.6', 'chex>=0.1.91',
            'gymnasium>=1.2.1', 'gymnax>=0.0.9', 'craftax>=1.5.0',
            'orjson>=3.11.3', 'numpy>=2.3.3', 'tqdm>=4.67.1',
            'msgpack>=1.1.1', 'msgpack-numpy>=0.4.8'
        ]
        _pip(*base)
    if with_t5:
        _pip('transformers', 'datasets', 'accelerate', 'sentencepiece', 'huggingface_hub')
    print('✓ Install complete')

BACKEND = detect_backend(BACKEND)
if BACKEND == 'cpu':
    os.environ.setdefault('JAX_PLATFORMS', 'cpu')
install_stack(BACKEND, with_t5=WITH_T5)

print('Environment ready.')


In [None]:
# Verify JAX devices
import os, jax
print('JAX default backend:', jax.default_backend())
print('JAX devices:', jax.devices())
if os.environ.get('JAX_PLATFORMS'):
    print('JAX_PLATFORMS=', os.environ['JAX_PLATFORMS'])


## 2) (Optional) Clone Repo for Editing
If you opened just this notebook in Colab and want to edit code, set `REPO_URL` above and run the cell below.

In [None]:
import os, subprocess
if REPO_URL:
    repo_name = REPO_URL.rstrip('/').split('/')[-1]
    if repo_name.endswith('.git'): repo_name = repo_name[:-4]
    if not os.path.isdir(repo_name):
        subprocess.check_call(['git', 'clone', REPO_URL])
    print('Repo available at:', repo_name)
else:
    print('REPO_URL not set; skipping clone.')


## 3) Quick Rollout (Project)
Runs a short training loop to verify wiring. Adjust to your needs.

In [None]:
%%bash
set -euxo pipefail
# If you cloned your repo above, `cd` into it; otherwise use current dir.
if [ -d main_xr_md_jax_phase0 ]; then cd main_xr_md_jax_phase0; fi
python -m mxrmd_jax.train_jax \
  --env craftax \
  --env-id craftax-classic-v1 \
  --num-envs 256 \
  --unroll 32 \
  --total-frames 131072 \
  --run-dir runs/colab_quickcheck


## 4) Optional: T5 (Flax) Smoke Test
If `WITH_T5` is enabled above, run a tiny text2text generation using FLAN-T5.

In [None]:
if WITH_T5:
    import jax, jax.numpy as jnp
    from transformers import AutoTokenizer, FlaxAutoModelForSeq2SeqLM
    model_name = 'google/flan-t5-small'
    tok = AutoTokenizer.from_pretrained(model_name)
    model = FlaxAutoModelForSeq2SeqLM.from_pretrained(model_name)
    inputs = tok('translate English to German: A beautiful day.', return_tensors='np')
    gen = model.generate(**inputs, max_length=40)
    print(tok.batch_decode(gen.sequences, skip_special_tokens=True)[0])
else:
    print('WITH_T5 is False; skipping T5 demo.')


## 5) Save Artifacts (Colab)
Mount Google Drive if you plan to keep checkpoints.

In [None]:
try:
    from google.colab import drive
    drive.mount('/content/drive')
    # Adjust paths as needed
    !rsync -a /content/runs/ /content/drive/MyDrive/mxrmd_colab_runs
except Exception as e:
    print('Drive mount not available or failed; skipping.', e)
