In [14]:
from kernels.rmsnorm_batch_invariant import nki_rmsnorm_kernel_isa
from kernels.matmul_batch_invariant import nki_matmul_kernel_isa

In [15]:
import os
os.environ['NEURON_PLATFORM_TARGET_OVERRIDE']='trn2'
os.environ['NEURON_CC_FLAGS'] = os.environ.get('NEURON_CC_FLAGS', '') + ' --cache_dir=/var/tmp/neuron-compile-cache'

In [16]:
import torch

def test_determinism(kernel_fn, a, b, deterministic, iterations=1000):
    """Test kernel produces identical results across 1000 iterations."""
    ref = kernel_fn(a, b, deterministic=deterministic)
    
    for i in range(iterations):
        result = kernel_fn(a, b, deterministic=deterministic)
        max_diff = (result - ref).abs().max().item()
        
        if max_diff != 0:
            print(f"  FAILED at iteration {i}: max_diff={max_diff}")
            return False
    
    print(f"  PASSED: {iterations} iterations identical")
    return True

In [17]:
device = 'xla'
iterations = 5
K, M, N = 512, 256, 512

A = torch.randn(K, M, device=device, dtype=torch.bfloat16)
B = torch.randn(K, N, device=device, dtype=torch.bfloat16)

print(f"Testing {iterations} iterations...")
pass_det = test_determinism(nki_matmul_kernel_isa, A, B, deterministic=True, iterations=iterations)

print("\n" + "=" * 60)
print(f"deterministic=True:  {'PASS' if pass_det else 'FAIL'}")

Testing 5 iterations...
The Python AST is located at: python_ast_tmp.name='/tmp/nki_matmul_kernel_isakwd3zeb1_python_ast.klir'
The KLR format is located at: final_klir_filepath='/tmp/nki_matmul_kernel_isaipi2zfzy.klir'
/home/ubuntu/nki-samples/contributed/batch_invariance/kernels/matmul_batch_invariant.py:19:4:Mutating assignment (a[...] = foo(...)) form is deprecated. Use foo(..., dst=a[...]) instead
The Python AST is located at: python_ast_tmp.name='/tmp/nki_matmul_kernel_isa6iw_7h3v_python_ast.klir'
The KLR format is located at: final_klir_filepath='/tmp/nki_matmul_kernel_isa19y5d6sl.klir'
/home/ubuntu/nki-samples/contributed/batch_invariance/kernels/matmul_batch_invariant.py:19:4:Mutating assignment (a[...] = foo(...)) form is deprecated. Use foo(..., dst=a[...]) instead




.Completed run_backend_driver.

Compiler status PASS
2026-02-26 10:42:25.000268:  3402  [INFO]: Compilation Successfully Completed for model.MODULE_15779616349351854341+fad94d7c.hlo_module.pb
The Python AST is located at: python_ast_tmp.name='/tmp/nki_matmul_kernel_isalspleisu_python_ast.klir'
The KLR format is located at: final_klir_filepath='/tmp/nki_matmul_kernel_isarlm5g87k.klir'
/home/ubuntu/nki-samples/contributed/batch_invariance/kernels/matmul_batch_invariant.py:19:4:Mutating assignment (a[...] = foo(...)) form is deprecated. Use foo(..., dst=a[...]) instead




.Completed run_backend_driver.

Compiler status PASS
2026-02-26 10:42:27.000097:  3402  [INFO]: Compilation Successfully Completed for model.MODULE_7247367884336743177+fad94d7c.hlo_module.pb
The Python AST is located at: python_ast_tmp.name='/tmp/nki_matmul_kernel_isa2mxf5kez_python_ast.klir'
The KLR format is located at: final_klir_filepath='/tmp/nki_matmul_kernel_isagqzmhb6b.klir'
/home/ubuntu/nki-samples/contributed/batch_invariance/kernels/matmul_batch_invariant.py:19:4:Mutating assignment (a[...] = foo(...)) form is deprecated. Use foo(..., dst=a[...]) instead




.Completed run_backend_driver.

Compiler status PASS
2026-02-26 10:42:28.000913:  3402  [INFO]: Compilation Successfully Completed for model.MODULE_8233131819476911003+fad94d7c.hlo_module.pb
The Python AST is located at: python_ast_tmp.name='/tmp/nki_matmul_kernel_isa75iyit_w_python_ast.klir'
The KLR format is located at: final_klir_filepath='/tmp/nki_matmul_kernel_isa3cvf1_dj.klir'
/home/ubuntu/nki-samples/contributed/batch_invariance/kernels/matmul_batch_invariant.py:19:4:Mutating assignment (a[...] = foo(...)) form is deprecated. Use foo(..., dst=a[...]) instead




.Completed run_backend_driver.

