Skip to content

Profiling cleanup and optimized kernels#485

Merged
jsuarez5341 merged 12 commits intoPufferAI:4.0from
jonahsamost:jonah_train_select_copy
Feb 13, 2026
Merged

Profiling cleanup and optimized kernels#485
jsuarez5341 merged 12 commits intoPufferAI:4.0from
jonahsamost:jonah_train_select_copy

Conversation

@jonahsamost
Copy link

@jonahsamost jonahsamost commented Feb 11, 2026

This PR does several things.

First, we add a few composite training profiles. In profile_train.cu, trainforward looks at the forward, loss, and backwards and trainstep is a full end to end training step. And profile_rolloutcopy.cu profiles the data prepartaion (i.e. advantage computation, priority sampling, and select + copy).

You can also test both float32 and bf16 (using --precision=float or --precision=bf16)

Next we optimize the compute_advantage kernel (~2.25x), compute_prio (~2.5) code, and train_select_and_copy (~10x) code.

The main speed up in the compute advantage kernel comes from 128-bit loads.
The main speed up from the compute prio is fusing ops together into their own kernels.
the main speed up from the select and copy is fusing all the copies into a single kernel.

If you build and run:

python setup.py build_profiler --precision=bf16
./profile_torch rolloutcopy
...
========================================
rolloutcopy (S=8192, T=64, mb_segs=512, in=96, A=4, H=128)
  rollout_rows=8192, minibatch=32768, using bf16
========================================

--- Advantage Kernel Comparison (S=8192, T=64) ---
  advantage (vectorized)           21.2 us    386.14 M elem/s
  advantage (scalar)               47.7 us    171.62 M elem/s
  advantage (torch)              8619.5 us      0.95 M elem/s
  vectorized vs scalar:  2.25x
  vectorized vs torch:   406.29x
  correctness (vec vs scalar): PASS (max_diff=0.00e+00)
  correctness (vec vs torch):  PASS (max_diff=0.00e+00)

--- Prio Kernel Comparison (S=8192, T=64, mb=512) ---
  prio (torch)                    195.5 us     41.90 M elem/s
  prio (kernel)                    75.7 us    108.27 M elem/s
  kernel vs torch:       2.58x
  correctness (mb_prio):  PASS (max_diff=1.19e-07)

--- Select+Copy Kernel Comparison (mb=512) ---
  select+copy (torch)             143.0 us      3.58 M elem/s
  select+copy (kernel)             10.4 us     49.04 M elem/s
  kernel vs torch:       13.70x
  obs              PASS (max_diff=0.00e+00)
  actions          PASS (max_diff=0.00e+00)
  logprobs         PASS (max_diff=0.00e+00)
  advantages       PASS (max_diff=0.00e+00)
  values           PASS (max_diff=0.00e+00)
  returns          PASS (max_diff=0.00e+00)

--- Per-Phase Timing ---
  compute_advantage                21.2 us    386.14 M elem/s
  compute_prio                     75.7 us    108.27 M elem/s
  train_select_and_copy            10.4 us     49.04 M elem/s

--- Full Rollout Copy (one minibatch iteration) ---
  rolloutcopy (full)              105.8 us     77.44 M elem/s

--- Proportional Breakdown ---
  compute_advantage                21.2 us   19.8%
  compute_prio                     75.7 us   70.5%
  train_select_and_copy            10.4 us    9.7%
  total (sum of phases)           107.3 us  100.0%
  full rolloutcopy actual         105.8 us  (measured)

--- trainforward ---
========================================
trainforward (N=512, T=64, in=96, H=128, A=4, layers=1)
  minibatch=32768, using bf16
========================================

--- Forward + Loss (no backward) ---
  forward+loss (kernel)           112.6 us    291.06 M elem/s
  forward+loss (cpp)              691.7 us     47.37 M elem/s

--- Forward + Loss + Backward ---
  fwd+loss+bwd (kernel)          2818.6 us     11.63 M elem/s
  fwd+loss+bwd (cpp)             7986.8 us      4.10 M elem/s

--- Phase Breakdown (forward only, no autograd) ---
  encoder (linear)                 19.0 us   18.3%
  rnn (fused_scan)                 69.4 us   66.9%
  decoder (linear)                 15.3 us   14.8%
  total (sum of phases)           103.7 us  100.0%
  forward+loss actual             112.6 us  (measured)

--- trainstep ---
========================================
trainstep (N=512, T=64, in=96, H=128, A=4, layers=1)
  minibatch=32768, using bf16, optimizer=Muon
========================================

--- Full Training Step: fwd + loss + bwd + clip + Muon + sync ---
  trainstep (kernel)             5297.2 us      6.19 M elem/s
  trainstep (cpp)               11213.5 us      2.92 M elem/s

--- Training Step Breakdown (instrumented) ---
  forward                         357.2 us    6.0%
  loss                            198.5 us    3.3%
  backward                       1843.1 us   30.9%
  grad_sync+clip                  594.6 us   10.0%
  Muon step                      2803.9 us   47.0%
  weight_sync                     171.4 us    2.9%
  total step                     5968.6 us  100.0%
  trainstep measured             5297.2 us  (profile_kernel)

--- CUDA Graph Training Step ---
  trainstep (graph)               504.3 us     64.98 M elem/s
  graph speedup vs eager:    10.50x

@jsuarez5341 jsuarez5341 merged commit fef2254 into PufferAI:4.0 Feb 13, 2026
0 of 12 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants