Reference: https://zhuanlan.zhihu.com/p/410278370, https://zhuanlan.zhihu.com/p/435908830, https://github.com/yzhaiustc/Optimizing-SGEMM-on-NVIDIA-Turing-GPUs

matrix_multiplication1: Read the original matrix into shared memory in blocks based on the size of SUBK iteratively.  
matrix_multiplication2: Process multiple elements in one thread.  
matrix_multiplication3: Use register (No improvement).  
matrix_multiplication4: pre fetch (double buffer in share memory).
matrix_multiplication5: shared memory 1D to 2D.
matrix_multiplication6: shared memory 1D to 3D.(pre fetch)

In [None]:
%%writefile matrix_multiplication.cu
#include <stdio.h>
#include <stdlib.h>
#include <type_traits>
#include <thrust/host_vector.h>
#include <thrust/device_vector.h>
#include <thrust/sequence.h>
#include <random>

#define TYPE int
// #define M 4
// #define K 3
// #define N 2

#define M 67
#define K 66
#define N 75

#define BLOCK_SIZE 32

#define BLOCK_SIZE_M BLOCK_SIZE
#define BLOCK_SIZE_N BLOCK_SIZE
#define NUM_PER_THREAD_M 8
#define NUM_PER_THREAD_N 1
#define DIVIDE_M (BLOCK_SIZE_M / NUM_PER_THREAD_M)
#define DIVIDE_N (BLOCK_SIZE_N / NUM_PER_THREAD_N)

// #define SUBM 1
// #define SUBM_NUM (M / SUBM)
#define SUBK BLOCK_SIZE
// #define SUBK_NUM (K / SUBK)
// #define SUBN 1
// #define SUBN_NUM (N / SUBN)

// M 4 K 3
//  1  2  3
//  4  5  6
//  7  8  9
// 10 11 12

// K 3 N 2
// 1 4
// 2 5
// 3 6

// M 4 N 2
// 1*1 + 2*2 + 3*3 = 14         1*4 + 2*5 + 3*6 = 32
// 4*1 + 5*2 + 6*3 = 32         4*4 + 5*5 + 6*6 = 77
// 7*1 + 8*2 + 9*3 = 50         7*4 + 8*5 + 9*6 = 122
// 10*1 + 11*2 + 12*3 = 68      10*4 + 11*5 + 12*6 = 167
//  x  x
// y
// y
// y
// y

__global__ void
warm_up()
{
    int indexX = threadIdx.x + blockIdx.x * blockDim.x;
    int indexY = threadIdx.y + blockIdx.y * blockDim.y;
    if (indexX < M && indexY < N)
    {
        float a = 0.0f;
        float b = 1.0f;
        float c = a + b;
    }
}

template <typename T, typename = std::enable_if_t<std::is_arithmetic<T>::value>>
__global__ void matrix_multiplication0(const T *a, const T *b, T *c)
{
    int idX = blockIdx.x * blockDim.x + threadIdx.x;
    int idY = blockIdx.y * blockDim.y + threadIdx.y;
    if (idY < M && idX < N)
    {
        T cTmp = 0;
        for (int i = 0; i < K; ++i)
        {
            cTmp += a[idY * K + i] * b[i * N + idX];
        }
        c[idY * N + idX] = cTmp;
    }
}

