## Challenge Exercise 4.1: SYRK Emulation

Now that we understand how to build a performant emulated GEMM kernel, we can start to think about how we can apply the same algorithm and techniques to other routines.

A very closely related routine is the symmetric rank-k update, often referred to as SYRK.  For a given $n \times k$ matrix $\mathbf{A}$, we can compute an $n \times n$ matrix $\mathbf{C}$ at row $i$, column $j$ as follows:

$$
\mathbf{C}_{i, j} = \alpha \sum_{l=0}^{k}\left( \mathbf{A}_{i, l} \mathbf{A}^{T}_{l, j} \right) + \beta \mathbf{C}_{i, j}
$$

You may notice that this follows the same definition as GEMM, except that we are multiplying $\mathbf{A}$ with itself.  This allows the output matrix to be symmetric (i.e. $\mathbf{C} = \mathbf{C}^{T}$)

That matrix property along with the knowledge that we are multiply $A$ by itself allows us to make more problem specific optimizations to reduce the amount of math operations and access memory more efficiently.

### C++ CMake Setup

In [None]:
import sys, os
sys.path.append(os.sep.join(["..", "utilities", "python"]))
from common_cuda import setup_cmake_project
setup_cmake_project()

### Python Imports

In [None]:
import sys
import os

import numpy as np
import cupy as cp
import nvmath

from nvmath.device import Matmul
from nvmath.device.cublasdx import DevicePipeline, SharedStorageCalc
from nvmath.device.cublasdx_numba import pipeline_extensions
from nvmath.device.common import axpby, clear, copy, copy_fragment, copy_wait, make_tensor
from numba import cuda

sys.path.append(os.sep.join(["..", "utilities", "python"]))

from benchmark import *

### C++

In [None]:
%%writefile cpp/3a_fused_syrk_emulation/parameters.hpp.inc

    int const warm_up_runs = 10;
    int const kernel_runs = 100;

    // ===================================
    // Problem configuration
    // ===================================

    // (syrk_n, syrk_k, alpha, beta, uplo)
    std::vector<tutorial::syrk_problem_t> problems = {
        {8192, 8192, 0.9, 1.1, tutorial::matrix_half::upper},
        {8192, 8192, 0.9, 1.1, tutorial::matrix_half::lower}
    };
    

    // ===================================
    // Global SYRK configuration
    // ===================================

    // The number of slices used in emulation algorithm
    // More slices = higher precision but more computation
    constexpr unsigned slices = 7;

    bool const   debug        = false;

In [None]:
%%writefile cpp/3a_fused_syrk_emulation/cublasdx_config.hpp.inc

    using slice_value_type       = int8_t;  // Precision for individual slices
    using accumulator_value_type = int32_t; // Precision for accumulation

    // The shape of data tile processed by a single CTA block
    constexpr int tile_m = 128;
    constexpr int tile_n = 128;
    constexpr int tile_k = 128;

    // The shape of CTA block (number of threads)
    constexpr int cta_shape_x = 128;
    constexpr int cta_shape_y = 1;
    constexpr int cta_shape_z = 1;

    using BLAS = decltype(cublasdx::Size<tile_m, tile_n, tile_k>() +
                          cublasdx::Precision<slice_value_type, slice_value_type, accumulator_value_type>() +
                          cublasdx::Type<cublasdx::type::real>() + cublasdx::Function<cublasdx::function::MM>() +
                          cublasdx::Arrangement<arrangement_a, arrangement_a_t, arrangement_c>() + cublasdx::Block() +
                          cublasdx::BlockDim<cta_shape_x, cta_shape_y, cta_shape_z>() + cublasdx::StaticBlockDim() +
                          cublasdx::MaxAlignment() + cublasdx::EnableInputStreaming() + cublasdx::WithPipeline() +
                          cublasdx::SM<SM_VALUE, SM_MODIFIER_VALUE>());

