# CUDNN JIT Batch MatMul

## What is cuDNN JIT?

cuDNN JIT (Just-In-Time compilation) is a runtime configuration of cuDNN that generates GPU kernels at runtime rather than shipping pre-compiled kernels. This provides:

- **Significantly smaller binary size** compared to cuDNN FULL
- **Support for runtime fusion engines** through the graph API
- **Optimized kernels for specific problem sizes** generated on-the-fly

### Requirements
- GPUs with compute capability $\geq 8.0$ (NVIDIA Ampere and later)
- Only dynamic libraries supported (no static linking)

### Limitations
- Narrower support surface than cuDNN FULL
- Pre-compiled single-operation engines NOT supported
- Specialized pre-compiled fusion engines NOT supported

Docs: https://docs.nvidia.com/deeplearning/cudnn/frontend/v1.14.1/developer/graph-api.html#cudnn-jit


## How JIT Achieves Performance Benefits

### The CUDA Compilation Pipeline

CUDA code goes through several stages:

$$\text{CUDA C++} \xrightarrow{\text{nvcc}} \text{PTX} \xrightarrow{\text{ptxas}} \text{SASS}$$

- **PTX (Parallel Thread Execution)**: Virtual ISA, portable across GPU architectures
- **SASS (Streaming ASSembler)**: Actual machine code for a specific GPU architecture (sm_80, sm_90, etc.)

### Traditional cuDNN (Pre-compiled)

Ships **SASS binaries** for every supported:
- GPU architecture (sm_70, sm_75, sm_80, sm_86, sm_89, sm_90...)
- Operation variant (different tile sizes, data types, layouts...)
- Problem size class

This creates **massive library sizes** (several GB) because:
$$\text{Binary Size} \propto |\text{architectures}| \times |\text{operations}| \times |\text{variants}|$$

### cuDNN JIT (Runtime Compilation)

Instead of shipping pre-compiled SASS, JIT:

1. **Stores parameterized kernel templates** (much smaller)
2. **At runtime**: Generates optimized PTX/SASS for the *exact* problem
3. **Caches compiled kernels** for reuse

### Why This Can Be Faster

1. **Specialization**: Kernel is compiled for exact dimensions $M, N, K$, not a size class
   - Loop bounds become compile-time constants → unrolling
   - Dead code elimination for unused branches

2. **Fusion**: Multiple operations compiled into a single kernel
   - Eliminates intermediate memory traffic
   - For $C = \text{ReLU}(\alpha \cdot AB + \beta \cdot C)$: one kernel instead of three

3. **Register Allocation**: Compiler can optimize for the specific tile sizes and data flow

### The Trade-off

| Aspect | Pre-compiled | JIT |
|--------|-------------|-----|
| First execution | Fast | Slow (compilation) |
| Library size | Large (GB) | Small (MB) |
| Kernel performance | Good (generic) | Excellent (specialized) |
| Supported ops | Full | Runtime fusion only |

The compilation overhead is amortized when:
- Same graph executed many times (training loops)
- Kernel cache is persisted across runs


## Implementation: Batch MatMul with @cudnn.jit

The `@cudnn.jit` decorator provides a clean way to define and build cuDNN graphs:

```python
@cudnn.jit(heur_modes=[cudnn.heur_mode.A])
def my_matmul_graph():
    with cudnn.graph(handle, ...) as g:
        # define tensors and operations
        return g, [input_tensors, output_tensor]
```

The decorator:
1. Calls your function to get the graph definition
2. Automatically builds with specified heuristic modes
3. Returns `(graph, [tensor_uids])` for execution


In [1]:
import cudnn
import torch

print(f"cuDNN frontend version: {cudnn.__version__}")
print(f"cuDNN backend version: {cudnn.backend_version()}")

# Check GPU compute capability (need >= 8.0 for JIT)
capability = torch.cuda.get_device_capability()
print(f"GPU: {torch.cuda.get_device_name()}")
print(f"Compute capability: {capability[0]}.{capability[1]}")
assert capability[0] >= 8, "cuDNN JIT requires compute capability >= 8.0 (Ampere+)"