Compiler status PASS
2026-02-26 10:42:30.000746:  3402  [INFO]: Compilation Successfully Completed for model.MODULE_15969074187069853100+fad94d7c.hlo_module.pb
The Python AST is located at: python_ast_tmp.name='/tmp/nki_matmul_kernel_isauy80mkvf_python_ast.klir'
The KLR format is located at: final_klir_filepath='/tmp/nki_matmul_kernel_isa6ftu1b15.klir'
/home/ubuntu/nki-samples/contributed/batch_invariance/kernels/matmul_batch_invariant.py:19:4:Mutating assignment (a[...] = foo(...)) form is deprecated. Use foo(..., dst=a[...]) instead




.Completed run_backend_driver.

Compiler status PASS
2026-02-26 10:42:32.000570:  3402  [INFO]: Compilation Successfully Completed for model.MODULE_3973624669412523081+fad94d7c.hlo_module.pb
  PASSED: 5 iterations identical

deterministic=True:  PASS


In [34]:
def test_tiling_invariance(determinism=True, dtype=torch.bfloat16):
    device = 'xla'
    M, K, N = 512, 512, 512
    
    # ISA expects [K, M] @ [K, N]
    a = torch.linspace(-1, 1, K * M, device=device, dtype=dtype).reshape(K, M)
    b = torch.linspace(-1, 1, K * N, device=device, dtype=dtype).reshape(K, N)
    
    out_det = nki_matmul_kernel_isa(a, b, deterministic=True)   # K_TILE=128
    out_adp = nki_matmul_kernel_isa(a, b, deterministic=determinism)  # K_TILE=64
    
    diff = (out_det - out_adp).abs().max().item()
    
    return {"dtype": str(dtype), "diff": diff, "invariant": diff == 0.0}

# ISA kernel deterministic vs non

In [35]:
test_tiling_invariance()
test_tiling_invariance(determinism=False)



The Python AST is located at: python_ast_tmp.name='/tmp/nki_matmul_kernel_isaambjh2uz_python_ast.klir'
The KLR format is located at: final_klir_filepath='/tmp/nki_matmul_kernel_isa5u7f6xwt.klir'
/home/ubuntu/nki-samples/contributed/batch_invariance/kernels/matmul_batch_invariant.py:19:4:Mutating assignment (a[...] = foo(...)) form is deprecated. Use foo(..., dst=a[...]) instead
The Python AST is located at: python_ast_tmp.name='/tmp/nki_matmul_kernel_isamj468t33_python_ast.klir'
The KLR format is located at: final_klir_filepath='/tmp/nki_matmul_kernel_isa41bxy4ut.klir'
/home/ubuntu/nki-samples/contributed/batch_invariance/kernels/matmul_batch_invariant.py:19:4:Mutating assignment (a[...] = foo(...)) form is deprecated. Use foo(..., dst=a[...]) instead
.Completed run_backend_driver.

Compiler status PASS
2026-02-26 10:52:40.000570:  3402  [INFO]: Compilation Successfully Completed for model.MODULE_11295845753885402139+fad94d7c.hlo_module.pb
The Python AST is located at: python_ast_tmp.n



.Completed run_backend_driver.

Compiler status PASS
2026-02-26 10:52:42.000137:  3402  [INFO]: Compilation Successfully Completed for model.MODULE_7224974460960840183+fad94d7c.hlo_module.pb


{'dtype': 'torch.bfloat16', 'diff': 0.0, 'invariant': True}

# ISA kernel deterministic vs non with float32

In [36]:
test_tiling_invariance(dtype=torch.float32)
test_tiling_invariance(determinism=False, dtype=torch.float32)



The Python AST is located at: python_ast_tmp.name='/tmp/nki_matmul_kernel_isa641psffg_python_ast.klir'
The KLR format is located at: final_klir_filepath='/tmp/nki_matmul_kernel_isahwa5v2s2.klir'
/home/ubuntu/nki-samples/contributed/batch_invariance/kernels/matmul_batch_invariant.py:19:4:Mutating assignment (a[...] = foo(...)) form is deprecated. Use foo(..., dst=a[...]) instead
The Python AST is located at: python_ast_tmp.name='/tmp/nki_matmul_kernel_isaf_k97gyh_python_ast.klir'
The KLR format is located at: final_klir_filepath='/tmp/nki_matmul_kernel_isai_91rlkj.klir'
/home/ubuntu/nki-samples/contributed/batch_invariance/kernels/matmul_batch_invariant.py:19:4:Mutating assignment (a[...] = foo(...)) form is deprecated. Use foo(..., dst=a[...]) instead
.Completed run_backend_driver.

Compiler status PASS
2026-02-26 10:53:19.000728:  3402  [INFO]: Compilation Successfully Completed for model.MODULE_8697539303033536320+fad94d7c.hlo_module.pb
The Python AST is located at: python_ast_tmp.na



