Skip to content

Add Snake activation functor for EVT#3184

Open
emre570 wants to merge 1 commit intoNVIDIA:mainfrom
emre570:feat/snake-activation
Open

Add Snake activation functor for EVT#3184
emre570 wants to merge 1 commit intoNVIDIA:mainfrom
emre570:feat/snake-activation

Conversation

@emre570
Copy link
Copy Markdown

@emre570 emre570 commented Apr 24, 2026

Adds cutlass::epilogue::thread::Snake, a two-operand activation functor
implementing the Snake activation from Ziyin, Hartwig, & Ueda
(NeurIPS 2020, arXiv:2006.08195):

Snake_a(x) = x + (1/a) * sin²(a*x)

The per-channel learnable frequency a is passed as a second operand,
flowing through an EVT child (typically Sm90RowBroadcast). This enables
fused GEMM/Conv + Snake epilogues for neural vocoders, where Snake is the
standard activation in the WaveGen path. See issue #3141 for production
benchmarks (~2.1× median speedup on H100 vocoder shapes).

Implementation

  • include/cutlass/epilogue/thread/activation.h — scalar Snake<T> and
    Snake<Array<T, N>> specializations, using cutlass::fast_sin for
    IEEE-precise device/host math.
  • Slots into the generic variadic Sm90Compute primary; no new
    infrastructure needed.
  • Follows house style: CUTLASS_HOST_DEVICE, kIsHeavy = true,
    CUTLASS_PRAGMA_UNROLL on the Array specialization.

Tests

Adds unit tests in test/unit/epilogue/thread/activation.cu:

  • Epilogue_thread_snake.device_f32 — tolerance 1e-5
  • Epilogue_thread_snake.device_bf16 — tolerance 2e-2

Both pass on SM90 (H100). Goldens are generated from a math.sin-based
float64 reference over 256 samples each of x, α, and expected output,
with α ∈ (0.1, 2.0) and x ~ N(0, 1).

Usage

// EVT tree: Snake(acc, alpha_broadcast)
using SnakeEpilogue = cutlass::epilogue::fusion::Sm90EVT<
    cutlass::epilogue::fusion::Sm90Compute<
        cutlass::epilogue::thread::Snake,
        ElementOut, ElementEpi,
        cutlass::FloatRoundStyle::round_to_nearest>,
    cutlass::epilogue::fusion::Sm90AccFetch,
    cutlass::epilogue::fusion::Sm90RowBroadcast<0, TileShape, float>
>;

Closes #3141

Introduces cutlass::epilogue::thread::Snake, a two-operand activation
functor implementing Snake_a(x) = x + (1/a) * sin^2(a*x) from
Ziyin et al. 2020 (arXiv:2006.08195). The per-channel learnable
frequency `a` flows through an EVT child (e.g. Sm90RowBroadcast),
composing into Sm90EVT<Sm90Compute<Snake, ...>, x_node, alpha_node>
for fused GEMM+Snake epilogues used in neural vocoders.

Adds unit tests in test/unit/epilogue/thread/activation.cu covering
f32 and bf16 paths, validated against float64 reference goldens.

Closes NVIDIA#3141
@emre570
Copy link
Copy Markdown
Author

emre570 commented Apr 27, 2026

@hwu36 Hi sir, is it possible to take a look in your free time?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[FEA] Add Snake activation functor for Epilogue Visitor Tree (EVT)

1 participant