template <typename T, typename = std::enable_if_t<std::is_arithmetic<T>::value>>
__global__ void matrix_multiplication1(const T *a, const T *b, T *c)
{
    int idX = blockIdx.x * blockDim.x + threadIdx.x;
    int idY = blockIdx.y * blockDim.y + threadIdx.y;

    int tidX = threadIdx.x;
    int tidY = threadIdx.y;

    __shared__ T as[BLOCK_SIZE_M][SUBK];
    __shared__ T bs[SUBK][BLOCK_SIZE_N];

    {
        T cTmp = 0;

        // read a and b into shared memory by SUBK
        // SUBK * BLOCK_SIZE_M 1, SUBK * BLOCK_SIZE_M 2, SUBK * BLOCK_SIZE_M 3...
        // BLOCK_SIZE_N * SUBK 1, BLOCK_SIZE_N * SUBK 2, BLOCK_SIZE_N * SUBK 3...
        for (int i = 0; i < K; i += SUBK)
        {
            int idXTmp = i + tidX;
            if (idY < M && idXTmp < K)
                as[tidY][tidX] = a[idY * K + idXTmp]; // tidX < SUBK, tidY < BLOCK_SIZE_M
            int idYTmp = i + tidY;
            if (idYTmp < K && idX < N)
                bs[tidY][tidX] = b[idYTmp * N + idX]; // tidX < BLOCK_SIZE_N, tidY < SUBK
            __syncthreads();
            for (int j = 0; j < SUBK; ++j)
            {
                // 0  0  1  1
                // 0  0  1  1
                // 2  2  3  3
                // 2  2  3  3
                // i0    i1
                // j0 j1 j0 j1
                if (i + j < K)
                    cTmp += as[tidY][j] * bs[j][tidX]; // why no bank conflict? as[0,1...][tidX] = 0 yes, as[0][tidX] = 0 no.  different position in same bank will cause bank conflict
                // if(idX == 0 && idY == 0)
                // {
                //   printf("i: %d, j: %d, idX: %d, idY: %d, as: %d, bs: %d, cTmp: %d\n", i, j, idX, idY, as[tidY * BLOCK_SIZE_M + j], bs[j * BLOCK_SIZE_N + tidX], cTmp);
                // }
                // if(idX == 8 && idY == 0)
                // {
                //   printf("i: %d, j: %d, idX: %d, idY: %d, as: %d, bs: %d, cTmp: %d\n", i, j, idX, idY, as[tidY * BLOCK_SIZE_M + j], bs[j * BLOCK_SIZE_N + tidX], cTmp);
                // }
            }
            __syncthreads();
        }

        if (idY < M && idX < N)
            c[idY * N + idX] = cTmp;
    }
}

template <typename T, typename = std::enable_if_t<std::is_arithmetic<T>::value>>
__global__ void matrix_multiplication2(const T *a, const T *b, T *c)
{
    int idX = blockIdx.x * (blockDim.x * NUM_PER_THREAD_N) + threadIdx.x;
    int idY = blockIdx.y * (blockDim.y * NUM_PER_THREAD_M) + threadIdx.y;

    int tidX = threadIdx.x;
    int tidY = threadIdx.y;

    __shared__ T as[BLOCK_SIZE_M][SUBK];
    __shared__ T bs[SUBK][BLOCK_SIZE_N];

    {
        // T cTmp = 0;
        T cTmp[NUM_PER_THREAD_M][NUM_PER_THREAD_N] = {0};

        for (int i = 0; i < K; i += SUBK)
        {
#pragma unroll
            for (int m = 0; m < BLOCK_SIZE_M; m += DIVIDE_M)
            {
#pragma unroll
                for (int n = 0; n < BLOCK_SIZE_N; n += DIVIDE_N)
                {
                    int idXTmp = i + tidX + n;
                    int idYTmp = idY + m;
                    if (idYTmp < M && idXTmp < K)
                        as[tidY + m][tidX + n] = a[idYTmp * K + idXTmp];
                    // printf("i: %d, tidX: %d, tidY: %d, m: %d, (tidY + m) * BLOCK_SIZE_M + (tidX + n): %d, as: %d\n", i, tidX, tidY, m, (tidY + m) * BLOCK_SIZE_M + (tidX + n), as[(tidY + m) * BLOCK_SIZE_M + (tidX + n)]);
                }
                // as[(tidY + m) * BLOCK_SIZE_M + tidX] = a[(idY + m) * K + i + tidX];
                // printf("i: %d, tidX: %d, tidY: %d, m: %d, (tidY + m) * BLOCK_SIZE_M + tidX: %d, as: %d\n", i, tidX, tidY, m, (tidY + m) * BLOCK_SIZE_M + tidX, as[(tidY + m) * BLOCK_SIZE_M + tidX]);
            }

#pragma unroll
            for (int n = 0; n < BLOCK_SIZE_N; n += DIVIDE_N)
            {
#pragma unroll
                for (int m = 0; m < BLOCK_SIZE_M; m += DIVIDE_M)
                {
                    int idXTmp = idX + n;
                    int idYTmp = i + tidY + m;
                    if (idYTmp < K && idXTmp < N)
                        bs[tidY + m][tidX + n] = b[idYTmp * N + idXTmp];
                        
                    // printf("i: %d, tidX: %d, tidY: %d, m: %d, n: %d, (tidY + m) * BLOCK_SIZE_N + (tidX + n): %d, bs: %d\n", i, tidX, tidY, m, n, (tidY + m) * BLOCK_SIZE_N + (tidX + n), bs[(tidY + m) * BLOCK_SIZE_N + (tidX + n)]);
                }
                // bs[tidY * BLOCK_SIZE_N + (tidX + n)] = b[(i + tidY) * N + idX + n];
                // printf("i: %d, tidX: %d, tidY: %d, n: %d, tidY * BLOCK_SIZE_N + (tidX + n): %d, bs: %d\n", i, tidX, tidY, n, tidY * BLOCK_SIZE_N + (tidX + n), bs[tidY * BLOCK_SIZE_N + (tidX + n)]);
            }

            __syncthreads();
            for (int j = 0; j < SUBK; ++j)
            {
#pragma unroll
                for (int m = 0; m < NUM_PER_THREAD_M; ++m)
                {
#pragma unroll
                    for (int n = 0; n < NUM_PER_THREAD_N; ++n)
                    {
                        if (i + j < K)
                        {
                            // if ((idY + m * DIVIDE_M) < M && (idX + n * DIVIDE_N) < N && i + j < K)
                            cTmp[m][n] += as[(tidY + m * DIVIDE_M)][j] * bs[j][(tidX + n * DIVIDE_N)];
                            // if (idX == 0 && idY == 0)
                            // {
                            //     printf("i: %d, j: %d, m: %d, n: %d, idX: %d, idY: %d, as: %d, bs: %d, cTmp: %d\n", i, j, m, n, idX, idY, as[(tidY + m * DIVIDE_M)][j], bs[j][(tidX + n * DIVIDE_N)], cTmp[m][n]);
                            // }
                        }
                    }
                }
            }
            __syncthreads();
        }

#pragma unroll
        for (int m = 0; m < NUM_PER_THREAD_M; ++m)
        {
#pragma unroll
            for (int n = 0; n < NUM_PER_THREAD_N; ++n)
            {
                if ((idY + m * DIVIDE_M) < M && (idX + n * DIVIDE_N) < N)
                    c[(idY + m * DIVIDE_M) * N + idX + n * DIVIDE_N] = cTmp[m][n];
            }
        }

        // c[idY * N + idX] = cTmp;
    }
}

