In [None]:
Reference: https://zhuanlan.zhihu.com/p/410278370

In [None]:
%%writefile matrix_multiplication.cu
#include <stdio.h>
#include <stdlib.h>
#include <type_traits>
#include <thrust/host_vector.h>
#include <thrust/device_vector.h>
#include <thrust/sequence.h>

#define TYPE int
#define M 4
#define K 3
#define N 2
#define BLOCK_SIZE 32
#define NUM_PER_THREAD 8

// M 4 K 3
//  1  2  3
//  4  5  6
//  7  8  9
// 10 11 12

// K 3 N 2
// 1 4
// 2 5
// 3 6

// M 4 N 2
// 1*1 + 2*2 + 3*3 = 14         1*4 + 2*5 + 3*6 = 32
// 4*1 + 5*2 + 6*3 = 32         4*4 + 5*5 + 6*6 = 77
// 7*1 + 8*2 + 9*3 = 50         7*4 + 8*5 + 9*6 = 122
// 10*1 + 11*2 + 12*3 = 68      10*4 + 11*5 + 12*6 = 167
//  x  x
//y
//y
//y
//y

__global__ void warm_up()
{
    int indexX = threadIdx.x + blockIdx.x * blockDim.x;
    int indexY = threadIdx.y + blockIdx.y * blockDim.y;
    if (indexX < M && indexY < N)
    {
        float a = 0.0f;
        float b = 1.0f;
        float c = a + b;
    }
}

template <typename T, typename = std::enable_if_t<std::is_arithmetic<T>::value>>
__global__ void matrix_multiplication(const T *a, const T *b, T *c) {
    int indexX = blockIdx.x * blockDim.x + threadIdx.x;
    int indexY = blockIdx.y * blockDim.y + threadIdx.y;
    if(indexY < M && indexX < N) {
        T cTmp = 0;
        for(int i = 0; i < K; ++i){
            cTmp += a[indexY * K + i] * b[i * N + indexX];
        }
        c[indexY * N + indexX] = cTmp;
    }
}


template <typename T, typename = std::enable_if_t<std::is_arithmetic<T>::value>>
void print_output(T *a, T *b, T* c)
{
    for (int i = 0; i < M * K; ++i)
    {
        if (i % K == 0)
        {
            std::cout << std::endl;
        }
        std::cout << a[i] << " ";
    }
    std::cout << std::endl;
    for (int i = 0; i < K * N; ++i)
    {
        if (i % N == 0)
        {
            std::cout << std::endl;
        }
        std::cout << b[i] << " ";
    }
    std::cout << std::endl;
    for (int i = 0; i < M * N; ++i)
    {
        if (i % N == 0)
        {
            std::cout << std::endl;
        }
        std::cout << c[i] << " ";
    }
}

int main()
{
    // Allocate space for host copies of a, b
    thrust::host_vector<TYPE> a(M * K);
    thrust::host_vector<TYPE> b(K * N);
    thrust::host_vector<TYPE> c(M * N);

    // Allocate space for device copies of a, b
    thrust::device_vector<TYPE> d_a(M * K, 1);
    thrust::device_vector<TYPE> d_b(K * N, 1);
    thrust::device_vector<TYPE> d_c(M * N, 0);

    dim3 threads_per_block(BLOCK_SIZE, BLOCK_SIZE, 1);
    dim3 no_of_blocks((N + BLOCK_SIZE - 1) / BLOCK_SIZE, (M + BLOCK_SIZE - 1) / BLOCK_SIZE, 1);

    warm_up<<<no_of_blocks, threads_per_block>>>();
    matrix_multiplication<<<no_of_blocks, threads_per_block>>>(thrust::raw_pointer_cast(d_a.data()), thrust::raw_pointer_cast(d_b.data()), thrust::raw_pointer_cast(d_c.data()));

    cudaDeviceSynchronize();

    thrust::copy(d_a.begin(), d_a.end(), a.begin());
    thrust::copy(d_b.begin(), d_b.end(), b.begin());
    thrust::copy(d_c.begin(), d_c.end(), c.begin());

    print_output(a.data(), b.data(), c.data());

    return 0;
}

In [None]:
!nvcc -o matrix_multiplication -lineinfo matrix_multiplication.cu

In [None]:
!./matrix_multiplication