.Completed run_backend_driver.

Compiler status PASS
2026-02-26 10:53:21.000292:  3402  [INFO]: Compilation Successfully Completed for model.MODULE_12152672823795625970+fad94d7c.hlo_module.pb


{'dtype': 'torch.float32', 'diff': 6.103515625e-05, 'invariant': False}

In [37]:
def test_rmsnorm_tiling_invariance(determinism=True, dtype=torch.bfloat16):
    """
    Test RMSNorm kernel for tiling invariance.
    Compares deterministic=True vs deterministic=False to see if different
    HIDDEN_TILE sizes produce different numerical results.
    """
    device = 'xla'
    batch_size = 128
    hidden_dim = 512

    a = torch.linspace(-1, 1, batch_size * hidden_dim, device=device, dtype=dtype).reshape(batch_size, hidden_dim)
    g = torch.ones(hidden_dim, device=device, dtype=dtype)

    out_det = nki_rmsnorm_kernel_isa(a, g, deterministic=True)
    out_adp = nki_rmsnorm_kernel_isa(a, g, deterministic=determinism)

    diff = (out_det - out_adp).abs().max().item()

    return {"dtype": str(dtype), "diff": diff, "invariant": diff == 0.0}

In [38]:
test_rmsnorm_tiling_invariance()
test_rmsnorm_tiling_invariance(determinism=False)



The Python AST is located at: python_ast_tmp.name='/tmp/nki_rmsnorm_kernel_isasvoh90iu_python_ast.klir'
The KLR format is located at: final_klir_filepath='/tmp/nki_rmsnorm_kernel_isablzbcgq8.klir'
The Python AST is located at: python_ast_tmp.name='/tmp/nki_rmsnorm_kernel_isajpb7j6a6_python_ast.klir'
The KLR format is located at: final_klir_filepath='/tmp/nki_rmsnorm_kernel_isapugmd09u.klir'
.Completed run_backend_driver.

Compiler status PASS
2026-02-26 10:54:15.000169:  3402  [INFO]: Compilation Successfully Completed for model.MODULE_10803138165116680494+fad94d7c.hlo_module.pb
The Python AST is located at: python_ast_tmp.name='/tmp/nki_rmsnorm_kernel_isawvoeu55k_python_ast.klir'
The KLR format is located at: final_klir_filepath='/tmp/nki_rmsnorm_kernel_isa8e_myyw8.klir'




The Python AST is located at: python_ast_tmp.name='/tmp/nki_rmsnorm_kernel_isa5i76a0dm_python_ast.klir'
The KLR format is located at: final_klir_filepath='/tmp/nki_rmsnorm_kernel_isala7qox6l.klir'
.Completed run_backend_driver.

Compiler status PASS
2026-02-26 10:54:16.000730:  3402  [INFO]: Compilation Successfully Completed for model.MODULE_657465572967042995+fad94d7c.hlo_module.pb


{'dtype': 'torch.bfloat16', 'diff': 0.0, 'invariant': True}

In [39]:
test_rmsnorm_tiling_invariance(dtype=torch.float32)
test_rmsnorm_tiling_invariance(determinism=False, dtype=torch.float32)



The Python AST is located at: python_ast_tmp.name='/tmp/nki_rmsnorm_kernel_isanwwsbsck_python_ast.klir'
The KLR format is located at: final_klir_filepath='/tmp/nki_rmsnorm_kernel_isagx3j2hkl.klir'
The Python AST is located at: python_ast_tmp.name='/tmp/nki_rmsnorm_kernel_isani2sy3wg_python_ast.klir'
The KLR format is located at: final_klir_filepath='/tmp/nki_rmsnorm_kernel_isa9do8tdht.klir'
.Completed run_backend_driver.

Compiler status PASS
2026-02-26 10:54:22.000184:  3402  [INFO]: Compilation Successfully Completed for model.MODULE_15777384063707193226+fad94d7c.hlo_module.pb
The Python AST is located at: python_ast_tmp.name='/tmp/nki_rmsnorm_kernel_isah9bxuvb7_python_ast.klir'
The KLR format is located at: final_klir_filepath='/tmp/nki_rmsnorm_kernel_isava3a38ue.klir'




The Python AST is located at: python_ast_tmp.name='/tmp/nki_rmsnorm_kernel_isaz64pbi6s_python_ast.klir'
The KLR format is located at: final_klir_filepath='/tmp/nki_rmsnorm_kernel_isane1s1wc9.klir'
.Completed run_backend_driver.

Compiler status PASS
2026-02-26 10:54:23.000744:  3402  [INFO]: Compilation Successfully Completed for model.MODULE_3828762567022385588+fad94d7c.hlo_module.pb


{'dtype': 'torch.float32', 'diff': 2.384185791015625e-07, 'invariant': False}