template <typename T, typename = std::enable_if_t<std::is_arithmetic<T>::value>>
__global__ void matrix_multiplication3(const T *a, const T *b, T *c)
{
    int idX = blockIdx.x * (blockDim.x * NUM_PER_THREAD_M) + threadIdx.x;
    int idY = blockIdx.y * (blockDim.y * NUM_PER_THREAD_N) + threadIdx.y;

    int tidX = threadIdx.x;
    int tidY = threadIdx.y;

    __shared__ T as[BLOCK_SIZE_M * SUBK];
    __shared__ T bs[SUBK * BLOCK_SIZE_N];

    {
        // T cTmp = 0;
        T cTmp[NUM_PER_THREAD_M][NUM_PER_THREAD_N];
#pragma unroll
        for (int m = 0; m < NUM_PER_THREAD_M; ++m)
        {
#pragma unroll
            for (int n = 0; n < NUM_PER_THREAD_N; ++n)
            {
                cTmp[m][n] = 0;
            }
        }

        for (int i = 0; i < K; i += SUBK)
        {
#pragma unroll
            for (int m = 0; m < BLOCK_SIZE_M; m += DIVIDE_M)
            {
#pragma unroll
                for (int n = 0; n < BLOCK_SIZE_N; n += DIVIDE_N)
                {
                    if (i + tidX + n < M)
                        as[(tidY + m) * SUBK + (tidX + n)] = a[(idY + m) * K + i + (tidX + n)];
                    // printf("i: %d, tidX: %d, tidY: %d, m: %d, (tidY + m) * BLOCK_SIZE_M + (tidX + n): %d, as: %d\n", i, tidX, tidY, m, (tidY + m) * BLOCK_SIZE_M + (tidX + n), as[(tidY + m) * BLOCK_SIZE_M + (tidX + n)]);
                }
                // as[(tidY + m) * BLOCK_SIZE_M + tidX] = a[(idY + m) * K + i + tidX];
                // printf("i: %d, tidX: %d, tidY: %d, m: %d, (tidY + m) * BLOCK_SIZE_M + tidX: %d, as: %d\n", i, tidX, tidY, m, (tidY + m) * BLOCK_SIZE_M + tidX, as[(tidY + m) * BLOCK_SIZE_M + tidX]);
            }

#pragma unroll
            for (int n = 0; n < BLOCK_SIZE_N; n += DIVIDE_N)
            {
#pragma unroll
                for (int m = 0; m < BLOCK_SIZE_M; m += DIVIDE_M)
                {
                    if (i + tidY + m < N)
                        bs[(tidY + m) * BLOCK_SIZE_N + (tidX + n)] = b[(i + tidY + m) * N + idX + n];
                    // printf("i: %d, tidX: %d, tidY: %d, m: %d, n: %d, (tidY + m) * BLOCK_SIZE_N + (tidX + n): %d, bs: %d\n", i, tidX, tidY, m, n, (tidY + m) * BLOCK_SIZE_N + (tidX + n), bs[(tidY + m) * BLOCK_SIZE_N + (tidX + n)]);
                }
                // bs[tidY * BLOCK_SIZE_N + (tidX + n)] = b[(i + tidY) * N + idX + n];
                // printf("i: %d, tidX: %d, tidY: %d, n: %d, tidY * BLOCK_SIZE_N + (tidX + n): %d, bs: %d\n", i, tidX, tidY, n, tidY * BLOCK_SIZE_N + (tidX + n), bs[tidY * BLOCK_SIZE_N + (tidX + n)]);
            }

            __syncthreads();
            for (int j = 0; j < SUBK; ++j)
            {
                T ar;
                T br;
#pragma unroll
                for (int m = 0; m < NUM_PER_THREAD_M; ++m)
                {
                    ar = as[(tidY + m * DIVIDE_M) * SUBK + j];
#pragma unroll
                    for (int n = 0; n < NUM_PER_THREAD_N; ++n)
                    {
                        br = bs[j * BLOCK_SIZE_N + (tidX + n * DIVIDE_N)];
                        if (i + j < K)
                            // if ((idY + m * DIVIDE_M) < M && (idX + n * DIVIDE_N) < N && i + j < K)
                            cTmp[m][n] += ar * br;
                    }
                }
            }
            __syncthreads();
        }

#pragma unroll
        for (int m = 0; m < NUM_PER_THREAD_M; ++m)
        {
#pragma unroll
            for (int n = 0; n < NUM_PER_THREAD_N; ++n)
            {
                if ((idY + m * DIVIDE_M) < M && (idX + n * DIVIDE_N) < N)
                    c[(idY + m * DIVIDE_M) * N + idX + n * DIVIDE_N] = cTmp[m][n];
            }
        }

        // c[idY * N + idX] = cTmp;
    }
}

