## Exercise 3.3: Fully fused Ozaki-I Scheme

Our next optimization will be to fuse the remaining portions of the epilogue function.  This would be casting anti-diagonal accumulators to FP64, scaling the FP64 values, accumulating anti-diagonals, and scaling based on the exponent shifting done for slicing.

Once you are done, spend some time profiling the kernels across a few different problem shapes.  Think about how the results change and why.  What factors are causing this?  Hint: look at the grid dimensions and consider hardware resources.

Some other questions you can consider:
1. Where does fusion seem to help the most?
2. Can you find cases where fusion does not help?

<img src="Images/Ozaki-I-Multiplications-Fused.png" width="800" height="auto"/>

### C++ Cmake Configuration

In [None]:
import sys, os
sys.path.append(os.sep.join(["..", "utilities", "python"]))
from common_cuda import setup_cmake_project
setup_cmake_project()

### Python Imports

In [None]:
import sys
import os

import numpy as np
import cupy as cp
import nvmath
import math

from nvmath.device import Matmul
from nvmath.device.cublasdx import DevicePipeline, SharedStorageCalc, MAX_ALIGNMENT
from nvmath.device.cublasdx_numba import pipeline_extensions
from nvmath.device.common import axpby, clear, copy, copy_fragment, copy_wait, make_tensor, make_fragment_like
from numba import cuda

sys.path.append(os.sep.join(["..", "utilities", "python"]))

from benchmark import *
from emulation_utils import get_width, epilogue_ldexp

### C++

In [None]:
%%writefile cpp/2c_fully_fused_emulation/parameters.hpp.inc

    // ===================================
    // Problem configuration
    // ===================================

    // (gemm_m, gemm_n, gemm_k, alpha, beta)
    std::vector<tutorial::gemm_problem_t> problems = {
        {2048, 2048, 2048, 0.9, 1.1}
    };
    

    // ===================================
    // Global GEMM configuration
    // ===================================

    // The number of slices used in emulation algorithm
    // More slices = higher precision but more computation
    constexpr unsigned slices = 7;

In [None]:
%%writefile cpp/2c_fully_fused_emulation/cublasdx_config.hpp.inc

    using slice_value_type       = int8_t;  // Precision for individual slices
    using accumulator_value_type = int32_t; // Precision for accumulation

    // The shape of data tile processed by a single CTA block
    constexpr int tile_m = 128;
    constexpr int tile_n = 128;
    constexpr int tile_k = 128;

    // The shape of CTA block (number of threads)
    constexpr int cta_shape_x = 128;
    constexpr int cta_shape_y = 1;
    constexpr int cta_shape_z = 1;

    using BLAS = decltype(cublasdx::Size<tile_m, tile_n, tile_k>() +
                          cublasdx::Precision<slice_value_type, slice_value_type, accumulator_value_type>() +
                          cublasdx::Type<cublasdx::type::real>() + cublasdx::Function<cublasdx::function::MM>() +
                          cublasdx::Arrangement<arrangement_a, arrangement_b, arrangement_c>() + cublasdx::Block() +
                          cublasdx::BlockDim<cta_shape_x, cta_shape_y, cta_shape_z>() + cublasdx::StaticBlockDim() +
                          cublasdx::WithPipeline() + cublasdx::MaxAlignment() + cublasdx::EnableInputStreaming() +
                          cublasdx::SM<SM_VALUE, SM_MODIFIER_VALUE>());

In [None]:
%%writefile cpp/2c_fully_fused_emulation/pipeline_config.hpp.inc

        constexpr int pipeline_depth = 3;
        auto device_pipeline = cublasdx::suggest_device_pipeline<pipeline_depth, BLAS, cublasdx::external_accumulation>(
                                   tensor_slice_a, tensor_slice_b)
                                   .value();

In [None]:
%%writefile cpp/2c_fully_fused_emulation/fused_kernel.hpp.inc

template<class BLAS,
         class DevicePipeline,
         class Alpha,
         class Beta,
         class CTensor,
         class AShiftTensor,
         class BShiftTensor,
         int32_t Slices>
__launch_bounds__(DevicePipeline::max_threads_per_block, 1) __global__
    void fused_epilogue_kernel(__grid_constant__ DevicePipeline const device_pipeline,
                               Alpha                                  alpha,
                               Beta                                   beta,
                               CTensor                                gmem_c_fp64,
                               AShiftTensor const                     gmem_shift_a,
                               BShiftTensor const                     gmem_shift_b) {
    extern __shared__ __align__(device_pipeline.buffer_alignment()) char smem[];
#ifdef __CUDA_ARCH__
    /* 
     * EXERCISE --> Complete the kernel to compute all products, accumulate along diagonals, and convert back to FP64
     */
    if constexpr (cublasdx::sm_of_v<BLAS> == __CUDA_ARCH__) {
        // ================================
        // 1. SETUP AND TILE PREPARATION
        // ================================

        constexpr int tile_m = cublasdx::size_of_v_m<BLAS>;
        constexpr int tile_n = cublasdx::size_of_v_n<BLAS>;

        // EXERCISE --> Choose the diagonal and term along the diagonal that you'd like to start with
        constexpr auto initial_diag = 
        constexpr auto initial_term = 

        auto [pipeline_smem, smem_shift_a, smem_shift_b] =
            cublasdx::shared_memory::slice<char, int32_t, int32_t>(smem,
                                                                   device_pipeline.buffer_alignment(),
                                                                   device_pipeline.buffer_size(),
                                                                   cublasdx::alignment_of_v_a<BLAS>,
                                                                   cute::make_layout(cute::Int<tile_m>()),
                                                                   cublasdx::alignment_of_v_b<BLAS>,
                                                                   cute::make_layout(cute::Int<tile_n>()));

        // Copy general purpose data
        cublasdx::copy<BLAS, 16>(gmem_shift_a(cute::_, blockIdx.x), smem_shift_a);
        cublasdx::copy<BLAS, 16>(gmem_shift_b(cute::_, blockIdx.y), smem_shift_b);
        cublasdx::copy_wait();
            
        // Get pipeline tile
        auto tile_pipeline = device_pipeline.get_tile(pipeline_smem,
                                                      cublasdx::make_coord(blockIdx.x, initial_term),
                                                      cublasdx::make_coord(blockIdx.y, initial_diag));

        auto accumulator = tile_pipeline.get_accumulator();

        // ================================
        // 2. FP64 C INPUT / OUTPUT TILE SETUP
        // ================================

        auto tile_c_fp64_gmem = cublasdx::get_tile(gmem_c_fp64, BLAS::c_shape, blockIdx.x, blockIdx.y);

        // ============================================
        // 3. OZAKI SCHEME DIAGONAL ITERATION
        // ============================================
#    pragma unroll 1
        for (int diag = initial_diag; /* for loop over diagonals */) {

            // Initialize accumulator for this diagonal
            accumulator.clear();

            // ==========================================
            // 4. SLICE COMBINATION COMPUTATION
            // ==========================================
#    pragma unroll 1
            for (int term = initial_term; /* for loop to iterate along the diagonal */) {
                // =========================================
                // 5. N-STAGE MEMORY PIPELINE FOR GEMM
                // =========================================

                tile_pipeline.execute(accumulator);

                const auto next_slice_row = // A slice index
                const auto next_slice_col = // B slice index
                device_pipeline.reset_tile(tile_pipeline,
                                           cublasdx::make_coord(blockIdx.x, next_slice_row),
                                           cublasdx::make_coord(blockIdx.y, next_slice_col));
            }

            // ========================================
            // 6. RESULT RECONSTRUCTION AND EPILOGUE
            // ========================================

            auto c_fp64_frag = accumulator.make_partition_and_copy(tile_c_fp64_gmem);

            if (accumulator.is_thread_active()) {
                // Convert accumulated int32_t results back to double precision
                // and apply appropriate scaling based on slice positions
                auto gemm_results = accumulator.get_results();

                // Load existing C values
                auto d_fp64_frag = cublasdx::make_fragment_like<double>(gemm_results);

                // At this point of the computation, we can no longer longer do tile based operations.  When we convert back to
                // FP64 we need to know the shifts associated with the row of A and column of B that produced this value.  The
                // cublasDx library gives us the ability to figure out the relative index within the tile.  We can use this to
                // find our shifts, do some intermediate computations, and then proceed with more tile computations.
                    
                # pragma unroll
                for (int i = 0; i < cublasdx::size(d_fp64_frag); ++i) {
                    const auto [global_x, global_y] = accumulator.map_fragment_index(i);

                    // Exercise --> Use shared memory to get the shifts for this particular element
                    const auto shift_a_elem = 
                    const auto shift_b_elem = 

                    // Convert int32_t slice result back to double precision
                    // with appropriate scaling for this diagonal and element
                    d_fp64_frag(i) = nth_slice_to_fp64<int32_t, int8_t>(diag, gemm_results(i), shift_a_elem + shift_b_elem);
                }

                // Apply alpha/beta scaling and accumulate into C
                // Use beta only for the first diagonal we process, then just add (beta=1.0)
                double beta_used = beta;
                if (/* EXERCISE --> Figure out when to use 1.0 for beta */) {
                    beta_used = 1.0;
                }
                cublasdx::axpby(alpha, d_fp64_frag, beta_used, c_fp64_frag);                
            }
            
            // Store results back to global memory
            accumulator.partition_and_copy(c_fp64_frag, tile_c_fp64_gmem);
        }
    }
#endif
}

