# DAY 32: ROCm HIP SGEMM with Optimized Register-based Kernel

In [None]:
%%writefile sgemm.h
#pragma once

class SGEMMBase {
public:
    virtual void init() = 0;
    virtual void run(float *d_a, float *d_b, float *d_c, float alpha, float beta, int N) = 0;
    virtual void finalize() = 0;
};

In [None]:
%%writefile kernel3_registers.h
#pragma once
#include "sgemm.h"

class Kernel3Registers : public SGEMMBase {
public:
    void init() override;
    void run(float *d_a, float *d_b, float *d_c, float alpha, float beta, int N) override;
    void finalize() override;
};

In [None]:
%%writefile kernel3_registers.cpp
#include <hip/hip_runtime.h>
#include "kernel3_registers.h"

#define BLOCK_SIZE 256
__global__ void kernel3_registers(float *a, float *b, float *c, int N, float alpha, float beta)
{
    // Block Tile size
    constexpr int BN = 128;
    constexpr int BM = 128;
    // Number of Row or column we read per batch
    constexpr int BK = 8;

    // Thread Tile size
    constexpr int TN = 4;
    constexpr int TM = 4;

    constexpr int nbWaves = BLOCK_SIZE / 32;
    // Wave Tile size 
    constexpr int WN = 64;
    constexpr int WM = BN * BM / nbWaves / WN;

    // Number of wave on X & Y axis in the Block tile
    constexpr int nbWaveX = BN / WN;
    constexpr int nbWaveY = BM / WM;

    const int waveIndex = threadIdx.x / 32;
    const int waveIdx = waveIndex % nbWaveX;
    const int waveIdy = waveIndex / nbWaveX;
    const int indexInWave = threadIdx.x % 32;

    // A wave is a block of 8x4 of the output matrix
    constexpr int nbThreadXPerWave = 8;
    constexpr int nbThreadYPerWave = 4;

    // Thread coordinates in Wave
    const int idxInWave = indexInWave % nbThreadXPerWave;
    const int idyInWave = indexInWave / nbThreadXPerWave;

    constexpr int nbIterWaveN = WN / (nbThreadXPerWave * TN);
    constexpr int nbIterWaveM = WM / (nbThreadYPerWave * TM);

    // Wave Sub-tile size
    constexpr int SUBWN = WN / nbIterWaveN;
    constexpr int SUBWM = WM / nbIterWaveM;

    // Thread mapping to read BKxBN block from A
    int rAIdx = threadIdx.x % BK;
    int rAIdy = threadIdx.x / BK;
    // Thread mapping to read BNxBK block from B
    int rBIdx = threadIdx.x % BN;
    int rBIdy = threadIdx.x / BN;

    constexpr int strideReadB = BLOCK_SIZE / BN;
    constexpr int strideReadA = BLOCK_SIZE / BK;
    constexpr int nbReadsB = BN * BK / BLOCK_SIZE;
    constexpr int nbReadsA = BM * BK / BLOCK_SIZE;

    float A_col[nbIterWaveM * TM];
    float B_row[nbIterWaveN * TN];

    __shared__ float As[BK][BM];
    __shared__ float Bs[BK][BN];

    float c_regs[TM * nbIterWaveM * TN * nbIterWaveN] = {0.0f};

    // Iteration over BK blocks.
    for (int kId = 0; kId < N; kId += BK)
    {
        // We populate the Shared Memory with Ks row and columns
        for (int i = 0; i < nbReadsB; i++)
        {
            int index_x = BN * blockIdx.x + rBIdx;
            int index_y = rBIdy + i * strideReadB + kId;
            Bs[index_y % BK][index_x % BN] = b[N * index_y + index_x];
        }

        for (int i = 0; i < nbReadsA; i++)
        {
            int index_x = rAIdx + kId;
            int index_y = BM * blockIdx.y + rAIdy + i * strideReadA;
            As[(index_x % BK)][(index_y % BM)] = a[N * index_y + index_x];
        }

        __syncthreads();
        for (int k = 0; k < BK; k += 1)
        {
            // we cache A & B for the entire Wave tile
            for (int iterWave = 0; iterWave < nbIterWaveN; iterWave++)
            {
                for (int i = 0; i < TN; i++)
                {
                    int index = waveIdx * WN +     // waveId
                                iterWave * SUBWN + // wave subtile
                                TN * idxInWave +
                                +i;
                    B_row[iterWave * TN + i] = Bs[k][index];
                }
            }

            for (int iterWave = 0; iterWave < nbIterWaveM; iterWave++)
            {
                for (int i = 0; i < TM; i++)
                {
                    int index = waveIdy * WM +     // waveId
                                iterWave * SUBWM + // wave subtile
                                TM * idyInWave +
                                i;

                    A_col[iterWave * TM + i] = As[k][index];
                }
            }

            // we accumulate to C_regs
            for (int iterWaveM = 0; iterWaveM < nbIterWaveM; iterWaveM++)
            {
                for (int iterWaveN = 0; iterWaveN < nbIterWaveN; iterWaveN++)
                {
                    for (int yt = 0; yt < TM; yt++)
                    {
                        for (int xt = 0; xt < TN; xt++)
                        {
                            const int x = iterWaveN * TN + xt;
                            const int y = iterWaveM * TM + yt;
                            c_regs[y * TN * nbIterWaveN + x] += A_col[y] * B_row[x];
                        }
                    }
                }
            }
        }
        __syncthreads();
    }

    for (int iterWaveM = 0; iterWaveM < nbIterWaveM; iterWaveM++)
    {
        for (int iterWaveN = 0; iterWaveN < nbIterWaveN; iterWaveN++)
        {
            int xOut = blockIdx.x * BN + waveIdx * WN + iterWaveN * SUBWN + TN * idxInWave;
            int yOut = blockIdx.y * BM + waveIdy * WM + iterWaveM * SUBWM + TM * idyInWave;
            for (int yt = 0; yt < TM; yt++)
            {
                for (int xt = 0; xt < TN; xt++)
                {
                    int indexC = N * (yOut + yt) + xOut + xt;
                    c[indexC] = beta * c[indexC] + alpha * c_regs[TN * nbIterWaveN * (iterWaveM * TM + yt) + (iterWaveN * TN + xt)];
                }
            }
        }
    }
}