torch.manual_seed(42)

cuDNN frontend version: 1.14.1
cuDNN backend version: 91301
GPU: NVIDIA A10G
Compute capability: 8.6


<torch._C.Generator at 0x7f5f94d242b0>

In [2]:
handle = cudnn.create_handle()

# Define dimensions: A[B,M,K] @ B[B,K,N] = C[B,M,N]
BATCH, M, K, N = 4, 64, 128, 64

@cudnn.jit(heur_modes=[cudnn.heur_mode.A])
def create_matmul_graph():
    with cudnn.graph(
        handle=handle,
        name="jit_matmul_graph",
        io_data_type=cudnn.data_type.HALF,
        compute_data_type=cudnn.data_type.FLOAT,
    ) as (g, __unknown_var__):
        # Input tensors with row-major strides
        A = g.tensor(
            name="A",
            dim=[BATCH, M, K],
            stride=[M * K, K, 1],
            data_type=cudnn.data_type.HALF,
        )
        B_mat = g.tensor(
            name="B",
            dim=[BATCH, K, N],
            stride=[K * N, N, 1],
            data_type=cudnn.data_type.HALF,
        )
        
        # Matmul operation
        C = g.matmul(A, B_mat, compute_data_type=cudnn.data_type.FLOAT)
        C.set_output(True)
        tensor_uids = [A, B_mat, C]
        return g, tensor_uids

# Build the graph (JIT compilation happens here)
graph, (A_uid, B_uid, C_uid) = create_matmul_graph()
print(f"Graph built. Tensor UIDs: A={A_uid}, B={B_uid}, C={C_uid}")

Graph built. Tensor UIDs: A=2, B=1, C=3


In [3]:
# Verify JIT is being used
behavior_notes = graph.get_behavior_notes()
print(f"Behavior notes: {behavior_notes}")

is_jit = cudnn.behavior_note.RUNTIME_COMPILATION in behavior_notes
print(f"Using JIT (RUNTIME_COMPILATION): {is_jit}")

# Show available execution plans
n_plans = graph.get_execution_plan_count()
print(f"\nExecution plans available: {n_plans}")
for i in range(n_plans):
    notes = graph.get_behavior_notes_for_plan_at_index(i)
    name = graph.get_plan_name_at_index(i)
    is_jit = cudnn.behavior_note.RUNTIME_COMPILATION in notes
    print(f"  Plan {i}: {name[:50]}... JIT={is_jit}")


Behavior notes: []
Using JIT (RUNTIME_COMPILATION): False

Execution plans available: 15
  Plan 0: eng0... JIT=False
  Plan 1: eng7_k24=35... JIT=True
  Plan 2: eng7_k24=36... JIT=True
  Plan 3: eng7_k24=20... JIT=True
  Plan 4: eng1_k24=7... JIT=True
  Plan 5: eng7_k24=12... JIT=True
  Plan 6: eng7_k24=21... JIT=True
  Plan 7: eng7_k24=22... JIT=True
  Plan 8: eng1_k24=8... JIT=True
  Plan 9: eng1_k24=2... JIT=True
  Plan 10: eng1_k24=4... JIT=True
  Plan 11: eng1_k24=11... JIT=True
  Plan 12: eng1_k24=10... JIT=True
  Plan 13: eng1_k24=9... JIT=True
  Plan 14: eng1_k24=41... JIT=True


## CUDNN JIT is not used by default
With the full cuDNN BOTH pre-compiled and JIT engines are available. The heuristics pick what they think is fastest.

The above output proves this:
>Plan 0: eng0... JIT=False        ← Selected by default (pre-compiled)\
>Plan 1: eng7_k24=35... JIT=True  ← Available but not chosen\
>Plan 2: eng7_k24=36... JIT=True  ← Available but not chosen

The heuristics chose a pre-compiled engine as "best" for simple standalone matmul.