In [None]:
!cmake --build ./build -t 2c_fully_fused_emulation

In [None]:
!./build/2c_fully_fused_emulation

#### Solution

We will rewrite kernel now and recompile the solution. If you want to restart your exercise make sure you rewrite kernel back and recompile it.

In [None]:
%%writefile cpp/2c_fully_fused_emulation/fused_kernel.hpp.inc

template<class BLAS,
         class DevicePipeline,
         class Alpha,
         class Beta,
         class CTensor,
         class AShiftTensor,
         class BShiftTensor,
         int32_t Slices>
__launch_bounds__(DevicePipeline::max_threads_per_block, 1) __global__
    void fused_epilogue_kernel(__grid_constant__ DevicePipeline const device_pipeline,
                               Alpha                                  alpha,
                               Beta                                   beta,
                               CTensor                                gmem_c_fp64,
                               AShiftTensor const                     gmem_shift_a,
                               BShiftTensor const                     gmem_shift_b) {
    extern __shared__ __align__(device_pipeline.buffer_alignment()) char smem[];
#ifdef __CUDA_ARCH__
    if constexpr (cublasdx::sm_of_v<BLAS> == __CUDA_ARCH__) {
        // ================================
        // 1. SETUP AND TILE PREPARATION
        // ================================

        constexpr int tile_m = cublasdx::size_of_v_m<BLAS>;
        constexpr int tile_n = cublasdx::size_of_v_n<BLAS>;

        constexpr auto initial_diag = Slices - 1;
        constexpr auto initial_term = 0;

        auto [pipeline_smem, smem_shift_a, smem_shift_b] =
            cublasdx::shared_memory::slice<char, int32_t, int32_t>(smem,
                                                                   device_pipeline.buffer_alignment(),
                                                                   device_pipeline.buffer_size(),
                                                                   cublasdx::alignment_of_v_a<BLAS>,
                                                                   cute::make_layout(cute::Int<tile_m>()),
                                                                   cublasdx::alignment_of_v_b<BLAS>,
                                                                   cute::make_layout(cute::Int<tile_n>()));

        // Copy general purpose data
        cublasdx::copy<BLAS, 16>(gmem_shift_a(cute::_, blockIdx.x), smem_shift_a);
        cublasdx::copy<BLAS, 16>(gmem_shift_b(cute::_, blockIdx.y), smem_shift_b);
        cublasdx::copy_wait();


        // Get pipeline tile
        auto tile_pipeline = device_pipeline.get_tile(pipeline_smem,
                                                      cublasdx::make_coord(blockIdx.x, initial_term),
                                                      cublasdx::make_coord(blockIdx.y, initial_diag));

        auto accumulator = tile_pipeline.get_accumulator();

        // ================================
        // 2. FP64 C INPUT / OUTPUT TILE SETUP
        // ================================

        auto tile_c_fp64_gmem = cublasdx::get_tile(gmem_c_fp64, BLAS::c_shape, blockIdx.x, blockIdx.y);

        // ============================================
        // 3. OZAKI SCHEME DIAGONAL ITERATION
        // ============================================

        // Iterate over diagonals in reverse order (highest power of 2 first)
        // This ensures proper accumulation order for numerical stability
#    pragma unroll 1
        for (auto diag = initial_diag; diag >= 0; --diag) {

            // Initialize accumulator for this diagonal
            accumulator.clear();

            // ==========================================
            // 4. SLICE COMBINATION COMPUTATION
            // ==========================================

            // Compute all slice combinations that contribute to this diagonal
            // For diagonal d, we compute: A_slice[i] * B_slice[d-i] for i = 0 to d
#    pragma unroll 1
            for (auto term = initial_term; term <= diag; ++term) {
                // =========================================
                // 5. N-STAGE MEMORY PIPELINE FOR GEMM
                // =========================================

                tile_pipeline.execute(accumulator);

                const auto next_slice_row = (term == diag) ? 0 : term + 1;                         // A slice index
                const auto next_slice_col = (term == diag) ? (diag - 1) : (diag - next_slice_row); // B slice index
                device_pipeline.reset_tile(tile_pipeline,
                                           cublasdx::make_coord(blockIdx.x, next_slice_row),
                                           cublasdx::make_coord(blockIdx.y, next_slice_col));
            } /* end of slice combination loop */

            // ========================================
            // 6. RESULT RECONSTRUCTION AND EPILOGUE
            // ========================================

            // Load existing C values
            auto c_fp64_frag = accumulator.make_partition_and_copy(tile_c_fp64_gmem);

            if (accumulator.is_thread_active()) {
                // Convert accumulated int32_t results back to double precision
                // and apply appropriate scaling based on slice positions
                auto gemm_results = accumulator.get_results();

                auto d_fp64_frag = cublasdx::make_fragment_like<double>(gemm_results);

                // At this point of the compuation, we can no longer longer do tile based operations.  When we convert back to
                // FP64 we need to know the shifts associated with the row of A and column of B that produced this value.  The
                // cublasDx library gives us the ability to figure out the relative index within the tile.  We can use this to
                // find our shifts, do some intermediate computations, and then proceed with more tile computations.

                # pragma unroll
                for (int i = 0; i < cublasdx::size(d_fp64_frag); ++i) {
                    const auto [global_x, global_y] = accumulator.map_fragment_index(i);
                    const auto shift_a_elem         = smem_shift_a(global_x);
                    const auto shift_b_elem         = smem_shift_b(global_y);

                    // Convert int32_t slice result back to double precision
                    // with appropriate scaling for this diagonal and element
                    d_fp64_frag(i) =
                        nth_slice_to_fp64<int32_t, int8_t>(diag, gemm_results(i), shift_a_elem + shift_b_elem);
                }

                // Apply alpha/beta scaling and accumulate into C
                // Use beta only for the first diagonal (highest order), then just add (beta=1.0)
                cublasdx::axpby(alpha, d_fp64_frag, (diag == Slices - 1) ? beta : 1.0, c_fp64_frag);
            }

            // Store results back to global memory
            accumulator.partition_and_copy(c_fp64_frag, tile_c_fp64_gmem);       
        }
    }
#endif
}

In [None]:
!cmake --build ./build -t 2c_fully_fused_emulation

In [None]:
!./build/2c_fully_fused_emulation

### Python

In [None]:
problems = [
  (2048, 2048, 2048, 0.9, 1.1),
]

In [None]:
def get_emulated_dgemm_kernel(BLAS):

    assert BLAS.a_value_type == BLAS.b_value_type, "Invalid BLAS configuration"

    tile_m, tile_n = BLAS.c_dim

    uint8_width = get_width(np.uint8)
    
    @cuda.jit(device=True, forceinline=True)
    def nth_slice_to_fp64(nth, nth_slice, exponent_shift):
        ko = math.pow(2.0, -nth * uint8_width)

        value = ko * np.float64(nth_slice)
        return epilogue_ldexp(value, -exponent_shift)
    
    @cuda.jit(extensions=pipeline_extensions, launch_bounds=(BLAS.block_size, 1))
    def dgemm_kernel(slices, shift_a_tensor, shift_b_tensor, alpha, beta, tensor_c, device_pipeline: DevicePipeline):
        m, n = tensor_c.shape

        # EXERCISE --> Complete the kernel to compute all products, accumulate along diagonals, and convert back to FP64

        # ================================
        # 1. SETUP AND TILE PREPARATION
        # ================================

        block_m = cuda.blockIdx.x
        block_n = cuda.blockIdx.y
        
        # EXERCISE --> Choose the diagonal and term along the diagonal that you'd like to start with
        initial_diag = -1
        initial_term = -1

        smem = cuda.shared.array(shape=(0,), dtype=np.int8, alignment=device_pipeline.buffer_alignment)
        smem_pipeline, smem = smem[:device_pipeline.buffer_size], smem[device_pipeline.buffer_size:].view(np.int32)
        smem_shift_a, smem = smem[:tile_m], smem[tile_m:]
        smem_shift_b, smem = smem[:tile_n], smem[tile_n:]

        # Copy general purpose data
        block_start_m = block_m * tile_m
        block_end_m = (block_m + 1) * tile_m

        block_start_n = block_n * tile_n
        block_end_n = (block_n + 1) * tile_n

        if block_start_m >= m or block_start_n >= n:
            return

        shift_a_view = shift_a_tensor[block_start_m : block_end_m]
        shift_b_view = shift_b_tensor[block_start_n : block_end_n]

        tid = cuda.threadIdx.x
        if tid < tile_m:
            smem_shift_a[tid] = shift_a_view[tid]
        if tid < tile_n:
            smem_shift_b[tid] = shift_b_view[tid]
        cuda.syncthreads()

        # Get pipeline tile
        tile_pipeline = device_pipeline.get_tile(smem_pipeline,
                                                 (block_m, np.int32(initial_term)),
                                                 (block_n, np.int32(initial_diag)))
        
        accumulator = BLAS.suggest_accumulator()

        # ================================
        # 2. FP64 C INPUT / OUTPUT TILE SETUP
        # ================================

        c_view = tensor_c[
            block_start_m : block_end_m,
            block_start_n : block_end_n,
        ]

        ldc = max(c_view.strides) // c_view.itemsize
        gmem_c = make_tensor(c_view, BLAS.get_layout_gmem_c(ldc))
        
        # ============================================
        # 3. OZAKI SCHEME DIAGONAL ITERATION
        # ============================================
        for diag in range(-1): # EXERCISE --> for loop over diagonals
            
            accumulator.clear()

            # ==========================================
            # 4. SLICE COMBINATION COMPUTATION
            # ==========================================
            for term in range(-1): # EXERCISE --> for loop to iterate along the diagonal
                # =========================================
                # 5. N-STAGE MEMORY PIPELINE FOR GEMM
                # =========================================
                tile_pipeline.execute(accumulator)

                # EXERCISE --> Determine which slice of A and slice of B to multiply
                next_slice_row = -1
                next_slice_col = -1

                device_pipeline.reset_tile(tile_pipeline,
                                           (block_m, np.int32(next_slice_row)),
                                           (block_n, np.int32(next_slice_col)))

            # ========================================
            # 6. RESULT RECONSTRUCTION AND EPILOGUE
            # ========================================
            if accumulator.is_thread_active():
                # Convert accumulated int32_t results back to double precision
                # and apply appropriate scaling based on slice positions
                gemm_results = accumulator.get_results()

                # Load existing C values
                c_fp64_frag = accumulator.make_partition_and_copy(gmem_c)
                d_fp64_frag = make_fragment_like(gemm_results, np.float64)

                # At this point of the compuation, we can no longer longer do tile based operations.  When we convert back to
                # FP64 we need to know the shifts associated with the row of A and column of B that produced this value.  The
                # nvmath-python cublasDx bindings give us the ability to figure out the relative index within the tile.  We can
                # use this to find our shifts, do some intermediate computations, and then proceed with more tile computations.
                for i in range(c_fp64_frag.layout.size):
                    # Get the elements offsets within the output tile
                    (global_x, global_y) = accumulator.map_fragment_index(i)
                    # Exercise --> Use shared memory to get the shifts for this particular element
                    shift_a = -1
                    shift_b = -1

                    # Convert int32_t slice result back to double precision
                    # with appropriate scaling for this diagonal and element
                    d_fp64_frag[i] = nth_slice_to_fp64(diag, gemm_results[i], shift_a + shift_b)

                # Apply alpha/beta scaling and accumulate into C
                # Use beta only for the first diagonal we process, then just add (beta=1.0)
                beta_used = beta
                if True: # EXERCISE -> Figure out when to use 1.0 for beta
                    beta_used = 1.0
                axpby(alpha, d_fp64_frag, beta_used, c_fp64_frag)

                accumulator.partition_and_copy(c_fp64_frag, gmem_c)

        tile_pipeline._del()

    return dgemm_kernel

In [None]:
def fused_dgemm_ozaki(tensor_slicedA_cupy, tensor_slicedB_cupy, tensor_c_cupy, tensor_shift_a_cupy, tensor_shift_b_cupy, alpha, beta, context, warmup=True):
    m, n           = tensor_c_cupy.shape
    _, k, slices   = tensor_slicedA_cupy.shape

    BLAS = context["BLAS"]
    PIPELINE_DEPTH = context["PIPELINE_DEPTH"]
    gemm_kernel = context["gemm_kernel"]
    grid = context["gemm_grid"]
    block = context["gemm_block"]

    TILE_M, TILE_N = BLAS.c_dim

    tensor_slicedA = cuda.as_cuda_array(tensor_slicedA_cupy)
    tensor_slicedB = cuda.as_cuda_array(tensor_slicedB_cupy)
    tensor_shift_a = cuda.as_cuda_array(tensor_shift_a_cupy)
    tensor_shift_b = cuda.as_cuda_array(tensor_shift_b_cupy)
    tensor_c = cuda.as_cuda_array(tensor_c_cupy)

    device_pipeline = BLAS.suggest_device_pipeline(PIPELINE_DEPTH, tensor_slicedA, tensor_slicedB)

    smem_size = device_pipeline.buffer_size + (TILE_M + TILE_N) * np.dtype(np.int32).itemsize
    if warmup:
        set_max_dynamic_shared_size_bytes(gemm_kernel, smem_size,
                                            slices, tensor_shift_a, tensor_shift_b, alpha, beta, tensor_c, device_pipeline)
    gemm_kernel[grid, block, 0, smem_size](slices, tensor_shift_a, tensor_shift_b, alpha, beta, tensor_c, device_pipeline)

In [None]:
def setup_func(m, n, k):
    tile_m = 128
    tile_n = 128
    tile_k = 128
    
    pipeline_depth = 3
    block_size = 128

    assert m % tile_m == 0, "Unsupported dimension m for TILE_M"
    assert n % tile_n == 0, "Unsupported dimension n for TILE_N"
    assert k % tile_k == 0, "Unsupported dimension k for TILE_K"
    assert k >= (tile_k * pipeline_depth), "Unsupported pipeline depth for k"
    
    BLAS = Matmul(size=(tile_m, tile_n, tile_k),
                  precision=(np.int8, np.int8, np.int32),
                  data_type="real",
                  alignment=MAX_ALIGNMENT,
                  arrangement=("row_major", "col_major", "col_major"), # Do not change
                  execution="Block",
                  block_size=block_size,
                  with_pipeline=True,
                  enable_input_streaming=True,
                  static_block_dim=True)

    gemm_grid = (m // tile_m, n // tile_n)
    gemm_block = BLAS.block_dim

    return {
        "BLAS": BLAS,
        "PIPELINE_DEPTH": pipeline_depth,
        "gemm_kernel" : get_emulated_dgemm_kernel(BLAS),
        "gemm_grid": gemm_grid,
        "gemm_block": gemm_block,
    }

In [None]:
benchmark_fused_emulated_dgemm(problems, setup_func, fused_dgemm_ozaki)

#### Solution

In [None]:
def get_emulated_dgemm_kernel_solution(BLAS):

    assert BLAS.a_value_type == BLAS.b_value_type, "Invalid BLAS configuration"

    tile_m, tile_n = BLAS.c_dim

    uint8_width = get_width(np.uint8)
    
    @cuda.jit(device=True, forceinline=True)
    def nth_slice_to_fp64(nth, nth_slice, exponent_shift):
        ko = math.pow(2.0, -nth * uint8_width)

        value = ko * np.float64(nth_slice)
        return epilogue_ldexp(value, -exponent_shift)
    
    @cuda.jit(extensions=pipeline_extensions, launch_bounds=(BLAS.block_size, 1))
    def dgemm_kernel(slices, shift_a_tensor, shift_b_tensor, alpha, beta, tensor_c, device_pipeline: DevicePipeline):
        m, n = tensor_c.shape

        block_m = cuda.blockIdx.x
        block_n = cuda.blockIdx.y

        smem = cuda.shared.array(shape=(0,), dtype=np.int8, alignment=device_pipeline.buffer_alignment)
        smem_pipeline, smem = smem[:device_pipeline.buffer_size], smem[device_pipeline.buffer_size:].view(np.int32)
        smem_shift_a, smem = smem[:tile_m], smem[tile_m:]
        smem_shift_b, smem = smem[:tile_n], smem[tile_n:]

        block_start_m = block_m * tile_m
        block_end_m = (block_m + 1) * tile_m

        block_start_n = block_n * tile_n
        block_end_n = (block_n + 1) * tile_n

        if block_start_m >= m or block_start_n >= n:
            return

        shift_a_view = shift_a_tensor[block_start_m : block_end_m]
        shift_b_view = shift_b_tensor[block_start_n : block_end_n]

        tid = cuda.threadIdx.x
        if tid < tile_m:
            smem_shift_a[tid] = shift_a_view[tid]
        if tid < tile_n:
            smem_shift_b[tid] = shift_b_view[tid]
        cuda.syncthreads()

        c_view = tensor_c[
            block_start_m : block_end_m,
            block_start_n : block_end_n,
        ]

        ldc = max(c_view.strides) // c_view.itemsize
        gmem_c = make_tensor(c_view, BLAS.get_layout_gmem_c(ldc))
        
        initial_diag = slices - 1
        initial_term = 0

        tile_pipeline = device_pipeline.get_tile(smem_pipeline,
                                                 (block_m, np.int32(initial_term)),
                                                 (block_n, np.int32(initial_diag)))
        
        accumulator = BLAS.suggest_accumulator()
        for diag in range(initial_diag, -1, -1):
            
            accumulator.clear()
            for term in range(initial_term, diag + 1):
                tile_pipeline.execute(accumulator)

                next_slice_row =          0 if term == diag else term + 1
                next_slice_col = (diag - 1) if term == diag else diag - next_slice_row

                device_pipeline.reset_tile(tile_pipeline,
                                           (block_m, np.int32(next_slice_row)),
                                           (block_n, np.int32(next_slice_col)))

            if accumulator.is_thread_active():
                gemm_results = accumulator.get_results()
        
                acc_fp64_frag = make_fragment_like(gemm_results, np.float64)
                c_fp64_frag = accumulator.make_partition_and_copy(gmem_c)
        
                # At this point of the compuation, we can no longer longer do tile based operations.  When we convert back to
                # FP64 we need to know the shifts associated with the row of A and column of B that produced this value.  The
                # nvmath-python cublasDx bindings give us the ability to figure out the relative index within the tile.  We can
                # use this to find our shifts, do some intermediate computations, and then proceed with more tile computations.
                for i in range(c_fp64_frag.layout.size):
                    (global_x, global_y) = accumulator.map_fragment_index(i)
                    shift_a = smem_shift_a[global_x]
                    shift_b = smem_shift_b[global_y]
        
                    acc_fp64_frag[i] = nth_slice_to_fp64(diag, gemm_results[i], shift_a + shift_b)
        
                beta_used = beta if diag == slices - 1 else 1.0
                axpby(alpha, acc_fp64_frag, beta_used, c_fp64_frag)
                accumulator.partition_and_copy(c_fp64_frag, gmem_c)

        tile_pipeline._del()

    return dgemm_kernel

In [None]:
def setup_func_solution(m, n, k):
    ctx = setup_func(m, n, k)
    BLAS = ctx["BLAS"]
    ctx["gemm_kernel"] = get_emulated_dgemm_kernel_solution(BLAS)
    return ctx

In [None]:
benchmark_fused_emulated_dgemm(problems, setup_func_solution, fused_dgemm_ozaki)

## Conclusion

In this notebook, we finished fusing the epilogue kernel into our emulated gemm kernel.  The core technique needed was an API to get our relative coordinates within the tile and use that to make element specific updates.  From there, we kept utilizing tile-based API's for efficiency and simplicity.