Skip to content

Conversation

@jonahsamost
Copy link

This kernel mainly attains a speed up from doing an online softmax in both the forward and backward passes and doing computation in 32 bit instead of 64 bit precision. It also avoids writing intermediate values to be used in the backward pass in favor of recomputation.

ppo_loss (NT=32768, 512x64, A=4)
  	forward (original)   17.3 us  1897.58 M elem/s
  	backward (original)   14.4 us  2279.77 M elem/s
  	forward (optimized)    3.3 us  9848.82 M elem/s
  	backward (optimized)    3.3 us  9787.83 M elem/s
  	forward (torch)     17.3 us  1890.14 M elem/s
  	backward (torch)   141.5 us  231.54 M elem/s
  	forward (cpp)      220.1 us  148.87 M elem/s
  	backward (cpp)    1084.5 us   30.21 M elem/s
  	forward (graph)     49.2 us  666.60 M elem/s

@jsuarez5341 jsuarez5341 merged commit 705acc9 into PufferAI:4.0 Jan 24, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants