In [1]:
from kernels.rmsnorm_batch_invariant import nki_rmsnorm_kernel_isa
from kernels.matmul_batch_invariant import nki_matmul_kernel_isa

In [4]:
import torch_xla
torch_xla.device()

2026-Feb-27 15:48:33.0366 3428:3621 [0] int nccl_net_ofi_create_plugin(nccl_net_ofi_plugin_t**):219 CCOM WARN NET/OFI Failed to initialize rdma protocol
2026-Feb-27 15:48:33.0368 3428:3621 [0] int nccl_net_ofi_create_plugin(nccl_net_ofi_plugin_t**):354 CCOM WARN NET/OFI aws-ofi-nccl initialization failed
2026-Feb-27 15:48:33.0370 3428:3621 [0] ncclResult_t nccl_net_ofi_init_no_atexit_fini_v6(ncclDebugLogger_t):183 CCOM WARN NET/OFI Initializing plugin failed
2026-Feb-27 15:48:33.0372 3428:3621 [0] net_plugin.cc:97 CCOM WARN OFI plugin initNet() failed is EFA enabled?


device(type='xla', index=0)

In [2]:
import os
os.environ['NEURON_PLATFORM_TARGET_OVERRIDE']='trn2'
os.environ['NEURON_CC_FLAGS'] = os.environ.get('NEURON_CC_FLAGS', '') + ' --cache_dir=/var/tmp/neuron-compile-cache'

# Determinism checks

In [3]:
import torch

def test_determinism(kernel_fn, a, b, deterministic, iterations=1000):
    """Test kernel produces identical results across 1000 iterations."""
    ref = kernel_fn(a, b, deterministic=deterministic)
    
    for i in range(iterations):
        result = kernel_fn(a, b, deterministic=deterministic)
        max_diff = (result - ref).abs().max().item()
        
        if max_diff != 0:
            print(f"  FAILED at iteration {i}: max_diff={max_diff}")
            return False
    
    print(f"  PASSED: {iterations} iterations identical")
    return True

In [11]:
device = 'xla'
iterations = 5
K, M, N = 512, 256, 512

A = torch.randn(K, M, device=device, dtype=torch.bfloat16)
B = torch.randn(K, N, device=device, dtype=torch.bfloat16)

print(f"Testing {iterations} iterations...")
pass_det = test_determinism(nki_matmul_kernel_isa, A, B, deterministic=True, iterations=iterations)

print("\n" + "=" * 60)
print(f"deterministic=True:  {'PASS' if pass_det else 'FAIL'}")

Testing 5 iterations...
The Python AST is located at: /tmp/klir_binaries/nki_matmul_kernel_isa_6hia3ssb/nki_matmul_kernel_isa30nv0i4__python_ast.klir
The KLR format is located at: final_klir_filepath='/tmp/klir_binaries/nki_matmul_kernel_isa_6hia3ssb/nki_matmul_kernel_isa9tkog97m.klir'
The Python AST is located at: /tmp/klir_binaries/nki_matmul_kernel_isa_3d2ogu6z/nki_matmul_kernel_isaadz8zlut_python_ast.klir
The KLR format is located at: final_klir_filepath='/tmp/klir_binaries/nki_matmul_kernel_isa_3d2ogu6z/nki_matmul_kernel_isabe8s0u6y.klir'
.
Compiler status PASS
2026-02-27 16:01:15.000805:  3428  [INFO]: Compilation Successfully Completed for model.MODULE_9473861346067690811+fad94d7c.hlo_module.pb
The Python AST is located at: /tmp/klir_binaries/nki_matmul_kernel_isa_ia_kgfst/nki_matmul_kernel_isa3bbxgs97_python_ast.klir
The KLR format is located at: final_klir_filepath='/tmp/klir_binaries/nki_matmul_kernel_isa_ia_kgfst/nki_matmul_kernel_isa1l510mjh.klir'
.
Compiler status PASS
202

## Numerical parity checks

In [None]:
import torch
import torch_neuronx

def test_matmul_parity():
    """Verify NKI matmul matches PyTorch."""
    M, K, N = 256, 512, 512

    a = torch.randn(M, K, dtype=torch.float32)
    b = torch.randn(K, N, dtype=torch.float32)

    # PyTorch reference
    ref = torch.matmul(a, b)

    # NKI kernel (expects [K, M] layout)
    a_xla = a.T.to('xla')  # [K, M]
    b_xla = b.to('xla')    # [K, N]
    result = nki_matmul_kernel_isa(a_xla, b_xla, deterministic=True).cpu()

    assert torch.allclose(ref, result, atol=1e-3, rtol=1e-2), \
        f"MatMul mismatch: max diff = {torch.max(torch.abs(ref - result))}"
    print("✓ MatMul parity test passed")

def test_rmsnorm_parity():
    """Verify NKI RMSNorm matches PyTorch."""
    batch, hidden = 128, 512
    eps = 1e-6

    x = torch.randn(batch, hidden, dtype=torch.float32)
    g = torch.ones(hidden, dtype=torch.float32)

    # PyTorch reference
    rms = torch.sqrt(torch.mean(x ** 2, dim=-1, keepdim=True) + eps)
    ref = (x / rms) * g

    # NKI kernel
    x_xla = x.to('xla')
    g_xla = g.to('xla')
    result = nki_rmsnorm_kernel_isa(x_xla, g_xla, deterministic=True).cpu()

    assert torch.allclose(ref, result, atol=1e-3, rtol=1e-2), \
        f"RMSNorm mismatch: max diff = {torch.max(torch.abs(ref - result))}"
    print("✓ RMSNorm parity test passed")

In [16]:
test_matmul_parity()
test_rmsnorm_parity()

The Python AST is located at: /tmp/klir_binaries/nki_matmul_kernel_isa_tybflq0s/nki_matmul_kernel_isao5zwhphv_python_ast.klir
The KLR format is located at: final_klir_filepath='/tmp/klir_binaries/nki_matmul_kernel_isa_tybflq0s/nki_matmul_kernel_isa0xaf7fzu.klir'
.
Compiler status PASS
2026-02-27 16:07:37.000643:  3428  [INFO]: Compilation Successfully Completed for model.MODULE_13037584473499484256+fad94d7c.hlo_module.pb
✓ MatMul parity test passed
The Python AST is located at: /tmp/klir_binaries/nki_rmsnorm_kernel_isa_j22ttxrd/nki_rmsnorm_kernel_isabljud138_python_ast.klir
The KLR format is located at: final_klir_filepath='/tmp/klir_binaries/nki_rmsnorm_kernel_isa_j22ttxrd/nki_rmsnorm_kernel_isa6638lr1l.klir'
.
Compiler status PASS
2026-02-27 16:07:40.000774:  3428  [INFO]: Compilation Successfully Completed for model.MODULE_7997940888169041779+fad94d7c.hlo_module.pb
✓ RMSNorm parity test passed


# Tile size invariance tests
## Matmul Kernel

In [12]:
def test_tiling_invariance(determinism=True, dtype=torch.bfloat16):
    device = 'xla'
    M, K, N = 512, 512, 512
    
    # ISA expects [K, M] @ [K, N]
    a = torch.linspace(-1, 1, K * M, device=device, dtype=dtype).reshape(K, M)
    b = torch.linspace(-1, 1, K * N, device=device, dtype=dtype).reshape(K, N)
    
    out_det = nki_matmul_kernel_isa(a, b, deterministic=True)   # K_TILE=128
    out_adp = nki_matmul_kernel_isa(a, b, deterministic=determinism)  # K_TILE=64
    
    diff = (out_det - out_adp).abs().max().item()
    
    return {"dtype": str(dtype), "diff": diff, "invariant": diff == 0.0}

deterministic vs non-deterministic (bfloat16)

In [13]:
test_tiling_invariance()
test_tiling_invariance(determinism=False)

The Python AST is located at: /tmp/klir_binaries/nki_matmul_kernel_isa_phpbl_66/nki_matmul_kernel_isasepvmdz2_python_ast.klir
The KLR format is located at: final_klir_filepath='/tmp/klir_binaries/nki_matmul_kernel_isa_phpbl_66/nki_matmul_kernel_isajz1xuo19.klir'
The Python AST is located at: /tmp/klir_binaries/nki_matmul_kernel_isa_f92wuxuw/nki_matmul_kernel_isawptt543e_python_ast.klir
The KLR format is located at: final_klir_filepath='/tmp/klir_binaries/nki_matmul_kernel_isa_f92wuxuw/nki_matmul_kernel_isaulx1whcr.klir'
.
Compiler status PASS
2026-02-27 16:01:31.000226:  3428  [INFO]: Compilation Successfully Completed for model.MODULE_1766330591526900260+fad94d7c.hlo_module.pb
The Python AST is located at: /tmp/klir_binaries/nki_matmul_kernel_isa_yeeav7hs/nki_matmul_kernel_isa094goyv9_python_ast.klir
The KLR format is located at: final_klir_filepath='/tmp/klir_binaries/nki_matmul_kernel_isa_yeeav7hs/nki_matmul_kernel_isaz425zx7q.klir'
The Python AST is located at: /tmp/klir_binaries/n

{'dtype': 'torch.bfloat16', 'diff': 0.0, 'invariant': True}

deterministic vs non-deterministic with float32

In [14]:
test_tiling_invariance(dtype=torch.float32)
test_tiling_invariance(determinism=False, dtype=torch.float32)

The Python AST is located at: /tmp/klir_binaries/nki_matmul_kernel_isa_n6lafb2g/nki_matmul_kernel_isar_nzcsld_python_ast.klir
The KLR format is located at: final_klir_filepath='/tmp/klir_binaries/nki_matmul_kernel_isa_n6lafb2g/nki_matmul_kernel_isall8f6oiu.klir'
The Python AST is located at: /tmp/klir_binaries/nki_matmul_kernel_isa_2wt8vlli/nki_matmul_kernel_isai8aweift_python_ast.klir
The KLR format is located at: final_klir_filepath='/tmp/klir_binaries/nki_matmul_kernel_isa_2wt8vlli/nki_matmul_kernel_isagt2pcrka.klir'
.
Compiler status PASS
2026-02-27 16:01:38.000733:  3428  [INFO]: Compilation Successfully Completed for model.MODULE_10769978250524783468+fad94d7c.hlo_module.pb
The Python AST is located at: /tmp/klir_binaries/nki_matmul_kernel_isa_mthaoopc/nki_matmul_kernel_isauvor3s85_python_ast.klir
The KLR format is located at: final_klir_filepath='/tmp/klir_binaries/nki_matmul_kernel_isa_mthaoopc/nki_matmul_kernel_isae90ejwxu.klir'
The Python AST is located at: /tmp/klir_binaries/