In [None]:
%%writefile cpp/3a_fused_syrk_emulation/pipeline_config.hpp.inc

        constexpr unsigned pipeline_depth = 3;

        auto device_pipeline = cublasdx::suggest_device_pipeline<pipeline_depth, BLAS, cublasdx::external_accumulation>(
                                   tensor_slice_a, tensor_slice_at)
                                   .value();

In [None]:
%%writefile cpp/3a_fused_syrk_emulation/kernel_config.hpp.inc

        dim3      grid(shape_a_rows / static_tile_m(), shape_a_rows / static_tile_n());
        auto kernel = fused_epilogue_kernel<BLAS,
                                            decltype(device_pipeline),
                                            double,
                                            double,
                                            CTensor,
                                            decltype(tensor_shift_a),
                                            decltype(tensor_shift_at),
                                            Slices>;

        auto shared_memory_size = cublasdx::make_shared_storage_calculator()
                                      .add(device_pipeline.buffer_alignment(), device_pipeline.buffer_size())
                                      .add(16, sizeof(int32_t), static_tile_m()) // shift_a
                                      .add(16, sizeof(int32_t), static_tile_n()) // shift_b
                                      .get();

        CUDA_CHECK_AND_EXIT(
            cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, shared_memory_size));

In [None]:
%%writefile cpp/3a_fused_syrk_emulation/fused_kernel.hpp.inc

template<class BLAS,
         class DevicePipeline,
         class Alpha,
         class Beta,
         class CTensor,
         class AShiftTensor,
         class AtShiftTensor,
         int32_t Slices>
__launch_bounds__(DevicePipeline::max_threads_per_block, 1) __global__
    void fused_epilogue_kernel(__grid_constant__ DevicePipeline const device_pipeline,
                               Alpha                                  alpha,
                               Beta                                   beta,
                               CTensor                                gmem_c_fp64,
                               tutorial::matrix_half                  output_half,
                               AShiftTensor const                     gmem_shift_a,
                               AtShiftTensor const                    gmem_shift_at) {
    extern __shared__ __align__(device_pipeline.buffer_alignment()) char smem[];
#ifdef __CUDA_ARCH__
    if constexpr (cublasdx::sm_of_v<BLAS> == __CUDA_ARCH__) {
        // CHALLENGE EXERCISE --> Implement a fused SYRK
        // HINT --> Start with emulated GEMM kernel from 3.3
    }
#endif
}

In [None]:
%%writefile cpp/3a_fused_syrk_emulation/kernel_launch.hpp.inc

        kernel<<<grid, device_pipeline.get_block_dim(), shared_memory_size, stream>>>(
            device_pipeline, alpha, beta, tensor_c, output_half, tensor_shift_a, tensor_shift_at);

In [None]:
!cmake --build ./build -t 3a_fused_syrk_emulation

In [None]:
!./build/3a_fused_syrk_emulation

#### Solution

We will rewrite kernel now and recompile the solution. If you want to restart your exercise make sure you rewrite kernel back and recompile it.

In [None]:
%%writefile cpp/3a_fused_syrk_emulation/fused_kernel.hpp.inc

template<class BLAS,
         class DevicePipeline,
         class Alpha,
         class Beta,
         class CTensor,
         class AShiftTensor,
         class AtShiftTensor,
         int32_t Slices>
__launch_bounds__(DevicePipeline::max_threads_per_block, 1) __global__
    void fused_epilogue_kernel(__grid_constant__ DevicePipeline const device_pipeline,
                               Alpha                                  alpha,
                               Beta                                   beta,
                               CTensor                                gmem_c_fp64,
                               tutorial::matrix_half                  output_half,
                               AShiftTensor const                     gmem_shift_a,
                               AtShiftTensor const                    gmem_shift_at) {
    extern __shared__ __align__(device_pipeline.buffer_alignment()) char smem[];
#ifdef __CUDA_ARCH__
    if constexpr (cublasdx::sm_of_v<BLAS> == __CUDA_ARCH__) {
        // ================================
        // 1. SETUP AND TILE PREPARATION
        // ================================

        constexpr int tile_m = cublasdx::size_of_v_m<BLAS>;
        constexpr int tile_n = cublasdx::size_of_v_n<BLAS>;

        auto const block_offset_m = blockIdx.x * tile_m;
        auto const block_offset_n = blockIdx.y * tile_n;

        if ((block_offset_n > (block_offset_m + tile_m) and output_half == tutorial::matrix_half::lower) or
            (block_offset_m > (block_offset_n + tile_n) and output_half == tutorial::matrix_half::upper)) {
            return;
        }

        constexpr auto initial_diag = Slices - 1;
        constexpr auto initial_term = 0;

        auto [pipeline_smem, smem_shift_a, smem_shift_at] =
            cublasdx::shared_memory::slice<char, int32_t, int32_t>(smem,
                                                                   device_pipeline.buffer_alignment(),
                                                                   device_pipeline.buffer_size(),
                                                                   cublasdx::alignment_of_v_a<BLAS>,
                                                                   cute::make_layout(cute::Int<tile_m>()),
                                                                   cublasdx::alignment_of_v_b<BLAS>,
                                                                   cute::make_layout(cute::Int<tile_n>()));

        // Copy general purpose data
        cublasdx::copy<BLAS, 16>(gmem_shift_a(cute::_, blockIdx.x), smem_shift_a);
        cublasdx::copy<BLAS, 16>(gmem_shift_at(cute::_, blockIdx.y), smem_shift_at);
        cublasdx::copy_wait();


        // Get pipeline tile
        auto tile_pipeline = device_pipeline.get_tile(pipeline_smem,
                                                      cublasdx::make_coord(blockIdx.x, initial_term),
                                                      cublasdx::make_coord(blockIdx.y, initial_diag));

        auto accumulator = tile_pipeline.get_accumulator();

        // ================================
        // 2. FP64 C INPUT / OUTPUT TILE SETUP
        // ================================

        auto tile_c_fp64_gmem = cublasdx::get_tile(gmem_c_fp64, BLAS::c_shape, blockIdx.x, blockIdx.y);

        // ============================================
        // 3. OZAKI SCHEME DIAGONAL ITERATION
        // ============================================

        // Iterate over diagonals in reverse order (highest power of 2 first)
        // This ensures proper accumulation order for numerical stability
#    pragma unroll 1
        for (auto diag = initial_diag; diag >= 0; --diag) {

            // Initialize accumulator for this diagonal
            accumulator.clear();

            // ==========================================
            // 4. SLICE COMBINATION COMPUTATION
            // ==========================================

            // Compute all slice combinations that contribute to this diagonal
            // For diagonal d, we compute: A_slice[i] * B_slice[d-i] for i = 0 to d
#    pragma unroll 1
            for (auto term = initial_term; term <= diag; ++term) {
                // =========================================
                // 5. N-STAGE MEMORY PIPELINE FOR GEMM
                // =========================================

                tile_pipeline.execute(accumulator);

                const auto next_slice_row = (term == diag) ? 0 : term + 1;                         // A slice index
                const auto next_slice_col = (term == diag) ? (diag - 1) : (diag - next_slice_row); // B slice index
                device_pipeline.reset_tile(tile_pipeline,
                                           cublasdx::make_coord(blockIdx.x, next_slice_row),
                                           cublasdx::make_coord(blockIdx.y, next_slice_col));
            } /* end of slice combination loop */

            // ========================================
            // 6. RESULT RECONSTRUCTION AND EPILOGUE
            // ========================================


            // Load existing C values
            auto d_fp64_frag = accumulator.make_partition_and_copy(tile_c_fp64_gmem);

            if(accumulator.is_thread_active()) {
                // Convert accumulated int32_t results back to double precision
                // and apply appropriate scaling based on slice positions
                auto gemm_results = accumulator.get_results();

                // Process each element in the register fragment
#    pragma unroll
                for (int i = 0; i < cublasdx::size(d_fp64_frag); ++i) {
                    auto const [global_x, global_y] = accumulator.map_fragment_index(i);
                    auto const shift_a_elem         = smem_shift_a(global_x);
                    auto const shift_at_elem        = smem_shift_at(global_y);

                    int const total_x = block_offset_m + global_x;
                    int const total_y = block_offset_n + global_y;
                    bool const is_in_bounds = (output_half == tutorial::matrix_half::lower and (total_x >= total_y)) or
                                              (output_half == tutorial::matrix_half::upper and (total_y >= total_x));

                    // Convert int32_t slice result back to double precision
                    // with appropriate scaling for this diagonal and element
                    double const val = nth_slice_to_fp64<int32_t, int8_t>(diag, gemm_results(i), shift_a_elem + shift_at_elem);
                    d_fp64_frag(i) = is_in_bounds ? (alpha * val + beta * d_fp64_frag(i)) : d_fp64_frag(i);
                }
            }

            accumulator.partition_and_copy(d_fp64_frag, tile_c_fp64_gmem);
            beta = 1.0;
        }
    }
#endif
}

In [None]:
!cmake --build ./build -t 3a_fused_syrk_emulation

In [None]:
!./build/3a_fused_syrk_emulation

### Python

In [None]:
import sys
import os

import numpy as np
import cupy as cp
import nvmath
import math

from nvmath.device import Matmul
from nvmath.device.cublasdx import DevicePipeline, SharedStorageCalc
from nvmath.device.cublasdx_numba import pipeline_extensions
from nvmath.device.common import axpby, clear, copy, copy_fragment, copy_wait, make_tensor, make_fragment_like
from numba import cuda

sys.path.append(os.sep.join(["..", "utilities", "python"]))

from benchmark import *
from emulation_utils import get_width, epilogue_ldexp, MatrixHalf

In [None]:
problems = [
  (8192, 8192, 0.9, 1.1, 'L'),
  (8192, 8192, 0.9, 1.1, 'U'),
]

In [None]:
def get_emulated_dsyrk_kernel(BLAS, matrix_half):
    
    assert BLAS.a_value_type == BLAS.b_value_type, "Invalid BLAS configuration"

    TILE_M, TILE_N = BLAS.c_dim
    TILE_K = BLAS.a_dim[1]
    BLOCK_SIZE = BLAS.block_size
    ALIGNMENT = min(BLAS.alignment.a, min(BLAS.alignment.b, BLAS.alignment.c))

    uint8_width = get_width(np.uint8)

    assert TILE_M == TILE_N, "Invalid SYRK configuration"
    is_lower = (matrix_half == MatrixHalf.lower)
    
    @cuda.jit(device=True, forceinline=True)
    def nth_slice_to_fp64(nth, nth_slice, exponent_shift):
        ko = math.pow(2.0, -nth * uint8_width)

        value = ko * np.float64(nth_slice)
        return epilogue_ldexp(value, -exponent_shift)
    
    @cuda.jit(extensions=pipeline_extensions, launch_bounds=(BLOCK_SIZE, 1))
    def dsyrk_kernel(slices, shift_a_tensor, alpha, beta, tensor_c, device_pipeline: DevicePipeline):
        # CHALLENGE EXERCISE --> Implement a fused SYRK kernel
        pass

    return dsyrk_kernel

In [None]:
def fused_dsyrk_ozaki(tensor_slicedA_cupy, tensor_c_cupy, tensor_shift_a_cupy, alpha, beta, context, warmup=True):
    m, n           = tensor_c_cupy.shape
    _, k, slices   = tensor_slicedA_cupy.shape

    BLAS = context["BLAS"]
    PIPELINE_DEPTH = context["PIPELINE_DEPTH"]
    syrk_kernel = context["syrk_kernel"]
    grid = context["syrk_grid"]
    block = context["syrk_block"]

    _, TILE_N = BLAS.c_dim

    # Create transposed view of A for A^T
    # Swap the shape and strides for the first two dimensions
    stride_n, stride_k, stride_slices = tensor_slicedA_cupy.strides
    tensor_slicedAT_cupy = cp.ndarray(
        shape=(k, m, slices),
        dtype=np.int8,
        memptr=tensor_slicedA_cupy.data,
        strides=(stride_k, stride_n, stride_slices)
    )

    tensor_slicedA = cuda.as_cuda_array(tensor_slicedA_cupy)
    tensor_slicedAT = cuda.as_cuda_array(tensor_slicedAT_cupy)
    tensor_shift_a = cuda.as_cuda_array(tensor_shift_a_cupy)
    tensor_c = cuda.as_cuda_array(tensor_c_cupy)

    device_pipeline = BLAS.suggest_device_pipeline(PIPELINE_DEPTH, tensor_slicedA, tensor_slicedAT)

    smem_size = device_pipeline.buffer_size + (TILE_N + TILE_N) * np.dtype(np.int32).itemsize
    
    if warmup:
        set_max_dynamic_shared_size_bytes(syrk_kernel, smem_size,
                                            slices, tensor_shift_a, alpha, beta, tensor_c, device_pipeline)
    syrk_kernel[grid, block, 0, smem_size](slices, tensor_shift_a, alpha, beta, tensor_c, device_pipeline)

In [None]:
def setup_func(n, k, matrix_half):
    TILE_N = 128
    TILE_K = 128
    PIPELINE_DEPTH = 3
    BLOCK_SIZE = 128
    ALIGNMENT = 16
    DATA_TYPE = "real"

    assert n % TILE_N == 0, "Unsupported dimension n for TILE_N"
    assert k % TILE_K == 0, "Unsupported dimension n for TILE_N"
    assert k >= (TILE_K * PIPELINE_DEPTH), "Unsupported pipeline depth for k"
    
    BLAS = Matmul(size=(TILE_N, TILE_N, TILE_K),
                  precision=(np.int8, np.int8, np.int32),
                  data_type=DATA_TYPE,
                  alignment=(ALIGNMENT, ALIGNMENT, ALIGNMENT),
                  arrangement=("row_major", "col_major", "col_major"), # Do not change
                  execution="Block",
                  block_size=BLOCK_SIZE,
                  with_pipeline=True,
                  enable_input_streaming=True,
                  static_block_dim=True)

    syrk_grid = (n // TILE_N, n // TILE_N)
    syrk_block = BLAS.block_dim

    return {
        "BLAS": BLAS,
        "PIPELINE_DEPTH": PIPELINE_DEPTH,
        "syrk_kernel" : get_emulated_dsyrk_kernel(BLAS, matrix_half),
        "syrk_grid": syrk_grid,
        "syrk_block": syrk_block,
    }

In [None]:
benchmark_fused_emulated_dsyrk(problems, setup_func, fused_dsyrk_ozaki)

#### Solution

In [None]:
def get_emulated_dsyrk_kernel_solution(BLAS, matrix_half):

    assert BLAS.a_value_type == BLAS.b_value_type, "Invalid BLAS configuration"

    A_SIZE = BLAS.suggest_layout_smem_a().cosize
    B_SIZE = BLAS.suggest_layout_smem_b().cosize
    C_SIZE = BLAS.suggest_layout_rmem_c().cosize

    TILE_M, TILE_N = BLAS.c_dim
    TILE_K = BLAS.a_dim[1]
    BLOCK_SIZE = BLAS.block_size
    ALIGNMENT = min(BLAS.alignment.a, min(BLAS.alignment.b, BLAS.alignment.c))

    uint8_width = get_width(np.uint8)

    assert TILE_M == TILE_N, "Invalid SYRK configuration"
    is_lower = (matrix_half == MatrixHalf.lower)
    
    @cuda.jit(device=True, forceinline=True)
    def nth_slice_to_fp64(nth, nth_slice, exponent_shift):
        ko = math.pow(2.0, -nth * uint8_width)

        value = ko * np.float64(nth_slice)
        return epilogue_ldexp(value, -exponent_shift)
    
    @cuda.jit(extensions=pipeline_extensions, launch_bounds=(BLOCK_SIZE, 1))
    def dsyrk_kernel(slices, shift_a_tensor, alpha, beta, tensor_c, device_pipeline: DevicePipeline):
        m, n = tensor_c.shape

        block_m = cuda.blockIdx.x
        block_n = cuda.blockIdx.y

        smem_pipeline = cuda.shared.array(shape=(0,), dtype=BLAS.a_value_type, alignment=ALIGNMENT)
        
        smem_shift_a  = cuda.shared.array(shape=(TILE_M), dtype=np.int32)
        smem_shift_at = cuda.shared.array(shape=(TILE_N), dtype=np.int32)

        block_start_m = block_m * TILE_M
        block_end_m = (block_m + 1) * TILE_M

        block_start_n = block_n * TILE_N
        block_end_n = (block_n + 1) * TILE_N

        # Skip blocks outside the triangular region
        if is_lower:
            # Lower triangular: skip if block_n > block_m
            if block_n > block_m:
                return
        else:
            # Upper triangular: skip if block_m > block_n
            if block_m > block_n:
                return
        
        if block_start_m >= m or block_start_n >= n:
            return

        shift_a_view = shift_a_tensor[block_start_m : block_end_m]
        shift_at_view = shift_a_tensor[block_start_n : block_end_n]

        tid = cuda.threadIdx.x
        if tid < TILE_M:
            smem_shift_a[tid] = shift_a_view[tid]
        if tid < TILE_N:
            smem_shift_at[tid] = shift_at_view[tid]
        cuda.syncthreads()

        c_view = tensor_c[
            block_start_m : block_end_m,
            block_start_n : block_end_n,
        ]

        ldc = max(c_view.strides) // c_view.itemsize
        gmem_c = make_tensor(c_view, BLAS.get_layout_gmem_c(ldc))
        
        initial_diag = slices - 1
        initial_term = 0

        tile_pipeline = device_pipeline.get_tile(smem_pipeline,
                                                 (block_m, np.int32(initial_term)),
                                                 (block_n, np.int32(initial_diag)))
        
        accumulator = BLAS.suggest_accumulator()
        beta_used = beta
        for diag in range(initial_diag, -1, -1):
            
            accumulator.clear()
            for term in range(initial_term, diag + 1):
                tile_pipeline.execute(accumulator)

                next_slice_row =          0 if term == diag else term + 1
                next_slice_col = (diag - 1) if term == diag else diag - next_slice_row

                device_pipeline.reset_tile(tile_pipeline,
                                           (block_m, np.int32(next_slice_row)),
                                           (block_n, np.int32(next_slice_col)))

            if accumulator.is_thread_active():
                gemm_results = accumulator.get_results()
                c_fp64_frag = make_fragment_like(gemm_results, np.float64)
                copy_fragment(gmem_c, c_fp64_frag)

                for i in range(c_fp64_frag.layout.size):
                    (global_x, global_y) = accumulator.map_fragment_index(i)
                    shift_a = smem_shift_a[global_x]
                    shift_at = smem_shift_at[global_y]

                    syrk_m = block_start_m + global_x
                    syrk_n = block_start_n + global_y
                    if is_lower:
                        in_bounds = (syrk_m >= syrk_n)
                    else:
                        in_bounds = (syrk_m <= syrk_n)

                    value = alpha * nth_slice_to_fp64(diag, gemm_results[i], shift_a + shift_at)
                    c_fp64_frag[i] = value + beta_used * c_fp64_frag[i] if in_bounds else c_fp64_frag[i]

                accumulator.partition_and_copy(c_fp64_frag, gmem_c)

            beta_used = 1.0

        tile_pipeline._del()

    return dsyrk_kernel

In [None]:
def setup_func_solution(n, k, matrix_half):
    ctx = setup_func(n, k, matrix_half)
    BLAS = ctx["BLAS"]
    ctx["syrk_kernel"] = get_emulated_dsyrk_kernel_solution(BLAS, matrix_half)

    return ctx

In [None]:
benchmark_fused_emulated_dsyrk(problems, setup_func_solution, fused_dsyrk_ozaki)

## Conclusion

In this chapter you have customized the kernel from 3.3 even further to accelerate a different algorithm, using the underlying flexibility of writing custom kernels with libraries only as building blocks.