In this nb, I generate and inspect the ptx of my numba-cuda version of matmul_2.

In [1]:
import os

from numba import cuda, float32, int32
from util import cdiv, measure_runtime

In [2]:
@cuda.jit()
def matmul_2(a,b,c,m,n,k,bs):
    # we defined blocks of size bs*bs
    x = cuda.blockIdx.x * bs + (cuda.threadIdx.x // bs)
    y = cuda.blockIdx.y * bs + (cuda.threadIdx.x % bs)
    if x>=m or y>=n: return 
    tmp = 0
    for i in range(k): tmp += a[x,i] * b[i,y]
    c[x, y] = tmp

In [3]:
# signature
sig = (
    float32[:, :], float32[:, :], float32[:, :],
    int32, int32, int32,
    int32
)

In [4]:
ptx = cuda.compile_ptx_for_current_device(matmul_2, sig)

In [5]:
type(ptx), len(ptx)

(tuple, 2)

In [6]:
type(ptx[1]), ptx[1]

(numba.core.types.misc.NoneType, none)

In [7]:
ptx = ptx[0]

In [8]:
type(ptx)

str

In [9]:
print(ptx[:500])

//
// Generated by NVIDIA NVVM Compiler
//
// Compiler Build ID: CL-32688072
// Cuda compilation tools, release 12.1, V12.1.105
// Based on NVVM 7.0.1
//

.version 8.1
.target sm_75
.address_size 64

	// .globl	_ZN6cudapy8__main__8matmul_2B2v1B94cw51cXTLSUwv1sCUt9Uw11Ew1dRRQPKzLTg4gaGKFsG2oMQGEYakJSQB1PQBk0Bynm21OiwU1a0UoLGhDpQE8oxrNQE_3dE5ArrayIfLi2E1A7mutable7alignedE5ArrayIfLi2E1A7mutable7alignedE5ArrayIfLi2E1A7mutable7alignedEiiii
.visible .global .align 4 .u32 _ZN6cudapy8__main__8matmul_2B2


The ptx contains variables with very long names ('mangled'; eg `_ZN6cudapy8__main__8matmul_2B2v2B94cw51cXTLS ...`). Let's shorten the names for better readability.

In [10]:
import re

def replace_mangled_variables(ptx_str):    
    pattern = r'(_Z[\w\d]{45,})' # Regular expression to find mangled variable names longer than 45 characters
    replacements = {}
    counter = 1

    # Replace each mangled name with a shorter identifier
    def replace_match(match):
        nonlocal counter
        mangled_name = match.group(1)
        if mangled_name not in replacements:
            replacements[mangled_name] = f"mangled_var_{counter}"
            counter += 1
        return replacements[mangled_name]

    # Replace all occurrences using the regular expression pattern
    new_ptx = re.sub(pattern, replace_match, ptx_str)
    return new_ptx

In [11]:
ptx_numba_orig = ptx
ptx_numba = replace_mangled_variables(ptx)

In [12]:
print(ptx_numba[:500])

//
// Generated by NVIDIA NVVM Compiler
//
// Compiler Build ID: CL-32688072
// Cuda compilation tools, release 12.1, V12.1.105
// Based on NVVM 7.0.1
//

.version 8.1
.target sm_75
.address_size 64

	// .globl	mangled_var_1
.visible .global .align 4 .u32 mangled_var_2;
.visible .global .align 4 .u32 mangled_var_3;
.visible .global .align 4 .u32 mangled_var_4;
.visible .global .align 4 .u32 mangled_var_5;
.visible .global .align 4 .u32 mangled_var_6;
.visible .global .align 4 .u32 mangled_var_7;


Let's now also generate the ptx for the cuda-c code

In [13]:
import os
import subprocess
from pathlib import Path

def generate_ptx_from_cudac(code, out_dir='tmp', out_cu='temp.cu', out_ptx='outp.ptx', cleanup=False):
    out_cu, out_ptx = Path(out_dir)/out_cu, Path(out_dir)/out_ptx
    # Write code to file
    with open(out_cu, 'w') as cu_file: cu_file.write(code) 
    # Compile
    nvcc_command = [
        'nvcc',
        '-ptx', out_cu,
        '-o', out_ptx, 
        '-arch', 'sm_75',  # T4 gpu has compute capability 7.5
    ]
    try: subprocess.run(nvcc_command, check=True)
    except subprocess.CalledProcessError as e:
        print(f"Compilation failed with error: {e}")
        return None
    # Read ptx
    with open(out_ptx, 'r') as ptx_file: ptx = ptx_file.read()
    # Cleanup
    if cleanup:
        os.remove(out_cu)
        os.remove(out_ptx)
    # Return
    return ptx

In [14]:
matmul_2_cudac = '''
#include <cuda.h>
#include <cuda_runtime.h>

template <const uint BLOCKSIZE>
__global__ void matmul_global_mem_coalesce(const float *A, const float *B, float *C, int M, int N, int K) {
  const int cRow = blockIdx.x * BLOCKSIZE + (threadIdx.x / BLOCKSIZE);
  const int cCol = blockIdx.y * BLOCKSIZE + (threadIdx.x % BLOCKSIZE);

  if (cRow < M && cCol < N) {
    float tmp = 0.0;
    for (int i = 0; i < K; ++i) { tmp += A[cRow * K + i] * B[i * N + cCol]; }
    C[cRow * N + cCol] = tmp;
  }
}

inline unsigned int cdiv(unsigned int a, unsigned int b) { return (a + b - 1) / b; }

void matmul(int M, int N, int K) {
    constexpr uint bs = 32;

    // Allocate memory for A,B,C on device
    float *d_A, *d_B, *d_C;
    cudaMalloc((void **)&d_A, M * K * sizeof(float));
    cudaMalloc((void **)&d_B, K * N * sizeof(float));
    cudaMalloc((void **)&d_C, M * N * sizeof(float));

    // Initialize A,B to ones
    cudaMemset(d_A, 1, M * K * sizeof(float));
    cudaMemset(d_B, 1, K * N * sizeof(float));

    // Initialize C to zeros
    cudaMemset(d_C, 0, M * N * sizeof(float));

    // Configure the grid and block dimensions
    dim3 tpb(bs * bs);
    dim3 blocks(cdiv(M, bs), cdiv(N, bs));

    // Launch the matrix multiplication kernel
    matmul_global_mem_coalesce<bs><<<blocks, tpb>>>(d_A, d_B, d_C, M, N, K);

    // Free device memory
    cudaFree(d_A);
    cudaFree(d_B);
    cudaFree(d_C);
}
'''

In [15]:
ptx_cudac = generate_ptx_from_cudac(matmul_2_cudac)

In [16]:
print(ptx_cudac[:500])

//
// Generated by NVIDIA NVVM Compiler
//
// Compiler Build ID: CL-32688072
// Cuda compilation tools, release 12.1, V12.1.105
// Based on NVVM 7.0.1
//

.version 8.1
.target sm_75
.address_size 64

	// .globl	_Z26matmul_global_mem_coalesceILj32EEvPKfS1_Pfiii

.visible .entry _Z26matmul_global_mem_coalesceILj32EEvPKfS1_Pfiii(
	.param .u64 _Z26matmul_global_mem_coalesceILj32EEvPKfS1_Pfiii_param_0,
	.param .u64 _Z26matmul_global_mem_coalesceILj32EEvPKfS1_Pfiii_param_1,
	.param .u64 _Z26matmul_glo


In [17]:
ptx_cudac_orig = ptx_cudac
ptx_cudac = replace_mangled_variables(ptx_cudac)

In [18]:
print(ptx_cudac[:500])

//
// Generated by NVIDIA NVVM Compiler
//
// Compiler Build ID: CL-32688072
// Cuda compilation tools, release 12.1, V12.1.105
// Based on NVVM 7.0.1
//

.version 8.1
.target sm_75
.address_size 64

	// .globl	mangled_var_1

.visible .entry mangled_var_1(
	.param .u64 mangled_var_2,
	.param .u64 mangled_var_3,
	.param .u64 mangled_var_4,
	.param .u32 mangled_var_5,
	.param .u32 mangled_var_6,
	.param .u32 mangled_var_7
)
{
	.reg .pred 	%p<9>;
	.reg .f32 	%f<30>;
	.reg .b32 	%r<34>;
	.reg .b64 	


In [19]:
os.makedirs('ptx', exist_ok=True)

with open('ptx/matmul_2_numba.ptx', 'w') as f: f.write(ptx_numba)
with open('ptx/matmul_2_cudac.ptx', 'w') as f: f.write(ptx_cudac)

with open('ptx/matmul_2_numba_orig.ptx', 'w') as f: f.write(ptx_numba_orig)
with open('ptx/matmul_2_cudac_orig.ptx', 'w') as f: f.write(ptx_cudac_orig)