# MAIN-XR-MD Phase-0 · Colab Launcher
Use this notebook to spin up a Google Colab session with GPU support, install CUDA-enabled JAX, and run a short validation rollout for the project.

## 1. Select GPU Runtime
Before running any cells below, open **Runtime → Change runtime type → GPU**. If you only need a CPU smoke test, you can keep the default CPU runtime and set `JAX_PLATFORMS=cpu` in later cells.

In [None]:
!nvidia-smi || echo 'No NVIDIA GPU attached; fallback to CPU runtime.'

## 2. Install Dependencies
Installs CUDA-enabled JAX and this repository from Git (change the URL if you are working from a fork).

In [None]:
%%bash
set -euxo pipefail
python -m pip install --upgrade pip setuptools wheel
python -m pip install --upgrade "jax[cuda12_pip]>=0.4.30" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
python -m pip install --upgrade git+https://github.com/krisztiaan/main_xr_md_jax_phase0.git

## 3. Clone Repo Snapshot (Editable)
Colab installs the library above; clone if you want to edit code inline.

In [None]:
%%bash
set -euxo pipefail
if [ ! -d main_xr_md_jax_phase0 ]; then
  git clone https://github.com/krisztiaan/main_xr_md_jax_phase0.git
fi

## 4. Quick GPU Rollout
Runs a short training loop to verify everything is wired correctly. Adjust hyperparameters for longer jobs.

In [None]:
%%bash
set -euxo pipefail
cd main_xr_md_jax_phase0
python -m pip install -r requirements.txt
python -m mxrmd_jax.train_jax \
  --env craftax \
  --env-id craftax-classic-v1 \
  --num-envs 256 \
  --unroll 16 \
  --total-frames 262144 \
  --r-reset 20 \
  --run-dir /content/runs/colab_demo
ls -R /content/runs/colab_demo

## 5. Save Artifacts
Mount Google Drive if you plan to keep checkpoints beyond this session.

In [None]:
from google.colab import drive
drive.mount('/content/drive')
!rsync -a /content/runs/colab_demo /content/drive/MyDrive/mxrmd_colab_runs