template <typename T, typename = std::enable_if_t<std::is_arithmetic<T>::value>>
__global__ void matrix_multiplication4(const T *a, const T *b, T *c)
{
    int idX = blockIdx.x * (blockDim.x * NUM_PER_THREAD_M) + threadIdx.x;
    int idY = blockIdx.y * (blockDim.y * NUM_PER_THREAD_N) + threadIdx.y;

    int tidX = threadIdx.x;
    int tidY = threadIdx.y;

    __shared__ T as[BLOCK_SIZE_M * SUBK * 2];
    __shared__ T bs[SUBK * BLOCK_SIZE_N * 2];

    {
        // T cTmp = 0;
        T cTmp[NUM_PER_THREAD_M][NUM_PER_THREAD_N];
#pragma unroll
        for (int m = 0; m < NUM_PER_THREAD_M; ++m)
        {
#pragma unroll
            for (int n = 0; n < NUM_PER_THREAD_N; ++n)
            {
                cTmp[m][n] = 0;
            }
        }

#pragma unroll
        for (int m = 0; m < BLOCK_SIZE_M; m += DIVIDE_M)
        {
#pragma unroll
            for (int n = 0; n < BLOCK_SIZE_N; n += DIVIDE_N)
            {
                if (tidX + n < M)
                    as[(tidY + m) * SUBK + (tidX + n)] = a[(idY + m) * K + (tidX + n)];
            }
        }

#pragma unroll
        for (int n = 0; n < BLOCK_SIZE_N; n += DIVIDE_N)
        {
#pragma unroll
            for (int m = 0; m < BLOCK_SIZE_M; m += DIVIDE_M)
            {
                if (tidY + m < N)
                    bs[(tidY + m) * BLOCK_SIZE_N + (tidX + n)] = b[(tidY + m) * N + idX + n];
            }
        }
        // __syncthreads();

        for (int i = 0; i < K; i += SUBK)
        {
            __syncthreads();
            int asBufferLoad = (((i / SUBK) % 2) * BLOCK_SIZE_M * SUBK);
            int bsBufferLoad = (((i / SUBK) % 2) * SUBK * BLOCK_SIZE_N);
            for (int j = 0; j < SUBK; ++j)
            {
                T ar;
                T br;
#pragma unroll
                for (int m = 0; m < NUM_PER_THREAD_M; ++m)
                {
                    ar = as[(tidY + m * DIVIDE_M) * SUBK + j + asBufferLoad];
#pragma unroll
                    for (int n = 0; n < NUM_PER_THREAD_N; ++n)
                    {
                        br = bs[j * BLOCK_SIZE_N + (tidX + n * DIVIDE_N) + bsBufferLoad];
                        if (i + j < K)
                            cTmp[m][n] += ar * br;
                    }
                }
            }

            if (i + SUBK < K)
            {
                int asBufferStore = ((((i / SUBK) + 1) % 2) * BLOCK_SIZE_M * SUBK);
#pragma unroll
                for (int m = 0; m < BLOCK_SIZE_M; m += DIVIDE_M)
                {
#pragma unroll
                    for (int n = 0; n < BLOCK_SIZE_N; n += DIVIDE_N)
                    {
                        if ((i + SUBK) + tidX + n < M)
                            as[(tidY + m) * SUBK + (tidX + n) + asBufferStore] = a[(idY + m) * K + (i + SUBK) + (tidX + n)];
                    }
                }

                int bsBufferStore = ((((i / SUBK) + 1) % 2) * SUBK * BLOCK_SIZE_N);
#pragma unroll
                for (int n = 0; n < BLOCK_SIZE_N; n += DIVIDE_N)
                {
#pragma unroll
                    for (int m = 0; m < BLOCK_SIZE_M; m += DIVIDE_M)
                    {
                        if ((i + SUBK) + tidY + m < N)
                            bs[(tidY + m) * BLOCK_SIZE_N + (tidX + n) + bsBufferStore] = b[((i + SUBK) + tidY + m) * N + idX + n];
                    }
                }
            }
            // __syncthreads();
        }

#pragma unroll
        for (int m = 0; m < NUM_PER_THREAD_M; ++m)
        {
#pragma unroll
            for (int n = 0; n < NUM_PER_THREAD_N; ++n)
            {
                if ((idY + m * DIVIDE_M) < M && (idX + n * DIVIDE_N) < N)
                    c[(idY + m * DIVIDE_M) * N + idX + n * DIVIDE_N] = cTmp[m][n];
            }
        }
    }
}

template <typename T, typename = std::enable_if_t<std::is_arithmetic<T>::value>>
__global__ void matrix_multiplication5(const T *a, const T *b, T *c)
{
    int idX = blockIdx.x * (blockDim.x * NUM_PER_THREAD_M) + threadIdx.x;
    int idY = blockIdx.y * (blockDim.y * NUM_PER_THREAD_N) + threadIdx.y;

    int tidX = threadIdx.x;
    int tidY = threadIdx.y;

    __shared__ T as[BLOCK_SIZE_M][SUBK];
    __shared__ T bs[SUBK][BLOCK_SIZE_N];

    {
        // T cTmp = 0;
        T cTmp[NUM_PER_THREAD_M][NUM_PER_THREAD_N];
#pragma unroll
        for (int m = 0; m < NUM_PER_THREAD_M; ++m)
        {
#pragma unroll
            for (int n = 0; n < NUM_PER_THREAD_N; ++n)
            {
                cTmp[m][n] = 0;
            }
        }

        for (int i = 0; i < K; i += SUBK)
        {
#pragma unroll
            for (int m = 0; m < BLOCK_SIZE_M; m += DIVIDE_M)
            {
#pragma unroll
                for (int n = 0; n < BLOCK_SIZE_N; n += DIVIDE_N)
                {
                    if (i + tidX + n < M)
                        as[(tidY + m)][(tidX + n)] = a[(idY + m) * K + i + (tidX + n)];
                    // printf("i: %d, tidX: %d, tidY: %d, m: %d, (tidY + m) * BLOCK_SIZE_M + (tidX + n): %d, as: %d\n", i, tidX, tidY, m, (tidY + m) * BLOCK_SIZE_M + (tidX + n), as[(tidY + m) * BLOCK_SIZE_M + (tidX + n)]);
                }
                // as[(tidY + m) * BLOCK_SIZE_M + tidX] = a[(idY + m) * K + i + tidX];
                // printf("i: %d, tidX: %d, tidY: %d, m: %d, (tidY + m) * BLOCK_SIZE_M + tidX: %d, as: %d\n", i, tidX, tidY, m, (tidY + m) * BLOCK_SIZE_M + tidX, as[(tidY + m) * BLOCK_SIZE_M + tidX]);
            }

#pragma unroll
            for (int n = 0; n < BLOCK_SIZE_N; n += DIVIDE_N)
            {
#pragma unroll
                for (int m = 0; m < BLOCK_SIZE_M; m += DIVIDE_M)
                {
                    if (i + tidY + m < N)
                        bs[(tidY + m)][(tidX + n)] = b[(i + tidY + m) * N + idX + n];
                    // printf("i: %d, tidX: %d, tidY: %d, m: %d, n: %d, (tidY + m) * BLOCK_SIZE_N + (tidX + n): %d, bs: %d\n", i, tidX, tidY, m, n, (tidY + m) * BLOCK_SIZE_N + (tidX + n), bs[(tidY + m) * BLOCK_SIZE_N + (tidX + n)]);
                }
                // bs[tidY * BLOCK_SIZE_N + (tidX + n)] = b[(i + tidY) * N + idX + n];
                // printf("i: %d, tidX: %d, tidY: %d, n: %d, tidY * BLOCK_SIZE_N + (tidX + n): %d, bs: %d\n", i, tidX, tidY, n, tidY * BLOCK_SIZE_N + (tidX + n), bs[tidY * BLOCK_SIZE_N + (tidX + n)]);
            }

            __syncthreads();
            for (int j = 0; j < SUBK; ++j)
            {
                T ar;
                T br;
#pragma unroll
                for (int m = 0; m < NUM_PER_THREAD_M; ++m)
                {
                    ar = as[(tidY + m * DIVIDE_M)][j];
#pragma unroll
                    for (int n = 0; n < NUM_PER_THREAD_N; ++n)
                    {
                        br = bs[j][(tidX + n * DIVIDE_N)];
                        if (i + j < K)
                            // if ((idY + m * DIVIDE_M) < M && (idX + n * DIVIDE_N) < N && i + j < K)
                            cTmp[m][n] += ar * br;
                    }
                }
            }
            __syncthreads();
        }

#pragma unroll
        for (int m = 0; m < NUM_PER_THREAD_M; ++m)
        {
#pragma unroll
            for (int n = 0; n < NUM_PER_THREAD_N; ++n)
            {
                if ((idY + m * DIVIDE_M) < M && (idX + n * DIVIDE_N) < N)
                    c[(idY + m * DIVIDE_M) * N + idX + n * DIVIDE_N] = cTmp[m][n];
            }
        }

        // c[idY * N + idX] = cTmp;
    }
}

