Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 20 additions & 3 deletions build.sh
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,23 @@ if [ -z "$CUDNN_LFLAG" ]; then
CUDNN_LFLAG=$(python -c "import nvidia.cudnn, os; print('-L' + os.path.join(nvidia.cudnn.__path__[0], 'lib'))" 2>/dev/null || echo "")
fi

# NCCL include/lib fallback (mirrors the cuDNN fallback above).
# Needed when NCCL is provided by the nvidia-nccl-cu12 wheel in the active venv.
NCCL_IFLAG=""
NCCL_LFLAG=""
for dir in /usr/include /usr/local/cuda/include; do
if [ -f "$dir/nccl.h" ]; then NCCL_IFLAG="-I$dir"; break; fi
done
for dir in /usr/lib/x86_64-linux-gnu /usr/local/cuda/lib64; do
if [ -f "$dir/libnccl.so" ] || [ -f "$dir/libnccl.so.2" ]; then NCCL_LFLAG="-L$dir"; break; fi
done
if [ -z "$NCCL_IFLAG" ]; then
NCCL_IFLAG=$(python -c "import nvidia.nccl, os; print('-I' + os.path.join(nvidia.nccl.__path__[0], 'include'))" 2>/dev/null || echo "")
fi
if [ -z "$NCCL_LFLAG" ]; then
NCCL_LFLAG=$(python -c "import nvidia.nccl, os; print('-L' + os.path.join(nvidia.nccl.__path__[0], 'lib'))" 2>/dev/null || echo "")
fi

export CCACHE_DIR="${CCACHE_DIR:-$HOME/.ccache}"
export CCACHE_BASEDIR="$(pwd)"
export CCACHE_COMPILERCHECK=content
Expand Down Expand Up @@ -240,7 +257,7 @@ if [ -z "$MODE" ]; then
-std=c++17 \
-I. -Isrc \
-I$PYTHON_INCLUDE -I$PYBIND_INCLUDE -I$NUMPY_INCLUDE \
-I$CUDA_HOME/include $CUDNN_IFLAG -I$RAYLIB_NAME/include \
-I$CUDA_HOME/include $CUDNN_IFLAG $NCCL_IFLAG -I$RAYLIB_NAME/include \
-Xcompiler=-fopenmp \
-DOBS_TENSOR_T=$OBS_TENSOR_T \
-DENV_NAME=$ENV \
Expand All @@ -250,7 +267,7 @@ if [ -z "$MODE" ]; then
LINK_CMD=(
${CXX:-g++} -shared -fPIC -fopenmp
build/bindings.o "$STATIC_LIB" "$RAYLIB_A"
-L$CUDA_HOME/lib64 $CUDNN_LFLAG
-L$CUDA_HOME/lib64 $CUDNN_LFLAG $NCCL_LFLAG
-lcudart -lnccl -lnvidia-ml -lcublas -lcusolver -lcurand -lcudnn
$OMP_LIB $LINK_OPT
"${SHARED_LDFLAGS[@]}"
Expand Down Expand Up @@ -285,7 +302,7 @@ elif [ "$MODE" = "profile" ]; then
echo "Compiling profile binary ($ARCH)..."
$NVCC $NVCC_OPT -arch=$ARCH -std=c++17 \
-I. -Isrc -I$SRC_DIR -Ivendor \
-I$CUDA_HOME/include $CUDNN_IFLAG -I$RAYLIB_NAME/include \
-I$CUDA_HOME/include $CUDNN_IFLAG $NCCL_IFLAG -I$RAYLIB_NAME/include \
-DOBS_TENSOR_T=$OBS_TENSOR_T \
-DENV_NAME=$ENV \
-Xcompiler=-DPLATFORM_DESKTOP \
Expand Down
12 changes: 12 additions & 0 deletions config/ocean/craftax.ini
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
[base]
env_name = craftax

[vec]
total_agents = 8192
num_buffers = 4
num_threads = 16

[env]

[train]
total_timesteps = 200_000_000
34 changes: 34 additions & 0 deletions ocean/craftax/binding.c
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
#include "craftax.h"

#define OBS_SIZE 1345
#define NUM_ATNS 1
#define ACT_SIZES {17}
#define OBS_TENSOR_T FloatTensor

#define Env Craftax
#include "vecenv.h"

void my_init(Env* env, Dict* kwargs) {
// No per-env kwargs for Craftax-Classic: the 64x64 map, inventory sizes,
// mob caps, etc. are all compile-time constants.
c_init(env);
}

void my_log(Log* log, Dict* out) {
dict_set(out, "perf", log->perf);
dict_set(out, "score", log->score);
dict_set(out, "episode_return", log->episode_return);
dict_set(out, "episode_length", log->episode_length);

static const char* ACH_NAMES[NUM_ACHIEVEMENTS] = {
"collect_wood", "place_table", "eat_cow", "collect_sapling",
"collect_drink", "make_wood_pick", "make_wood_sword","place_plant",
"defeat_zombie", "collect_stone", "place_stone", "eat_plant",
"defeat_skeleton","make_stone_pick","make_stone_sword","wake_up",
"place_furnace", "collect_coal", "collect_iron", "collect_diamond",
"make_iron_pick", "make_iron_sword",
};
for (int i = 0; i < NUM_ACHIEVEMENTS; i++) {
dict_set(out, ACH_NAMES[i], log->achievements[i]);
}
}
Loading