In [1]:
from kernels.rmsnorm_batch_invariant import nki_rmsnorm_kernel_lang, nki_rmsnorm_kernel_isa
from kernels.matmul_batch_invariant import nki_matmul_kernel_isa, nki_matmul_kernel_lang
import torch
import torch_neuronx 

In [7]:
import torch

def test_determinism(kernel_fn, a, b, deterministic, iterations=1000):
    """Test kernel produces identical results across 1000 iterations."""
    ref = kernel_fn(a, b, deterministic=deterministic)
    
    for i in range(iterations):
        result = kernel_fn(a, b, deterministic=deterministic)
        max_diff = (result - ref).abs().max().item()
        
        if max_diff != 0:
            print(f"  FAILED at iteration {i}: max_diff={max_diff}")
            return False
    
    print(f"  PASSED: {iterations} iterations identical")
    return True

In [None]:
device = 'xla'
K, M, N = 512, 256, 512

A = torch.randn(K, M, device=device, dtype=torch.bfloat16)
B = torch.randn(K, N, device=device, dtype=torch.bfloat16)

print("Testing 10000 iterations...")

print("\ndeterministic=True:")
pass_det = test_determinism(nki_matmul_kernel_isa, A, B, deterministic=True, iterations=10000)

print("\n" + "=" * 60)
print(f"deterministic=True:  {'PASS' if pass_det else 'FAIL'}")

Testing 1000 iterations...

deterministic=True:
.Completed run_backend_driver.

Compiler status PASS
2026-01-30 21:55:07.000869:  13220  [INFO]: Compilation Successfully Completed for model.MODULE_11646591744998724192+fad94d7c.hlo_module.pb
  PASSED: 10000 iterations identical

deterministic=True:  PASS


In [2]:
def test_tiling_invariance(kernel_fn, is_isa=False, determinism=True, dtype=torch.bfloat16):
    device = 'xla'
    M, K, N = 512, 512, 512
    
    if is_isa:
        # ISA expects [K, M] @ [K, N]
        a = torch.linspace(-1, 1, K * M, device=device, dtype=dtype).reshape(K, M)
    else:
        # Lang expects [M, K] @ [K, N]
        a = torch.linspace(-1, 1, M * K, device=device, dtype=dtype).reshape(M, K)
    
    b = torch.linspace(-1, 1, K * N, device=device, dtype=dtype).reshape(K, N)
    
    out_det = kernel_fn(a, b, deterministic=True)   # K_TILE=128
    out_adp = kernel_fn(a, b, deterministic=determinism)  # K_TILE=64
    
    diff = (out_det - out_adp).abs().max().item()
    
    name = "ISA" if is_isa else "Lang"
    print(f"{name}: deterministic=True vs {determinism} → diff={diff:.6f}")
    print(f"  Tiling affects numerics: {'YES' if diff > 0 else 'NO'}")
    

# Lang kernel deterministic vs non

In [3]:
test_tiling_invariance(nki_matmul_kernel_lang, is_isa=False)
test_tiling_invariance(nki_matmul_kernel_lang, is_isa=False, determinism=False)

2026-Jan-30 21:50:02.0908 13220:13274 [1] int nccl_net_ofi_create_plugin(nccl_net_ofi_plugin_t**):219 CCOM WARN NET/OFI Failed to initialize rdma protocol
2026-Jan-30 21:50:02.0911 13220:13274 [1] int nccl_net_ofi_create_plugin(nccl_net_ofi_plugin_t**):354 CCOM WARN NET/OFI aws-ofi-nccl initialization failed
2026-Jan-30 21:50:02.0913 13220:13274 [1] ncclResult_t nccl_net_ofi_init_no_atexit_fini_v6(ncclDebugLogger_t):183 CCOM WARN NET/OFI Initializing plugin failed
2026-Jan-30 21:50:02.0916 13220:13274 [1] net_plugin.cc:97 CCOM WARN OFI plugin initNet() failed is EFA enabled?
.Completed run_backend_driver.

Compiler status PASS
2026-01-30 21:50:04.000403:  13220  [INFO]: Compilation Successfully Completed for model.MODULE_11522224973351651600+fad94d7c.hlo_module.pb
Lang: deterministic=True vs True → diff=0.000000
  Tiling affects numerics: NO
.Completed run_backend_driver.

Compiler status PASS
2026-01-30 21:50:05.000978:  13220  [INFO]: Compilation Successfully Completed for model.MODU

In [4]:
test_tiling_invariance(nki_matmul_kernel_lang, is_isa=False)
test_tiling_invariance(nki_matmul_kernel_lang, is_isa=False, determinism=False, dtype=torch.float32)

Lang: deterministic=True vs True → diff=0.000000
  Tiling affects numerics: NO
.Completed run_backend_driver.

Compiler status PASS
2026-01-30 21:50:10.000417:  13220  [INFO]: Compilation Successfully Completed for model.MODULE_6421119283783150616+fad94d7c.hlo_module.pb
Lang: deterministic=True vs False → diff=0.000046
  Tiling affects numerics: YES


# ISA kernel deterministic vs non

In [5]:
test_tiling_invariance(nki_matmul_kernel_isa, is_isa=True)
test_tiling_invariance(nki_matmul_kernel_isa, is_isa=True, determinism=False)

2026-01-30 21:50:24.000003:  13220  [INFO]: Using a cached neff at /var/tmp/neuron-compile-cache/neuronxcc-2.22.12471.0+b4a00d10/MODULE_5313299922059221254+fad94d7c/model.neff
ISA: deterministic=True vs True → diff=0.000000
  Tiling affects numerics: NO
2026-01-30 21:50:24.000047:  13220  [INFO]: Using a cached neff at /var/tmp/neuron-compile-cache/neuronxcc-2.22.12471.0+b4a00d10/MODULE_16718627453147721994+fad94d7c/model.neff
ISA: deterministic=True vs False → diff=0.000000
  Tiling affects numerics: NO


# ISA kernel deterministic vs non with float32

In [6]:
test_tiling_invariance(nki_matmul_kernel_isa, is_isa=True)
test_tiling_invariance(nki_matmul_kernel_isa, is_isa=True, determinism=False, dtype=torch.float32)

ISA: deterministic=True vs True → diff=0.000000
  Tiling affects numerics: NO
2026-01-30 21:50:27.000813:  13220  [INFO]: Using a cached neff at /var/tmp/neuron-compile-cache/neuronxcc-2.22.12471.0+b4a00d10/MODULE_11375411469173762114+fad94d7c/model.neff
ISA: deterministic=True vs False → diff=0.000061
  Tiling affects numerics: YES