#### When does JIT actually get chosen by default?
JIT engines tend to be preferred when:
- Fused operations (matmul + bias + activation) - JIT can fuse into one kernel
- Unusual dimensions where no pre-compiled kernel exists
- Specific data type/layout combos not covered by pre-compiled kernels

TL;DR: JIT is available if ops are in the support surface, but not necessarily selected by default. Heuristics optimize for speed, and pre-compiled kernels can be faster for common patterns.


In [4]:
# Print graph structure
print(graph)


{
    "context": {
        "compute_data_type": "FLOAT",
        "intermediate_data_type": "FLOAT",
        "io_data_type": "HALF",
        "name": "",
        "sm_count": -1
    },
    "cudnn_backend_version": "9.13.1",
    "cudnn_frontend_version": 11401,
    "json_version": "1.0",
    "nodes": [
        {
            "compute_data_type": "FLOAT",
            "inputs": {
                "A": "A",
                "B": "B"
            },
            "name": "0",
            "outputs": {
                "C": "0::C"
            },
            "padding_value": 0.0,
            "tag": "MATMUL"
        }
    ],
    "tensors": {
        "0::C": {
            "data_type": "HALF",
            "dim": [4,64,64],
            "is_pass_by_value": false,
            "is_virtual": false,
            "name": "0::C",
            "pass_by_value": null,
            "reordering_type": "NONE",
            "stride": [4096,64,1],
            "uid": 3,
            "uid_assigned": true
        },
        "A": 

In [5]:
# Allocate GPU tensors
A_gpu = torch.randn(BATCH, M, K, device="cuda", dtype=torch.float16)
B_gpu = torch.randn(BATCH, K, N, device="cuda", dtype=torch.float16)
C_gpu = torch.empty(BATCH, M, N, device="cuda", dtype=torch.float16)
workspace = torch.empty(graph.get_workspace_size(), device="cuda", dtype=torch.uint8)

print(f"Workspace size: {graph.get_workspace_size()} bytes")

# Create variant pack using UIDs (returned by @cudnn.jit decorator)
variant_pack = {A_uid: A_gpu, B_uid: B_gpu, C_uid: C_gpu}


Workspace size: 4194560 bytes


In [6]:
# Execute the graph using UIDs
graph.execute(variant_pack, workspace, handle=handle)

# Verify against PyTorch
C_ref = torch.bmm(A_gpu, B_gpu)
max_diff = (C_gpu - C_ref).abs().max().item()

print(f"Output shape: {C_gpu.shape}")
print(f"Max diff vs torch.bmm: {max_diff}")
assert max_diff < 1e-2, f"Results don't match! Max diff: {max_diff}"
print("✓ Results match!")


Output shape: torch.Size([4, 64, 64])
Max diff vs torch.bmm: 0.0
✓ Results match!


### Timing Comparison

After JIT compilation, subsequent executions are fast. The compilation cost is paid once.


In [7]:
import time

# Warmup
for _ in range(10):
    graph.execute(variant_pack, workspace, handle=handle)
torch.cuda.synchronize()

# Time cuDNN
n_iters = 100
torch.cuda.synchronize()
start = time.perf_counter()
for _ in range(n_iters):
    graph.execute(variant_pack, workspace, handle=handle)
torch.cuda.synchronize()
cudnn_time = (time.perf_counter() - start) / n_iters * 1000

# Time PyTorch
torch.cuda.synchronize()
start = time.perf_counter()
for _ in range(n_iters):
    C_ref = torch.bmm(A_gpu, B_gpu)
torch.cuda.synchronize()
torch_time = (time.perf_counter() - start) / n_iters * 1000

print(f"cuDNN:   {cudnn_time:.4f} ms")
print(f"PyTorch: {torch_time:.4f} ms")
print(f"Speedup: {torch_time/cudnn_time:.2f}x")


cuDNN:   0.0199 ms
PyTorch: 0.0238 ms
Speedup: 1.20x


In [8]:
# Cleanup
cudnn.destroy_handle(handle)
