In this nb, I run the numba-cuda version of matmul_2 in debugging mode, and check the types.

**Hypothesis:** Unwanted typecasting is done, which might be the reason the code is slow.<br/>
**Result:** Yes! `tmp` is cast to fp64, which makes the kernel slow. Defining it as fp32 makes the kernel fast.

In [1]:
import os
os.environ['NUMBA_ENABLE_CUDASIM']='1'
os.environ['CUDA_LAUNCH_BLOCKING']='1'

import numpy as np
from numba import cuda
from util import cdiv, to_d, to_h

from fastcore.basics import strcat, get_class

dtype='float32'

In [2]:
def bits_per_item(mat): return mat._item.nbytes/(mat._item.shape[0]*mat._item.shape[1]) # to check if matrices are fp32 (4bytes per item)

def types(*os): print(strcat((type(o) for o in os), ', '))

In [3]:
m,k,n = 2,3,4
bs=2

a = to_d(np.ones((m,k), dtype=dtype))
b = to_d(np.ones((k,n), dtype=dtype))
c = to_d(np.empty((m,n), dtype=dtype))

nthreads = bs*bs
nblocks = cdiv(c.shape, (bs,bs))

print(m,k,n,' # ',bs,' # ',nblocks,nthreads)
print(a)
print(b)

2 3 4  #  2  #  (1, 2) 4
[[1. 1. 1.]
 [1. 1. 1.]]
[[1. 1. 1. 1.]
 [1. 1. 1. 1.]
 [1. 1. 1. 1.]]


