Reference: https://zhuanlan.zhihu.com/p/410278370, https://zhuanlan.zhihu.com/p/435908830, https://github.com/yzhaiustc/Optimizing-SGEMM-on-NVIDIA-Turing-GPUs

matrix_multiplication1: Read the original matrix into shared memory in blocks based on the size of SUBK iteratively.

In [None]:
%%writefile matrix_multiplication.cu
#include <stdio.h>
#include <stdlib.h>
#include <type_traits>
#include <thrust/host_vector.h>
#include <thrust/device_vector.h>
#include <thrust/sequence.h>
#include <random>

#define TYPE int
// #define M 4
// #define K 3
// #define N 2

#define M 100
#define K 90
#define N 110

#define BLOCK_SIZE 32

#define BLOCK_SIZE_M BLOCK_SIZE
#define BLOCK_SIZE_N BLOCK_SIZE
#define NUM_PER_THREAD 8

// #define SUBM 1
// #define SUBM_NUM (M / SUBM)
#define SUBK BLOCK_SIZE
#define SUBK_NUM (K / SUBK)
// #define SUBN 1
// #define SUBN_NUM (N / SUBN)

// M 4 K 3
//  1  2  3
//  4  5  6
//  7  8  9
// 10 11 12

// K 3 N 2
// 1 4
// 2 5
// 3 6

// M 4 N 2
// 1*1 + 2*2 + 3*3 = 14         1*4 + 2*5 + 3*6 = 32
// 4*1 + 5*2 + 6*3 = 32         4*4 + 5*5 + 6*6 = 77
// 7*1 + 8*2 + 9*3 = 50         7*4 + 8*5 + 9*6 = 122
// 10*1 + 11*2 + 12*3 = 68      10*4 + 11*5 + 12*6 = 167
//  x  x
// y
// y
// y
// y

__global__ void
warm_up()
{
    int indexX = threadIdx.x + blockIdx.x * blockDim.x;
    int indexY = threadIdx.y + blockIdx.y * blockDim.y;
    if (indexX < M && indexY < N)
    {
        float a = 0.0f;
        float b = 1.0f;
        float c = a + b;
    }
}

template <typename T, typename = std::enable_if_t<std::is_arithmetic<T>::value>>
__global__ void matrix_multiplication0(const T *a, const T *b, T *c)
{
    int idX = blockIdx.x * blockDim.x + threadIdx.x;
    int idY = blockIdx.y * blockDim.y + threadIdx.y;
    if (idY < M && idX < N)
    {
        T cTmp = 0;
        for (int i = 0; i < K; ++i)
        {
            cTmp += a[idY * K + i] * b[i * N + idX];
        }
        c[idY * N + idX] = cTmp;
    }
}

template <typename T, typename = std::enable_if_t<std::is_arithmetic<T>::value>>
__global__ void matrix_multiplication1(const T *a, const T *b, T *c)
{
    int idX = blockIdx.x * blockDim.x + threadIdx.x;
    int idY = blockIdx.y * blockDim.y + threadIdx.y;

    int tidX = threadIdx.x;
    int tidY = threadIdx.y;

    __shared__ T as[BLOCK_SIZE_M * SUBK];
    __shared__ T bs[SUBK * BLOCK_SIZE_N];

    {
        T cTmp = 0;

        // read a and b into shared memory by SUBK
        // SUBK * BLOCK_SIZE_M 1, SUBK * BLOCK_SIZE_M 2, SUBK * BLOCK_SIZE_M 3...
        // BLOCK_SIZE_N * SUBK 1, BLOCK_SIZE_N * SUBK 2, BLOCK_SIZE_N * SUBK 3...
        for (int i = 0; i < K; i += SUBK)
        {
            if (i + tidX < M)
                as[tidY * BLOCK_SIZE_M + tidX] = a[idY * K + i + tidX]; // tidX < SUBK, tidY < BLOCK_SIZE_M
            if (i + tidY < N)
                bs[tidY * BLOCK_SIZE_N + tidX] = b[(i + tidY) * N + idX]; // tidX < BLOCK_SIZE_N, tidY < SUBK
            __syncthreads();
            for (int j = 0; j < SUBK; ++j)
            {
                if (i + j < K)
                    cTmp += as[tidY * BLOCK_SIZE_M + j] * bs[j * BLOCK_SIZE_N + tidX]; // why no bank conflict? as[(tidX%32)*32] = 0 yes, as[0] = 0 no.
                // if(idX == 0 && idY == 0)
                // {
                //   printf("i: %d, j: %d, idX: %d, idY: %d, as: %d, bs: %d, cTmp: %d\n", i, j, idX, idY, as[tidY * BLOCK_SIZE_M + j], bs[j * BLOCK_SIZE_N + tidX], cTmp);
                // }
                // if(idX == 8 && idY == 0)
                // {
                //   printf("i: %d, j: %d, idX: %d, idY: %d, as: %d, bs: %d, cTmp: %d\n", i, j, idX, idY, as[tidY * BLOCK_SIZE_M + j], bs[j * BLOCK_SIZE_N + tidX], cTmp);
                // }
            }
            __syncthreads();
        }

        if (idY < M && idX < N)
            c[idY * N + idX] = cTmp;
    }
}

template <typename T, typename = std::enable_if_t<std::is_arithmetic<T>::value>>
void print_output(T *a, T *b, T *c)
{
    for (int i = 0; i < M * K; ++i)
    {
        if (i % K == 0)
        {
            std::cout << std::endl;
        }
        std::cout << a[i] << " ";
    }
    std::cout << std::endl;
    for (int i = 0; i < K * N; ++i)
    {
        if (i % N == 0)
        {
            std::cout << std::endl;
        }
        std::cout << b[i] << " ";
    }
    std::cout << std::endl;
    for (int i = 0; i < M * N; ++i)
    {
        if (i % N == 0)
        {
            std::cout << std::endl;
        }
        std::cout << c[i] << " ";
    }
    std::cout << std::endl;
}

template <typename T>
void matrix_multiplication_cpu(const T *a, const T *b, T *c)
{
    for (int i = 0; i < M; ++i)
    {
        for (int j = 0; j < N; ++j)
        {
            T sum = 0;
            for (int k = 0; k < K; ++k)
            {
                sum += a[i * K + k] * b[k * N + j];
            }
            c[i * N + j] = sum;
        }
    }
}

int main()
{
    // Allocate space for host copies of a, b
    thrust::host_vector<TYPE> a(M * K);
    thrust::host_vector<TYPE> b(K * N);
    thrust::host_vector<TYPE> c(M * N);
    thrust::host_vector<TYPE> c_cpu(M * N);

    // Randomly initialize a and b
    std::random_device rd;
    std::mt19937 gen(rd());
    std::uniform_int_distribution<TYPE> dis(0, 9);

    for (int i = 0; i < M * K; ++i)
    {
        a[i] = dis(gen);
    }

    for (int i = 0; i < K * N; ++i)
    {
        b[i] = dis(gen);
    }

    // Allocate space for device copies of a, b
    thrust::device_vector<TYPE> d_a = a;
    thrust::device_vector<TYPE> d_b = b;
    thrust::device_vector<TYPE> d_c(M * N, 0);

    dim3 threads_per_block(BLOCK_SIZE_N, BLOCK_SIZE_M, 1);
    dim3 no_of_blocks((N + BLOCK_SIZE_N - 1) / BLOCK_SIZE_N, (M + BLOCK_SIZE_M - 1) / BLOCK_SIZE_M, 1);

    warm_up<<<no_of_blocks, threads_per_block>>>();
    matrix_multiplication0<<<no_of_blocks, threads_per_block>>>(thrust::raw_pointer_cast(d_a.data()), thrust::raw_pointer_cast(d_b.data()), thrust::raw_pointer_cast(d_c.data()));
    matrix_multiplication1<<<no_of_blocks, threads_per_block>>>(thrust::raw_pointer_cast(d_a.data()), thrust::raw_pointer_cast(d_b.data()), thrust::raw_pointer_cast(d_c.data()));

    cudaDeviceSynchronize();

    thrust::copy(d_c.begin(), d_c.end(), c.begin());

    // Perform CPU matrix multiplication
    matrix_multiplication_cpu(a.data(), b.data(), c_cpu.data());

    print_output(a.data(), b.data(), c.data());

    // Verify the results
    bool match = true;
    for (int i = 0; i < M * N; ++i)
    {
        if (c[i] != c_cpu[i])
        {
            match = false;
            break;
        }
    }

    if (match)
        std::cout << "Results match!" << std::endl;
    else
        std::cout << "Results do not match!" << std::endl;

    return 0;
}

In [None]:
!nvcc -o matrix_multiplication -lineinfo matrix_multiplication.cu

In [None]:
!./matrix_multiplication