From 7232608c951f5baef481c58ad66bf20c57ea31fd Mon Sep 17 00:00:00 2001 From: Less Wright Date: Wed, 10 Dec 2025 12:16:57 -0800 Subject: [PATCH 1/3] Update silu_and_mul.py remove dead code --- src/tilegym/ops/cutile/silu_and_mul.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/tilegym/ops/cutile/silu_and_mul.py b/src/tilegym/ops/cutile/silu_and_mul.py index 4a6ac3d..e0932ce 100644 --- a/src/tilegym/ops/cutile/silu_and_mul.py +++ b/src/tilegym/ops/cutile/silu_and_mul.py @@ -45,7 +45,6 @@ def silu_and_mul_kernel_row_wise( row_idx = bid a_col_idx = offsets # First half: [0, hidden_size) b_col_idx = offsets + hidden_size # Second half: [hidden_size, 2*hidden_size) - out_offsets = bid * hidden_size + offsets # Load tiles using gather with 2D indices # gather broadcasts (scalar, tile) to (tile,) From 26e8fdf278213285513380b952bf0d218b76f573 Mon Sep 17 00:00:00 2001 From: Less Wright Date: Wed, 10 Dec 2025 17:41:55 -0800 Subject: [PATCH 2/3] remove n_elements param --- src/tilegym/ops/cutile/silu_and_mul.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/tilegym/ops/cutile/silu_and_mul.py b/src/tilegym/ops/cutile/silu_and_mul.py index e0932ce..728b7db 100644 --- a/src/tilegym/ops/cutile/silu_and_mul.py +++ b/src/tilegym/ops/cutile/silu_and_mul.py @@ -34,7 +34,6 @@ def silu_and_mul_kernel_row_wise( input, output, TILE_SIZE: ConstInt, - n_elements: ConstInt, hidden_size: ConstInt, ): bid = ct.bid(0) # this gives us our row From 887d5c2f6ed0bc792f4ef1afca6547ce89f4b6a9 Mon Sep 17 00:00:00 2001 From: lessw2020 Date: Thu, 11 Dec 2025 05:59:15 +0000 Subject: [PATCH 3/3] remove all n_elements references, verify working --- src/tilegym/ops/cutile/silu_and_mul.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/src/tilegym/ops/cutile/silu_and_mul.py b/src/tilegym/ops/cutile/silu_and_mul.py index 728b7db..55eb2ee 100644 --- a/src/tilegym/ops/cutile/silu_and_mul.py +++ b/src/tilegym/ops/cutile/silu_and_mul.py @@ -90,9 +90,6 @@ def silu_and_mul( # Flatten input to 2D: (batch_size, 2 * hidden_size) input_flat = input.view(-1, original_shape[-1]) batch_size = input_flat.shape[0] - n_elements = ( - batch_size * hidden_size - ) # Total elements to process in output # Get final output shape output_shape = list(original_shape) @@ -124,7 +121,6 @@ def silu_and_mul( input_flat, output, TILE_SIZE, - n_elements, hidden_size ), )