template <typename T, typename = std::enable_if_t<std::is_arithmetic<T>::value>>
__global__ void matrix_multiplication6(const T *a, const T *b, T *c)
{
    int idX = blockIdx.x * (blockDim.x * NUM_PER_THREAD_M) + threadIdx.x;
    int idY = blockIdx.y * (blockDim.y * NUM_PER_THREAD_N) + threadIdx.y;

    int tidX = threadIdx.x;
    int tidY = threadIdx.y;

    __shared__ T as[2][BLOCK_SIZE_M][SUBK];
    __shared__ T bs[2][SUBK][BLOCK_SIZE_N];

    {
        // T cTmp = 0;
        T cTmp[NUM_PER_THREAD_M][NUM_PER_THREAD_N];
#pragma unroll
        for (int m = 0; m < NUM_PER_THREAD_M; ++m)
        {
#pragma unroll
            for (int n = 0; n < NUM_PER_THREAD_N; ++n)
            {
                cTmp[m][n] = 0;
            }
        }

#pragma unroll
        for (int m = 0; m < BLOCK_SIZE_M; m += DIVIDE_M)
        {
#pragma unroll
            for (int n = 0; n < BLOCK_SIZE_N; n += DIVIDE_N)
            {
                if (tidX + n < M)
                    as[0][(tidY + m)][(tidX + n)] = a[(idY + m) * K + (tidX + n)];
            }
        }

#pragma unroll
        for (int n = 0; n < BLOCK_SIZE_N; n += DIVIDE_N)
        {
#pragma unroll
            for (int m = 0; m < BLOCK_SIZE_M; m += DIVIDE_M)
            {
                if (tidY + m < N)
                    bs[0][(tidY + m)][(tidX + n)] = b[(tidY + m) * N + idX + n];
            }
        }
        // __syncthreads();

        for (int i = 0, l = 0; i < K; ++l)
        {
            int iBufferLoad = (l % 2);

            __syncthreads();
            for (int j = 0; j < SUBK; ++j)
            {
                T ar;
                T br;
#pragma unroll
                for (int m = 0; m < NUM_PER_THREAD_M; ++m)
                {
                    ar = as[iBufferLoad][(tidY + m * DIVIDE_M)][j];
#pragma unroll
                    for (int n = 0; n < NUM_PER_THREAD_N; ++n)
                    {
                        br = bs[iBufferLoad][j][(tidX + n * DIVIDE_N)];
                        if (i + j < K)
                            cTmp[m][n] += ar * br;
                    }
                }
            }

            i += SUBK;
            int iBufferStore = (l + 1) % 2;
            if (i < K)
            {
#pragma unroll
                for (int m = 0; m < BLOCK_SIZE_M; m += DIVIDE_M)
                {
#pragma unroll
                    for (int n = 0; n < BLOCK_SIZE_N; n += DIVIDE_N)
                    {
                        if ((i) + tidX + n < M)
                            as[iBufferStore][(tidY + m)][(tidX + n)] = a[(idY + m) * K + (i) + (tidX + n)];
                    }
                }

#pragma unroll
                for (int n = 0; n < BLOCK_SIZE_N; n += DIVIDE_N)
                {
#pragma unroll
                    for (int m = 0; m < BLOCK_SIZE_M; m += DIVIDE_M)
                    {
                        if ((i) + tidY + m < N)
                            bs[iBufferStore][(tidY + m)][(tidX + n)] = b[((i) + tidY + m) * N + idX + n];
                    }
                }
            }
            // __syncthreads();
        }

#pragma unroll
        for (int m = 0; m < NUM_PER_THREAD_M; ++m)
        {
#pragma unroll
            for (int n = 0; n < NUM_PER_THREAD_N; ++n)
            {
                if ((idY + m * DIVIDE_M) < M && (idX + n * DIVIDE_N) < N)
                    c[(idY + m * DIVIDE_M) * N + idX + n * DIVIDE_N] = cTmp[m][n];
            }
        }
    }
}

