Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion pufferlib/ocean/breakout/binding.c
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
#define NUM_ATNS 1
#define ACT_SIZES {3}
#define OBS_TENSOR_T FloatTensor
#define ACT_TYPE DOUBLE

#define Env Breakout
#include "vecenv.h"
Expand Down
4 changes: 2 additions & 2 deletions pufferlib/ocean/breakout/breakout.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ typedef struct Breakout {
Client* client;
Log log;
float* observations;
double* actions;
float* actions;
float* rewards;
float* terminals;
int num_agents;
Expand Down Expand Up @@ -121,7 +121,7 @@ void init(Breakout* env) {
void allocate(Breakout* env) {
init(env);
env->observations = (float*)calloc(11 + env->num_bricks, sizeof(float));
env->actions = (double*)calloc(1, sizeof(double));
env->actions = (float*)calloc(1, sizeof(float));
env->rewards = (float*)calloc(1, sizeof(float));
env->terminals = (float*)calloc(1, sizeof(float));
}
Expand Down
4 changes: 0 additions & 4 deletions pufferlib/src/bindings.cu
Original file line number Diff line number Diff line change
Expand Up @@ -453,10 +453,6 @@ PYBIND11_MODULE(_C, m) {
.def("__repr__", [](const PrecisionTensor& t) { return std::string(puf_repr(&t)); })
.def("ndim", [](const PrecisionTensor& t) { return ndim(t.shape); })
.def("numel", [](const PrecisionTensor& t) { return numel(t.shape); });
py::class_<DoubleTensor>(m, "DoubleTensor")
.def("__repr__", [](const DoubleTensor& t) { return std::string(puf_repr(&t)); })
.def("ndim", [](const DoubleTensor& t) { return ndim(t.shape); })
.def("numel", [](const DoubleTensor& t) { return numel(t.shape); });
py::class_<FloatTensor>(m, "FloatTensor")
.def("__repr__", [](const FloatTensor& t) { return std::string(puf_repr(&t)); })
.def("ndim", [](const FloatTensor& t) { return ndim(t.shape); })
Expand Down
19 changes: 0 additions & 19 deletions pufferlib/src/kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -138,18 +138,6 @@ __global__ void transpose_102(precision_t* __restrict__ dst,
dst[b * A * C + a * C + c] = src[idx];
}

// This exists for actions (currently fp64)
__global__ void transpose_102(double* __restrict__ dst,
const double* __restrict__ src, int A, int B, int C) {
int idx = blockIdx.x * blockDim.x + threadIdx.x;
int total = A * B * C;
if (idx >= total) {
return;
}
int a = idx / (B * C), rem = idx % (B * C), b = rem / C, c = rem % C;
dst[b * A * C + a * C + c] = src[idx];
}

__global__ void fill_precision_kernel(precision_t* __restrict__ dst, precision_t val, int n) {
int idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx < n) {
Expand Down Expand Up @@ -247,10 +235,6 @@ inline const char* puf_repr(const PrecisionTensor* t) {
return _puf_repr_impl("PrecisionTensor", USE_BF16 ? "bf16" : "f32",
t->shape, ndim(t->shape), numel(t->shape), !t->data);
}
inline const char* puf_repr(const DoubleTensor* t) {
return _puf_repr_impl("DoubleTensor", "f64",
t->shape, ndim(t->shape), numel(t->shape), !t->data);
}
inline const char* puf_repr(const FloatTensor* t) {
return _puf_repr_impl("FloatTensor", "f32",
t->shape, ndim(t->shape), numel(t->shape), !t->data);
Expand Down Expand Up @@ -431,9 +415,6 @@ void alloc_register(Allocator* a, PrecisionTensor* t) {
void alloc_register(Allocator* a, FloatTensor* t) {
alloc_register_impl(a, (void**)&t->data, t->shape, sizeof(float));
}
void alloc_register(Allocator* a, DoubleTensor* t) {
alloc_register_impl(a, (void**)&t->data, t->shape, sizeof(double));
}
void alloc_register(Allocator* a, LongTensor* t) {
alloc_register_impl(a, (void**)&t->data, t->shape, sizeof(long));
}
Expand Down
36 changes: 13 additions & 23 deletions pufferlib/src/pufferlib.cu
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ enum LossIdx {

struct RolloutBuf {
PrecisionTensor observations; // (horizon, segments, input_size)
DoubleTensor actions; // (horizon, segments, num_atns)
PrecisionTensor actions; // (horizon, segments, num_atns)
PrecisionTensor values; // (horizon, segments)
PrecisionTensor logprobs; // (horizon, segments)
PrecisionTensor rewards; // (horizon, segments)
Expand All @@ -29,7 +29,7 @@ struct RolloutBuf {
struct TrainGraph {
PrecisionTensor mb_obs; // (S, H, input_size)
PrecisionTensor mb_state; // (L, S, 1, hidden)
DoubleTensor mb_actions; // (S, H, num_atns)
PrecisionTensor mb_actions; // (S, H, num_atns)
PrecisionTensor mb_logprobs; // (S, H)
FloatTensor mb_advantages; // (S, H) f32
PrecisionTensor mb_prio; // (S, 1)
Expand All @@ -44,7 +44,7 @@ struct TrainGraph {
struct PPOGraphArgs {
precision_t* out_ratio;
precision_t* out_newvalue;
const double* actions;
const precision_t* actions;
const precision_t* old_logprobs;
const float* advantages;
const precision_t* prio;
Expand Down Expand Up @@ -76,7 +76,7 @@ struct PPOKernelArgs {
// Pre-allocated buffers for PPO loss
struct PPOBuffersPuf {
FloatTensor loss_output, grad_loss;
DoubleTensor saved_for_bwd;
FloatTensor saved_for_bwd;
FloatTensor grad_logits, grad_values, grad_logstd, adv_scratch;
};

Expand Down Expand Up @@ -185,19 +185,10 @@ inline PrecisionTensor puf_slice(PrecisionTensor& p, int t, int start, int count
return {.data = p.data + (t*S + start), .shape = {count}};
}
}
inline DoubleTensor puf_slice(DoubleTensor& p, int t, int start, int count) {
if (ndim(p.shape) == 3) {
long S = p.shape[1], F = p.shape[2];
return {.data = p.data + (t*S + start)*F, .shape = {count, F}};
} else {
long S = p.shape[1];
return {.data = p.data + (t*S + start), .shape = {count}};
}
}

struct EnvBuf {
OBS_TENSOR_T obs; // (total_agents, obs_size) — type defined per-env in binding.c
DoubleTensor actions; // (total_agents, num_atns) f64
FloatTensor actions; // (total_agents, num_atns) f64
FloatTensor rewards; // (total_agents,) f32
FloatTensor terminals;// (total_agents,) f32
};
Expand All @@ -210,7 +201,7 @@ StaticVec* create_environments(int num_buffers, int total_agents,
.shape = {total_agents, get_obs_size()},
};
env.actions = {
.data = (double*)vec->gpu_actions,
.data = (float*)vec->gpu_actions,
.shape = {total_agents, get_num_atns()},
};
env.rewards = {
Expand Down Expand Up @@ -392,7 +383,7 @@ __global__ void sample_logits_kernel(
PrecisionTensor dec_out, // (B, fused_cols) fused logits+value from decoder
PrecisionTensor logstd_puf, // (1, od) log std for continuous, or empty
IntTensor act_sizes_puf, // (num_atns,) action head sizes
double* __restrict__ actions, // (B, num_atns) output
precision_t* __restrict__ actions, // (B, num_atns) output
precision_t* __restrict__ logprobs, // (B,) output
precision_t* __restrict__ value_out, // (B,) output
uint64_t seed,
Expand Down Expand Up @@ -443,7 +434,7 @@ __global__ void sample_logits_kernel(
float normalized = (action - mean) / std;
float log_prob = -0.5f * normalized * normalized - 0.5f * LOG_2PI - log_std;

actions[idx * num_atns + h] = double(action);
actions[idx * num_atns + h] = from_float(action);
total_log_prob += log_prob;
}
} else {
Expand Down Expand Up @@ -514,7 +505,7 @@ __global__ void sample_logits_kernel(
float log_prob = sampled_logit - logsumexp;

// Write action for this head
actions[idx * num_atns + h] = double(sampled_action);
actions[idx * num_atns + h] = from_float(sampled_action);
total_log_prob += log_prob;

// Advance to next action head
Expand Down Expand Up @@ -584,7 +575,7 @@ extern "C" void net_callback_wrapper(void* ctx, int buf, int t) {
PrecisionTensor dec_puf = policy_forward(&pufferl->policy, pufferl->weights, pufferl->buffer_activations[buf], obs_dst, state_puf, stream);

// Sample actions, logprobs, values into rollout buffer
DoubleTensor act_slice = puf_slice(rollouts.actions, t, start, block_size);
PrecisionTensor act_slice = puf_slice(rollouts.actions, t, start, block_size);
PrecisionTensor lp_slice = puf_slice(rollouts.logprobs, t, start, block_size);
PrecisionTensor val_slice = puf_slice(rollouts.values, t, start, block_size);

Expand All @@ -604,9 +595,8 @@ extern "C" void net_callback_wrapper(void* ctx, int buf, int t) {

// Copy actions to env
long act_cols = env.actions.shape[1];
cudaMemcpyAsync(
env.actions.data + start * act_cols,
act_slice.data, numel(act_slice.shape) * sizeof(double), cudaMemcpyDeviceToDevice, stream);
cast_kernel<<<grid_size(numel(act_slice.shape)), BLOCK_SIZE, 0, stream>>>(
env.actions.data + start * act_cols, act_slice.data, numel(act_slice.shape));

if (capturing) {
cudagraph_capture_end(&pufferl->fused_rollout_cudagraphs[graph], cap_stream_raw);
Expand Down Expand Up @@ -1307,7 +1297,7 @@ __global__ void select_copy_kernel(

// Compute row byte counts from tensor shapes
int obs_row_bytes = (numel(rollouts.observations.shape) / rollouts.observations.shape[0]) * sizeof(precision_t);
int act_row_bytes = (numel(rollouts.actions.shape) / rollouts.actions.shape[0]) * sizeof(double);
int act_row_bytes = (numel(rollouts.actions.shape) / rollouts.actions.shape[0]) * sizeof(precision_t);
int lp_row_bytes = (numel(rollouts.logprobs.shape) / rollouts.logprobs.shape[0]) * sizeof(precision_t);
int horizon = rollouts.values.shape[1];

Expand Down
5 changes: 0 additions & 5 deletions pufferlib/src/tensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,6 @@ typedef struct {
int64_t shape[PUF_MAX_DIMS];
} FloatTensor;

typedef struct {
double* data;
int64_t shape[PUF_MAX_DIMS];
} DoubleTensor;

typedef struct {
unsigned char* data;
int64_t shape[PUF_MAX_DIMS];
Expand Down
16 changes: 8 additions & 8 deletions pufferlib/src/vecenv.h
Original file line number Diff line number Diff line change
Expand Up @@ -79,11 +79,11 @@ typedef struct StaticVec {
int* buffer_env_starts;
int* buffer_env_counts;
void* observations;
double* actions;
float* actions;
float* rewards;
float* terminals;
void* gpu_observations;
double* gpu_actions;
float* gpu_actions;
float* gpu_rewards;
float* gpu_terminals;
cudaStream_t* streams;
Expand Down Expand Up @@ -252,7 +252,7 @@ static void* static_omp_threadmanager(void* arg) {
cudaMemcpyAsync(
&vec->actions[agent_start * NUM_ATNS],
&vec->gpu_actions[agent_start * NUM_ATNS],
agents_per_buffer * NUM_ATNS * sizeof(double),
agents_per_buffer * NUM_ATNS * sizeof(float),
cudaMemcpyDeviceToHost, stream);
cudaStreamSynchronize(stream);
clock_gettime(CLOCK_MONOTONIC, &t1);
Expand Down Expand Up @@ -384,17 +384,17 @@ StaticVec* create_static_vec(int total_agents, int num_buffers, Dict* vec_kwargs

size_t obs_elem_size = obs_element_size();
cudaHostAlloc((void**)&vec->observations, total_agents * OBS_SIZE * obs_elem_size, cudaHostAllocPortable);
cudaHostAlloc((void**)&vec->actions, total_agents * NUM_ATNS * sizeof(double), cudaHostAllocPortable);
cudaHostAlloc((void**)&vec->actions, total_agents * NUM_ATNS * sizeof(float), cudaHostAllocPortable);
cudaHostAlloc((void**)&vec->rewards, total_agents * sizeof(float), cudaHostAllocPortable);
cudaHostAlloc((void**)&vec->terminals, total_agents * sizeof(float), cudaHostAllocPortable);

cudaMalloc((void**)&vec->gpu_observations, total_agents * OBS_SIZE * obs_elem_size);
cudaMalloc((void**)&vec->gpu_actions, total_agents * NUM_ATNS * sizeof(double));
cudaMalloc((void**)&vec->gpu_actions, total_agents * NUM_ATNS * sizeof(float));
cudaMalloc((void**)&vec->gpu_rewards, total_agents * sizeof(float));
cudaMalloc((void**)&vec->gpu_terminals, total_agents * sizeof(float));

cudaMemset(vec->gpu_observations, 0, total_agents * OBS_SIZE * obs_elem_size);
cudaMemset(vec->gpu_actions, 0, total_agents * NUM_ATNS * sizeof(double));
cudaMemset(vec->gpu_actions, 0, total_agents * NUM_ATNS * sizeof(float));
cudaMemset(vec->gpu_rewards, 0, total_agents * sizeof(float));
cudaMemset(vec->gpu_terminals, 0, total_agents * sizeof(float));

Expand Down Expand Up @@ -483,7 +483,7 @@ void static_vec_close(StaticVec* vec) {

cudaDeviceSynchronize();
size_t obs_bytes = vec->total_agents * OBS_SIZE * obs_element_size();
size_t act_bytes = vec->total_agents * NUM_ATNS * sizeof(double);
size_t act_bytes = vec->total_agents * NUM_ATNS * sizeof(float);
size_t rew_bytes = vec->total_agents * sizeof(float);
size_t term_bytes = vec->total_agents * sizeof(float);
cudaFree(vec->gpu_observations);
Expand Down Expand Up @@ -578,7 +578,7 @@ size_t get_obs_elem_size(void) { return obs_element_size(); }
void static_vec_step(StaticVec* vec) {
// D2H: copy GPU actions to CPU pinned memory so envs can read them
cudaMemcpy(vec->actions, vec->gpu_actions,
(size_t)vec->total_agents * NUM_ATNS * sizeof(double),
(size_t)vec->total_agents * NUM_ATNS * sizeof(float),
cudaMemcpyDeviceToHost);

memset(vec->rewards, 0, vec->total_agents * sizeof(float));
Expand Down
Loading