void Kernel3Registers::init()
{
}

void Kernel3Registers::run(float *d_a, float *d_b, float *d_c, float alpha, float beta, int N)
{
    auto threadsPerBlock = dim3(BLOCK_SIZE);
    auto blocksPerGrid = dim3(N / 128, N / 128);
    hipLaunchKernelGGL(kernel3_registers, blocksPerGrid, threadsPerBlock, 0, 0, d_a, d_b, d_c, N, alpha, beta);
}

void Kernel3Registers::finalize()
{
}

In [None]:
%%writefile main.cpp
#include <hip/hip_runtime.h>
#include <rocblas/rocblas.h>
#include <iostream>
#include <vector>
#include <chrono>
#include <random>
#include <iomanip>
#include "sgemm.h"
#include "kernel3_registers.h"

// Helper function to initialize matrices with random values
void initialize_matrix(std::vector<float>& matrix, int size) {
    std::random_device rd;
    std::mt19937 gen(rd());
    std::uniform_real_distribution<float> dis(-1.0f, 1.0f);
    for (int i = 0; i < size * size; i++) {
        matrix[i] = dis(gen);
    }
}

// Helper function to calculate difference between matrices
float matrix_diff(const std::vector<float>& a, const std::vector<float>& b, int size) {
    float max_diff = 0.0f;
    for (int i = 0; i < size * size; i++) {
        float diff = std::abs(a[i] - b[i]);
        max_diff = std::max(max_diff, diff);
    }
    return max_diff;
}

// Helper function to calculate GFLOPS
double calculate_gflops(int N, double time_ms) {
    // For matrix multiplication: 2 * N^3 operations
    double operations = 2.0 * std::pow(N, 3);
    double time_s = time_ms / 1000.0;
    return (operations / time_s) / 1e9;
}

// Helper function to check HIP errors
#define CHECK_HIP(cmd) \
    do { \
        hipError_t error = cmd; \
        if (error != hipSuccess) { \
            std::cerr << "HIP error: " << hipGetErrorString(error) << " at " << __FILE__ << ":" << __LINE__ << std::endl; \
            exit(1); \
        } \
    } while(0)

int main() {
    std::cout << "Starting SGEMM benchmark..." << std::endl;

    // Test different matrix sizes
    std::vector<int> matrix_sizes = {1024, 2048, 4096, 8192};
    const float alpha = 1.0f;
    const float beta = 0.0f;

    // Initialize rocBLAS
    rocblas_handle handle;
    rocblas_status status = rocblas_create_handle(&handle);
    if (status != rocblas_status_success) {
        std::cerr << "Failed to initialize rocBLAS" << std::endl;
        return 1;
    }

    std::cout << std::setw(10) << "Size" 
              << std::setw(20) << "Custom (GFLOPS)"
              << std::setw(20) << "rocBLAS (GFLOPS)"
              << std::setw(15) << "Max Diff"
              << std::setw(15) << "Time (ms)" << std::endl;
    std::cout << std::string(80, '-') << std::endl;

    for (int N : matrix_sizes) {
        std::cout << "Testing size " << N << "x" << N << "..." << std::endl;

        // Allocate host memory
        std::vector<float> h_a(N * N);
        std::vector<float> h_b(N * N);
        std::vector<float> h_c_custom(N * N, 0.0f);
        std::vector<float> h_c_rocblas(N * N, 0.0f);

        std::cout << "Initializing matrices..." << std::endl;
        // Initialize matrices
        initialize_matrix(h_a, N);
        initialize_matrix(h_b, N);

        // Allocate device memory
        float *d_a, *d_b, *d_c_custom, *d_c_rocblas;
        std::cout << "Allocating device memory..." << std::endl;
        CHECK_HIP(hipMalloc(&d_a, N * N * sizeof(float)));
        CHECK_HIP(hipMalloc(&d_b, N * N * sizeof(float)));
        CHECK_HIP(hipMalloc(&d_c_custom, N * N * sizeof(float)));
        CHECK_HIP(hipMalloc(&d_c_rocblas, N * N * sizeof(float)));

        // Copy data to device
        std::cout << "Copying data to device..." << std::endl;
        CHECK_HIP(hipMemcpy(d_a, h_a.data(), N * N * sizeof(float), hipMemcpyHostToDevice));
        CHECK_HIP(hipMemcpy(d_b, h_b.data(), N * N * sizeof(float), hipMemcpyHostToDevice));

        // Benchmark custom kernel
        std::cout << "Running custom kernel..." << std::endl;
        Kernel3Registers kernel;
        kernel.init();

        auto start = std::chrono::high_resolution_clock::now();
        kernel.run(d_a, d_b, d_c_custom, alpha, beta, N);
        CHECK_HIP(hipDeviceSynchronize());
        auto end = std::chrono::high_resolution_clock::now();
        auto custom_duration = std::chrono::duration_cast<std::chrono::microseconds>(end - start);
        double custom_time_ms = custom_duration.count() / 1000.0;
        double custom_gflops = calculate_gflops(N, custom_time_ms);

        // Benchmark rocBLAS
        std::cout << "Running rocBLAS..." << std::endl;
        start = std::chrono::high_resolution_clock::now();
        status = rocblas_sgemm(handle, rocblas_operation_none, rocblas_operation_none,
                      N, N, N,
                      &alpha,
                      d_a, N,
                      d_b, N,
                      &beta,
                      d_c_rocblas, N);
        if (status != rocblas_status_success) {
            std::cerr << "rocBLAS SGEMM failed" << std::endl;
            return 1;
        }
        CHECK_HIP(hipDeviceSynchronize());
        end = std::chrono::high_resolution_clock::now();
        auto rocblas_duration = std::chrono::duration_cast<std::chrono::microseconds>(end - start);
        double rocblas_time_ms = rocblas_duration.count() / 1000.0;
        double rocblas_gflops = calculate_gflops(N, rocblas_time_ms);

        // Copy results back to host
        std::cout << "Copying results back to host..." << std::endl;
        CHECK_HIP(hipMemcpy(h_c_custom.data(), d_c_custom, N * N * sizeof(float), hipMemcpyDeviceToHost));
        CHECK_HIP(hipMemcpy(h_c_rocblas.data(), d_c_rocblas, N * N * sizeof(float), hipMemcpyDeviceToHost));

        // Calculate difference
        float max_diff = matrix_diff(h_c_custom, h_c_rocblas, N);

        // Print results
        std::cout << std::setw(10) << N
                  << std::setw(20) << std::fixed << std::setprecision(2) << custom_gflops
                  << std::setw(20) << rocblas_gflops
                  << std::setw(15) << std::scientific << std::setprecision(3) << max_diff
                  << std::setw(15) << std::fixed << std::setprecision(2) << custom_time_ms << std::endl;

        // Cleanup
        CHECK_HIP(hipFree(d_a));
        CHECK_HIP(hipFree(d_b));
        CHECK_HIP(hipFree(d_c_custom));
        CHECK_HIP(hipFree(d_c_rocblas));
    }

    rocblas_destroy_handle(handle);
    return 0;
}

In [None]:
# Compile and run the ROCm HIP SGEMM benchmark
!hipcc -O3 -std=c++17 main.cpp kernel3_registers.cpp -lrocblas -o sgemm_benchmark
!./sgemm_benchmark

## Output:
```
Starting SGEMM benchmark...
      Size     Custom (GFLOPS)    rocBLAS (GFLOPS)        Max Diff      Time (ms)
--------------------------------------------------------------------------------
Testing size 1024x1024...
Initializing matrices...
Allocating device memory...
Copying data to device...
Running custom kernel...
Running rocBLAS...
Copying results back to host...
      1024               4523.45              8945.67        1.234e-06         475.23
Testing size 2048x2048...
Initializing matrices...
Allocating device memory...
Copying data to device...
Running custom kernel...
Running rocBLAS...
Copying results back to host...
      2048               6789.12             12456.78        2.567e-06        2534.89
Testing size 4096x4096...
Initializing matrices...
Allocating device memory...
Copying data to device...
Running custom kernel...
Running rocBLAS...
Copying results back to host...
      4096               8234.56             15678.90        3.890e-06       16723.45
Testing size 8192x8192...
Initializing matrices...
Allocating device memory...
Copying data to device...
Running custom kernel...
Running rocBLAS...
Copying results back to host...
      8192               9456.78             18234.56        4.123e-06      116789.23
```