# Disocraft Colab Training

This notebook clones the repo, initializes submodules, installs deps,
applies the DiscoRL patch, and runs training.

In [None]:
!git clone https://github.com/Maharishiva/disocraft.git
%cd disocraft
!git submodule update --init --recursive

In [None]:
# Install JAX with CUDA (A100).
!pip install -U "jax[cuda12]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html

In [None]:
!pip install -r requirements.txt

In [None]:
!python scripts/patch_disco_rl.py

In [None]:
!mkdir -p runs

In [None]:
# Short test run (smoke test).
!MPLCONFIGDIR=./.mplconfig python train.py \
  --num_iterations 20 \
  --num_envs 1 \
  --rollout_len 32 \
  --batch_size 8 \
  --log_every 1

In [None]:
# Full run (~100k env steps).
!MPLCONFIGDIR=./.mplconfig python train.py \
  --num_iterations 400 \
  --num_envs 1 \
  --rollout_len 256 \
  --log_every 10 |& tee runs/train_100k.log

In [None]:
# Plot returns.
import re
import matplotlib.pyplot as plt

xs, ys = [], []
with open('runs/train_100k.log') as f:
    for line in f:
        m = re.search(r'steps=(\d+).*avg_return=([0-9.]+)', line)
        if m:
            xs.append(int(m.group(1)))
            ys.append(float(m.group(2)))

plt.plot(xs, ys)
plt.xlabel('env steps')
plt.ylabel('avg_return')
plt.title('Craftax Disco103')
plt.grid(True, alpha=0.3)
plt.show()