# Disocraft Colab Training

This notebook clones the repo, initializes submodules, installs deps,
applies the DiscoRL patch, and runs training.

In [16]:
!git clone https://github.com/Maharishiva/disocraft.git
%cd disocraft
!git submodule update --init --recursive

Cloning into 'disocraft'...
remote: Enumerating objects: 44, done.[K
remote: Counting objects: 100% (44/44), done.[K
remote: Compressing objects: 100% (33/33), done.[K
remote: Total 44 (delta 17), reused 35 (delta 8), pack-reused 0 (from 0)[K
Receiving objects: 100% (44/44), 17.91 KiB | 17.91 MiB/s, done.
Resolving deltas: 100% (17/17), done.
/content/disocraft/disocraft
Submodule 'external/Craftax' (https://github.com/MichaelTMatthews/Craftax.git) registered for path 'external/Craftax'
Submodule 'external/disco_rl' (https://github.com/google-deepmind/disco_rl.git) registered for path 'external/disco_rl'
Cloning into '/content/disocraft/disocraft/external/Craftax'...
Cloning into '/content/disocraft/disocraft/external/disco_rl'...
Submodule path 'external/Craftax': checked out 'bcc0fae62060579de4dee63dbb4e8dc729588173'
Submodule path 'external/disco_rl': checked out '9059a29f7121d60948f25ef165e08e050e9399c8'


In [17]:
# Install JAX with CUDA (A100).
!pip install -U "jax[cuda12]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html

Looking in links: https://storage.googleapis.com/jax-releases/jax_cuda_releases.html


In [18]:
!pip install -r requirements.txt

Obtaining file:///content/disocraft/disocraft/external/disco_rl (from -r requirements.txt (line 1))
  Installing build dependencies ... [?25l[?25hdone
  Checking if build backend supports build_editable ... [?25l[?25hdone
  Getting requirements to build editable ... [?25l[?25hdone
  Preparing editable metadata (pyproject.toml) ... [?25l[?25hdone
Obtaining file:///content/disocraft/disocraft/external/Craftax (from -r requirements.txt (line 2))
  Installing build dependencies ... [?25l[?25hdone
  Checking if build backend supports build_editable ... [?25l[?25hdone
  Getting requirements to build editable ... [?25l[?25hdone
  Preparing editable metadata (pyproject.toml) ... [?25l[?25hdone
Building wheels for collected packages: disco_rl, craftax
  Building editable for disco_rl (pyproject.toml) ... [?25l[?25hdone
  Created wheel for disco_rl: filename=disco_rl-1.0.0-0.editable-py3-none-any.whl size=8825 sha256=e04c06cd7ad8eb40627a9216f78fdce3152db85cb11b10d41ddb0655ded8b

In [19]:
!python scripts/patch_disco_rl.py

Patched /content/disocraft/disocraft/external/disco_rl/disco_rl/networks/meta_nets.py


In [20]:
!mkdir -p runs

In [None]:
# Short test run (smoke test).
# Uses >=4 envs and batch_size=4 to trigger a learner update on iter 1.
!XLA_PYTHON_CLIENT_PREALLOCATE=false MPLCONFIGDIR=./.mplconfig python train.py \
  --num_iterations 5 \
  --num_envs 4 \
  --rollout_len 32 \
  --batch_size 4 \
  --buffer_capacity_transitions 10000 \
  --log_every 1

2026-01-11 01:00:28.695595: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1768093228.716898   25804 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1768093228.723427   25804 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
W0000 00:00:1768093228.740032   25804 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1768093228.740058   25804 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1768093228.740061   25804 computation_placer.cc:177] computation placer alr

In [None]:
# 1M env steps (~1e6).
# steps ~= num_iterations * rollout_len * num_envs
!XLA_PYTHON_CLIENT_PREALLOCATE=false MPLCONFIGDIR=./.mplconfig python train.py \
  --num_iterations 1078 \
  --num_envs 32 \
  --rollout_len 29 \
  --batch_size 24 \
  --replay_fraction 0.99 \
  --buffer_capacity_transitions 400000 \
  --learning_rate 3e-4 \
  --log_every 10 |& tee runs/train_1m.log


In [None]:
# 20M env steps (~2e7).
# steps ~= num_iterations * rollout_len * num_envs
!XLA_PYTHON_CLIENT_PREALLOCATE=false MPLCONFIGDIR=./.mplconfig python train.py \
  --num_iterations 21552 \
  --num_envs 32 \
  --rollout_len 29 \
  --batch_size 24 \
  --replay_fraction 0.99 \
  --buffer_capacity_transitions 400000 \
  --learning_rate 3e-4 \
  --log_every 10 |& tee runs/train_20m.log


In [None]:
# Plot returns.
import re
import matplotlib.pyplot as plt

xs, ys = [], []
log_path = 'runs/train_1m.log'
# Set log_path to runs/train_20m.log for the longer run.
with open(log_path) as f:
    for line in f:
        m = re.search(r'steps=(\d+).*avg_return=([0-9.]+)', line)
        if m:
            xs.append(int(m.group(1)))
            ys.append(float(m.group(2)))

plt.plot(xs, ys)
plt.xlabel('env steps')
plt.ylabel('avg_return')
plt.title('Craftax Disco103')
plt.grid(True, alpha=0.3)
plt.show()