{'dtype': 'torch.float32', 'diff': 6.103515625e-05, 'invariant': False}

## RMSNorm kernel

In [6]:
def test_rmsnorm_tiling_invariance(determinism=True, dtype=torch.bfloat16):
    """
    Test RMSNorm kernel for tiling invariance.
    Compares deterministic=True vs deterministic=False to see if different
    HIDDEN_TILE sizes produce different numerical results.
    """
    device = 'xla'
    batch_size = 128
    hidden_dim = 512

    a = torch.linspace(-1, 1, batch_size * hidden_dim, device=device, dtype=dtype).reshape(batch_size, hidden_dim)
    g = torch.ones(hidden_dim, device=device, dtype=dtype)

    out_det = nki_rmsnorm_kernel_isa(a, g, deterministic=True)
    out_adp = nki_rmsnorm_kernel_isa(a, g, deterministic=determinism)

    diff = (out_det - out_adp).abs().max().item()

    return {"dtype": str(dtype), "diff": diff, "invariant": diff == 0.0}

deterministic vs non-deterministic (bfloat16)

In [9]:
test_rmsnorm_tiling_invariance()
test_rmsnorm_tiling_invariance(determinism=False)

The Python AST is located at: /tmp/klir_binaries/nki_rmsnorm_kernel_isa_t92galw_/nki_rmsnorm_kernel_isatr7yukyv_python_ast.klir
The KLR format is located at: final_klir_filepath='/tmp/klir_binaries/nki_rmsnorm_kernel_isa_t92galw_/nki_rmsnorm_kernel_isa_uz7r3w7.klir'
The Python AST is located at: /tmp/klir_binaries/nki_rmsnorm_kernel_isa_1bc56dl_/nki_rmsnorm_kernel_isa2zul72uw_python_ast.klir
The KLR format is located at: final_klir_filepath='/tmp/klir_binaries/nki_rmsnorm_kernel_isa_1bc56dl_/nki_rmsnorm_kernel_isan3zqr8zy.klir'
....
Compiler status PASS
2026-02-27 15:57:03.000070:  3428  [INFO]: Compilation Successfully Completed for model.MODULE_9950062464119990324+fad94d7c.hlo_module.pb
The Python AST is located at: /tmp/klir_binaries/nki_rmsnorm_kernel_isa_7me51j3i/nki_rmsnorm_kernel_isaxef6x2_c_python_ast.klir
The KLR format is located at: final_klir_filepath='/tmp/klir_binaries/nki_rmsnorm_kernel_isa_7me51j3i/nki_rmsnorm_kernel_isahi5g7s75.klir'
The Python AST is located at: /tmp/

{'dtype': 'torch.bfloat16', 'diff': 0.0, 'invariant': True}

deterministic vs non-deterministic (float32)

In [10]:
test_rmsnorm_tiling_invariance(dtype=torch.float32)
test_rmsnorm_tiling_invariance(determinism=False, dtype=torch.float32)

The Python AST is located at: /tmp/klir_binaries/nki_rmsnorm_kernel_isa_rbpnxx1y/nki_rmsnorm_kernel_isac6p2nv1__python_ast.klir
The KLR format is located at: final_klir_filepath='/tmp/klir_binaries/nki_rmsnorm_kernel_isa_rbpnxx1y/nki_rmsnorm_kernel_isai71o9lcj.klir'
The Python AST is located at: /tmp/klir_binaries/nki_rmsnorm_kernel_isa_ipndb477/nki_rmsnorm_kernel_isaso5l1taj_python_ast.klir
The KLR format is located at: final_klir_filepath='/tmp/klir_binaries/nki_rmsnorm_kernel_isa_ipndb477/nki_rmsnorm_kernel_isa8tmfzk2t.klir'
.
Compiler status PASS
2026-02-27 15:57:13.000923:  3428  [INFO]: Compilation Successfully Completed for model.MODULE_6527901568736549946+fad94d7c.hlo_module.pb
The Python AST is located at: /tmp/klir_binaries/nki_rmsnorm_kernel_isa__0a8edij/nki_rmsnorm_kernel_isaylk9_elw_python_ast.klir
The KLR format is located at: final_klir_filepath='/tmp/klir_binaries/nki_rmsnorm_kernel_isa__0a8edij/nki_rmsnorm_kernel_isa9h_wyeae.klir'
The Python AST is located at: /tmp/kli

{'dtype': 'torch.float32', 'diff': 2.384185791015625e-07, 'invariant': False}