In [1]:
import cudnn
import torch

print(f"cudnn frontend version: {cudnn.__version__}")
print(f"cudnn backend version: {cudnn.backend_version()}")

torch.manual_seed(42)
assert torch.cuda.is_available()


cudnn frontend version: 1.14.1
cudnn backend version: 91301


In [2]:
# Create handle and graph
handle = cudnn.create_handle()

graph = cudnn.pygraph(
    handle=handle,
    name="matmul_graph",
    io_data_type=cudnn.data_type.HALF,
    compute_data_type=cudnn.data_type.FLOAT,
)


####  Tensors and stride

A "fully-packed" tensor has an unintuitive definition found in the front-end docs (https://docs.nvidia.com/deeplearning/cudnn/frontend/v1.14.1/developer/core-concepts.html#fully-packed-tensors), as an example:


The stride is how many elements along the flattened tensor you need to traverse to get to the next element of that dimension. For example, for our batch of B MxK matrices, we can see the stride is [M x K, K, 1]. To get to the same index (m, k) in the next matrix in the batch we need to move forward an entire matrix worth of elements (MxK) elements. To get to the element in the same column in the next row of a given matrix, we need to move forward K elements. To get to the element in the next column we need to move forward 1 element. As there are no gaps in the vector (when written sequentially in memory) this tensor is "fully packed".

In [3]:
# Define dimensions: A[B,M,K] @ B[B,K,N] = C[B,M,N]
B, M, K, N = 2, 4, 8, 6

# Input tensors (row-major strides)
A = graph.tensor(
    name="A",
    dim=[B, M, K],
    stride=[M * K, K, 1], 
    data_type=cudnn.data_type.HALF,
)
B_tensor = graph.tensor(
    name="B",
    dim=[B, K, N],
    stride=[K * N, N, 1],
    data_type=cudnn.data_type.HALF,
)

# Matmul operation
C = graph.matmul(A, B_tensor, compute_data_type=cudnn.data_type.FLOAT)
C.set_output(True)


[{"data_type":null,"dim":[],"is_pass_by_value":false,"is_virtual":false,"name":"0::C","pass_by_value":null,"reordering_type":"NONE","stride":[],"uid":0,"uid_assigned":false}]

In [4]:
# Build graph
graph.build([cudnn.heur_mode.A])
print(graph)


{
    "context": {
        "compute_data_type": "FLOAT",
        "intermediate_data_type": null,
        "io_data_type": "HALF",
        "name": "",
        "sm_count": -1
    },
    "cudnn_backend_version": "9.13.1",
    "cudnn_frontend_version": 11401,
    "json_version": "1.0",
    "nodes": [
        {
            "compute_data_type": "FLOAT",
            "inputs": {
                "A": "A",
                "B": "B"
            },
            "name": "0",
            "outputs": {
                "C": "0::C"
            },
            "padding_value": 0.0,
            "tag": "MATMUL"
        }
    ],
    "tensors": {
        "0::C": {
            "data_type": "HALF",
            "dim": [2,4,6],
            "is_pass_by_value": false,
            "is_virtual": false,
            "name": "0::C",
            "pass_by_value": null,
            "reordering_type": "NONE",
            "stride": [24,6,1],
            "uid": 3,
            "uid_assigned": true
        },
        "A": {
      

In [5]:
# Create GPU tensors
A_gpu = torch.randn(2, 4, 8, device="cuda", dtype=torch.float16)
B_gpu = torch.randn(2, 8, 6, device="cuda", dtype=torch.float16)
C_gpu = torch.empty(2, 4, 6, device="cuda", dtype=torch.float16)
workspace = torch.empty(graph.get_workspace_size(), device="cuda", dtype=torch.uint8)

# Execute
graph.execute({A: A_gpu, B_tensor: B_gpu, C: C_gpu}, workspace, handle=handle)


In [6]:
# Print output and verify against torch
print("CUDNN output:")
print(C_gpu)

print("\nTorch reference:")
C_ref = torch.bmm(A_gpu, B_gpu)
print(C_ref)

print(f"\nMax diff: {(C_gpu - C_ref).abs().max().item()}")

CUDNN output:
tensor([[[-1.7402,  1.4990,  0.7959,  2.1348, -4.0508, -5.3633],
         [-0.3716,  0.9287,  1.0342, -0.7207, -0.2500,  0.7520],
         [ 1.1025, -2.2871,  0.0893, -5.5469,  0.5498,  2.8105],
         [ 2.0742,  0.4377, -1.0098,  0.8345, -3.5254,  2.3359]],

        [[ 4.1719,  0.3274,  2.6504,  2.3926,  1.7188, -1.6426],
         [-0.8267, -1.5254, -0.2050, -1.7324,  3.1934, -1.9678],
         [-0.5469,  2.1465, -0.0419,  1.1553, -1.7520, -0.1475],
         [-0.6406,  0.6064, -1.8369, -0.5371, -0.2876, -1.0762]]],
       device='cuda:0', dtype=torch.float16)

Torch reference:
tensor([[[-1.7402,  1.4990,  0.7959,  2.1348, -4.0508, -5.3633],
         [-0.3716,  0.9287,  1.0342, -0.7207, -0.2500,  0.7520],
         [ 1.1025, -2.2871,  0.0893, -5.5469,  0.5498,  2.8105],
         [ 2.0742,  0.4377, -1.0098,  0.8345, -3.5254,  2.3359]],

        [[ 4.1719,  0.3274,  2.6504,  2.3926,  1.7188, -1.6426],
         [-0.8267, -1.5254, -0.2050, -1.7324,  3.1934, -1.9678],
       

In [7]:
# Cleanup
cudnn.destroy_handle(handle)