template <typename T, typename = std::enable_if_t<std::is_arithmetic<T>::value>>
void print_output(T *a, T *b, T *c)
{
    for (int i = 0; i < M * K; ++i)
    {
        if (i % K == 0)
        {
            std::cout << std::endl;
        }
        std::cout << a[i] << " ";
    }
    std::cout << std::endl;
    for (int i = 0; i < K * N; ++i)
    {
        if (i % N == 0)
        {
            std::cout << std::endl;
        }
        std::cout << b[i] << " ";
    }
    std::cout << std::endl;
    for (int i = 0; i < M * N; ++i)
    {
        if (i % N == 0)
        {
            std::cout << std::endl;
        }
        std::cout << c[i] << " ";
    }
    std::cout << std::endl;
}

template <typename T>
void matrix_multiplication_cpu(const T *a, const T *b, T *c)
{
    for (int i = 0; i < M; ++i)
    {
        for (int j = 0; j < N; ++j)
        {
            T sum = 0;
            for (int k = 0; k < K; ++k)
            {
                sum += a[i * K + k] * b[k * N + j];
            }
            c[i * N + j] = sum;
        }
    }
}

int main()
{
    // Allocate space for host copies of a, b
    thrust::host_vector<TYPE> a(M * K);
    thrust::host_vector<TYPE> b(K * N);
    thrust::host_vector<TYPE> c(M * N);
    thrust::host_vector<TYPE> c_cpu(M * N);

    // Randomly initialize a and b
    std::random_device rd;
    std::mt19937 gen(rd());
    std::uniform_int_distribution<TYPE> dis(0, 9);

    for (int i = 0; i < M * K; ++i)
    {
        a[i] = dis(gen);
        // a[i] = 1;
    }

    for (int i = 0; i < K * N; ++i)
    {
        b[i] = dis(gen);
        // b[i] = 1;
    }

    // Allocate space for device copies of a, b
    thrust::device_vector<TYPE> d_a = a;
    thrust::device_vector<TYPE> d_b = b;
    thrust::device_vector<TYPE> d_c(M * N, 0);

    dim3 threads_per_block(BLOCK_SIZE_N, BLOCK_SIZE_M, 1); // x y z
    dim3 no_of_blocks((N + BLOCK_SIZE_N - 1) / BLOCK_SIZE_N, (M + BLOCK_SIZE_M - 1) / BLOCK_SIZE_M, 1);

    warm_up<<<no_of_blocks, threads_per_block>>>();
    // matrix_multiplication0<<<no_of_blocks, threads_per_block>>>(thrust::raw_pointer_cast(d_a.data()), thrust::raw_pointer_cast(d_b.data()), thrust::raw_pointer_cast(d_c.data()));
    // matrix_multiplication1<<<no_of_blocks, threads_per_block>>>(thrust::raw_pointer_cast(d_a.data()), thrust::raw_pointer_cast(d_b.data()), thrust::raw_pointer_cast(d_c.data()));

    dim3 threads_per_block_multi(BLOCK_SIZE_N / NUM_PER_THREAD_N, BLOCK_SIZE_M / NUM_PER_THREAD_M, 1); // x y z
    dim3 no_of_blocks_multi((N + BLOCK_SIZE_N - 1) / BLOCK_SIZE_N, (M + BLOCK_SIZE_M - 1) / BLOCK_SIZE_M, 1);
    matrix_multiplication2<<<no_of_blocks_multi, threads_per_block_multi>>>(thrust::raw_pointer_cast(d_a.data()), thrust::raw_pointer_cast(d_b.data()), thrust::raw_pointer_cast(d_c.data()));

    // matrix_multiplication3<<<no_of_blocks_multi, threads_per_block_multi>>>(thrust::raw_pointer_cast(d_a.data()), thrust::raw_pointer_cast(d_b.data()), thrust::raw_pointer_cast(d_c.data()));
    // matrix_multiplication4<<<no_of_blocks_multi, threads_per_block_multi>>>(thrust::raw_pointer_cast(d_a.data()), thrust::raw_pointer_cast(d_b.data()), thrust::raw_pointer_cast(d_c.data()));
    // matrix_multiplication5<<<no_of_blocks_multi, threads_per_block_multi>>>(thrust::raw_pointer_cast(d_a.data()), thrust::raw_pointer_cast(d_b.data()), thrust::raw_pointer_cast(d_c.data()));
    // matrix_multiplication6<<<no_of_blocks_multi, threads_per_block_multi>>>(thrust::raw_pointer_cast(d_a.data()), thrust::raw_pointer_cast(d_b.data()), thrust::raw_pointer_cast(d_c.data()));

    thrust::copy(d_c.begin(), d_c.end(), c.begin());

    // Perform CPU matrix multiplication
    matrix_multiplication_cpu(a.data(), b.data(), c_cpu.data());

    // print_output(a.data(), b.data(), c.data());

    // Verify the results
    bool match = true;
    for (int i = 0; i < M * N; ++i)
    {
        if (c[i] != c_cpu[i])
        {
            match = false;
            break;
        }
    }

    if (match)
        std::cout << "Results match!" << std::endl;
    else
        std::cout << "Results do not match!" << std::endl;

    return 0;
}

In [None]:
!nvcc -o matrix_multiplication -lineinfo matrix_multiplication.cu

In [None]:
!./matrix_multiplication

In [None]:
!wget https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2204/x86_64/nsight-systems-2024.2.3_2024.2.3.38-1_amd64.deb
!apt update
!apt install ./nsight-systems-2024.2.3_2024.2.3.38-1_amd64.deb
!apt --fix-broken install

In [None]:
!nsys profile -o report_nsys_matrix_multiplication ./matrix_multiplication -f

In [None]:
!ncu --set full --replay-mode kernel --target-processes all -o report_ncu_matrix_multiplication -f ./matrix_multiplication