In [44]:
@cuda.jit()
def matmul_2(a,b,c,m,n,k,bs):
    # we defined blocks of size bs*bs
    x = cuda.blockIdx.x * bs + (cuda.threadIdx.x // bs)
    y = cuda.blockIdx.y * bs + (cuda.threadIdx.x % bs)
    if x>=m or y>=n: return 
    tmp = 0
    for i in range(k): tmp += a[x,i] * b[i,y]
    c[x, y] = tmp
    if (x,y)==(0,0): raise RuntimeError('Let us inspect')

In [45]:
matmul_2[nblocks,nthreads](a,b,c,m,n,k,bs)

RuntimeError: tid=[0, 0, 0] ctaid=[0, 0, 0]: Let us inspect

This is the debug output:

```py
%debug

ipdb>  print(a._item.nbytes)
24
ipdb>  bits_per_item(a), bits_per_item(b), bits_per_item(c)
(4.0, 4.0, 4.0) ## note: fp32 = 4bytes, so this is good
ipdb>  type(tmp)
<class 'numpy.float64'>
ipdb>  i=k-1
ipdb>  i
2
ipdb>  types(a[x,i], b[i,y], a[x,i] * b[i,y])
<class 'numpy.float32'>, <class 'numpy.float32'>, <class 'numpy.float32'>
ipdb>  type(tmp+a[x,i] * b[i,y])
<class 'numpy.float64'>
ipdb>  type(0+a[x,i] * b[i,y])
<class 'numpy.float64'>
ipdb>  type(0.0+a[x,i] * b[i,y])
<class 'numpy.float64'>
```
The accumulator is casted to fp64!

In [7]:
@cuda.jit()
def matmul_2_explicit_typedef(a,b,c,m,n,k,bs):
    # we defined blocks of size bs*bs
    x = cuda.blockIdx.x * bs + (cuda.threadIdx.x // bs)
    y = cuda.blockIdx.y * bs + (cuda.threadIdx.x % bs)
    if x>=m or y>=n: return 
    tmp = np.float32(0)  # Initialize tmp explicitly as fp32
    for i in range(k): tmp += a[x,i] * b[i,y]
    c[x, y] = tmp
    if (x,y)==(0,0): raise RuntimeError('Let us inspect')

In [8]:
matmul_2_explicit_typedef[nblocks,nthreads](a,b,c,m,n,k,bs)

RuntimeError: tid=[0, 0, 0] ctaid=[0, 0, 0]: Let us inspect

This is the debug output:
```
%debug

ipdb>  type(tmp)
<class 'numpy.float32'>
```

___

Even with `tmp` as fp32, matmul_3 is slow. Let's check it's types

In [4]:
from numba import float32

In [5]:
@cuda.jit()
def matmul_3_bs32(a,b,c,m,n,k):
    bs = 32
    bx,by = cuda.blockIdx.x, cuda.blockIdx.y
    tx,ty = cuda.threadIdx.x//bs, cuda.threadIdx.x%bs
    sh_a, sh_b = cuda.shared.array((bs,bs), float32), cuda.shared.array((bs,bs), float32)        
    tmp = np.float32(0)
    nk = (k+bs-1)//bs
    for bk in range(nk):
        for i in range(bs):
            if (bx*bs+tx<m) and (bk*bs+i<k) and (by*bs+ty<n):
                sh_a[tx,ty] = a[bx*bs+tx, bk*bs+i ]
                sh_b[tx,ty] = b[bk*bs+i , by*bs+ty]
        cuda.syncthreads()
        for i in range(bs):
            if (bx*bs+tx <m) and (bk*bs+i<k) and (by*bs+ty<n):
                tmp += sh_a[tx,i]*sh_b[i,ty]
        cuda.syncthreads()
    if bx*bs+tx<m and by*bs+ty<n: c[bx*bs+tx,by*bs+ty] = tmp
    if (bx,by,tx,ty)==(0,0,0,0): raise RuntimeError('Let us inspect')

In [6]:
matmul_3_bs32[nblocks,nthreads](a,b,c,m,n,k)

RuntimeError: tid=[0, 0, 0] ctaid=[0, 0, 0]: Let us inspect

In [None]:
%debug

> [0;32m/tmp/ipykernel_111165/2261411280.py[0m(20)[0;36mmatmul_3_bs32[0;34m()[0m
[0;32m     16 [0;31m            [0;32mif[0m [0;34m([0m[0mbx[0m[0;34m*[0m[0mbs[0m[0;34m+[0m[0mtx[0m [0;34m<[0m[0mm[0m[0;34m)[0m [0;32mand[0m [0;34m([0m[0mbk[0m[0;34m*[0m[0mbs[0m[0;34m+[0m[0mi[0m[0;34m<[0m[0mk[0m[0;34m)[0m [0;32mand[0m [0;34m([0m[0mby[0m[0;34m*[0m[0mbs[0m[0;34m+[0m[0mty[0m[0;34m<[0m[0mn[0m[0;34m)[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     17 [0;31m                [0mtmp[0m [0;34m+=[0m [0msh_a[0m[0;34m[[0m[0mtx[0m[0;34m,[0m[0mi[0m[0;34m][0m[0;34m*[0m[0msh_b[0m[0;34m[[0m[0mi[0m[0;34m,[0m[0mty[0m[0;34m][0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     18 [0;31m        [0mcuda[0m[0;34m.[0m[0msyncthreads[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     19 [0;31m    [0;32mif[0m [0mbx[0m[0;34m*[0m[0mbs[0m[0;34m+[0m[0mtx[0m[0;34m<[0m[0mm[0m [0;32man

ipdb>  bits_per_item(a),bits_per_item(b),bits_per_item(c),


(4.0, 4.0, 4.0)


ipdb>  sh_a.dtype, sh_b.dtype


(dtype('float32'), dtype('float32'))


ipdb>  types(bx,by,tx,ty)


<class 'int'>, <class 'int'>, <class 'int'>, <class 'int'>


ipdb>  type(tmp)


<class 'numpy.float32'>


ipdb>  type(c[bx*bs+tx,by*bs+ty])


<class 'numpy.float32'>


ipdb>  types(a[0,0], b[0,0], c[0,0])


<class 'numpy.float32'>, <class 'numpy.float32'>, <class 'numpy.float32'>


ipdb>  nk


1


```
%debug

ipdb>  bits_per_item(a),bits_per_item(b),bits_per_item(c)
(4.0, 4.0, 4.0)
ipdb>  types(a[0,0], b[0,0], c[0,0])
<class 'numpy.float32'>, <class 'numpy.float32'>, <class 'numpy.float32'>
ipdb>  sh_a.dtype, sh_b.dtype
(dtype('float32'), dtype('float32'))
ipdb>  types(bx,by,tx,ty)
<class 'int'>, <class 'int'>, <class 'int'>, <class 'int'>
```