## Exercise 3.2: Optimizing the Ozaki-I Scheme with kernel fusion

Now that we have a working Ozaki-I implementation, we'd like to start fusing the epilogue function into our emulated GEMM kernel. This will help reduce the global memory traffic.

In this exercise, we will start this by computing all slice products and accumulate the anti-diagonals in one kernel:

<img src="Images/Ozaki-I-Multiplications-Partial-Fusion.png" width="800" height="auto"/>

### C++ Cmake Configuration

In [None]:
import sys, os
sys.path.append(os.sep.join(["..", "utilities", "python"]))
from common_cuda import setup_cmake_project
setup_cmake_project()

### Python Imports

In [None]:
import sys
import os
import math

import numpy as np
import cupy as cp
import nvmath

from nvmath.device import Matmul
from nvmath.device.cublasdx import DevicePipeline, SharedStorageCalc, MAX_ALIGNMENT
from nvmath.device.cublasdx_numba import pipeline_extensions
from nvmath.device.common import axpby, clear, copy, copy_fragment, copy_wait, make_tensor
from numba import cuda

sys.path.append(os.sep.join(["..", "utilities", "python"]))

from benchmark import *
from emulation_utils import get_width, epilogue_ldexp

### C++

In [None]:
%%writefile cpp/2b_partially_fused_emulation/parameters.hpp.inc

    // ===================================
    // Problem configuration
    // ===================================

    // (gemm_m, gemm_n, gemm_k, alpha, beta)
    std::vector<tutorial::gemm_problem_t> problems = {
        {2048, 2048, 2048, 0.9, 1.1}
    };
    

    // ===================================
    // Global GEMM configuration
    // ===================================

    // The number of slices used in emulation algorithm
    // More slices = higher precision but more computation
    constexpr unsigned slices = 7;

In [None]:
%%writefile cpp/2b_partially_fused_emulation/cublasdx_config.hpp.inc

    using slice_value_type       = int8_t;  // Precision for individual slices
    using accumulator_value_type = int32_t; // Precision for accumulation

    // The shape of data tile processed by a single CTA block
    constexpr int tile_m = 128;
    constexpr int tile_n = 128;
    constexpr int tile_k = 128;

    // The shape of CTA block (number of threads)
    constexpr int cta_shape_x = 128;
    constexpr int cta_shape_y = 1;
    constexpr int cta_shape_z = 1;

    using BLAS = decltype(cublasdx::Size<tile_m, tile_n, tile_k>() +
                          cublasdx::Precision<slice_value_type, slice_value_type, accumulator_value_type>() +
                          cublasdx::Type<cublasdx::type::real>() + cublasdx::Function<cublasdx::function::MM>() +
                          cublasdx::Arrangement<arrangement_a, arrangement_b, arrangement_c>() + cublasdx::Block() +
                          cublasdx::BlockDim<cta_shape_x, cta_shape_y, cta_shape_z>() + cublasdx::StaticBlockDim() +
                          cublasdx::WithPipeline() + cublasdx::MaxAlignment() + cublasdx::EnableInputStreaming() +
                          cublasdx::SM<SM_VALUE, SM_MODIFIER_VALUE>());

In [None]:
%%writefile cpp/2b_partially_fused_emulation/pipeline_config.hpp.inc

        constexpr int pipeline_depth = 3;
        auto device_pipeline = cublasdx::suggest_device_pipeline<pipeline_depth, BLAS, cublasdx::external_accumulation>(
                                   tensor_slice_a, tensor_slice_b)
                                   .value();

In [None]:
%%writefile cpp/2b_partially_fused_emulation/fused_kernel.hpp.inc

template<int Slices, class BLAS, class DevicePipeline, class SliceProductTensor>
__launch_bounds__(DevicePipeline::max_threads_per_block, 1) __global__
    void fused_epilogue_kernel(__grid_constant__ DevicePipeline const device_pipeline,
                               SliceProductTensor                     slice_product_tensor) {
    extern __shared__ __align__(device_pipeline.buffer_alignment()) char smem[];
#ifdef __CUDA_ARCH__
    /* 
     * EXERCISE --> Complete the kernel to compute all products and accumulate along diagonals in the same kernel
     */

    if constexpr (cublasdx::sm_of_v<BLAS> == __CUDA_ARCH__) {
        // ================================
        // 1. SETUP AND TILE PREPARATION
        // ================================

        // EXERCISE --> Choose your starting diagonal and term along the diagonal
        constexpr auto initial_diag = ;
        constexpr auto initial_term = ;

        // Get pipeline tile
        auto tile_pipeline = device_pipeline.get_tile(
            smem, cublasdx::make_coord(blockIdx.x, initial_term), cublasdx::make_coord(blockIdx.y, initial_diag));

        auto accumulator = tile_pipeline.get_accumulator();

        // ============================================
        // 2. OZAKI SCHEME DIAGONAL ITERATION
        // ============================================
#    pragma unroll 1
        for (int diag = initial_diag; /* EXERCISE --> for loop over diagonals */) {

            // Initialize accumulator for this diagonal
            accumulator.clear();

            // ==========================================
            // 3. SLICE COMBINATION COMPUTATION
            // ==========================================
#    pragma unroll 1
            for (int term = initial_term; /* EXERCISE --> for loop to iterate along the diagonal */) {
                // =========================================
                // 4. N-STAGE MEMORY PIPELINE FOR GEMM
                // =========================================

                tile_pipeline.execute(accumulator);

                // EXERCISE --> Determine which slice of A and slice of B to multiply
                const auto next_slice_row = ;
                const auto next_slice_col = ;
                device_pipeline.reset_tile(tile_pipeline,
                                           cublasdx::make_coord(blockIdx.x, next_slice_row),
                                           cublasdx::make_coord(blockIdx.y, next_slice_col));
            }

            // ========================================
            // 5. RESULT RECONSTRUCTION AND EPILOGUE
            // ========================================

            if (accumulator.is_thread_active()) {
                // Choose output tensor for this slice iteration
                auto this_slice_output = slice_product_tensor(cublasdx::slice, cublasdx::slice, diag);
                // Get output tile for this block
                auto slice_output_tile = cublasdx::get_tile(this_slice_output, BLAS::c_shape, blockIdx.x, blockIdx.y);
                // Store results
                accumulator.partition_and_store(slice_output_tile);
            }
        }
    }
#endif
}

In [None]:
%%writefile cpp/2b_partially_fused_emulation/epilogue_config.hpp.inc

        constexpr int epilogue_kernel_tile_m = 16;
        constexpr int epilogue_kernel_tile_n = 16;

In [None]:
%%writefile cpp/2b_partially_fused_emulation/epilogue_kernel.hpp.inc

template<int BlockSize, int Slices, class ProductTensor, class ShiftTensorA, class ShiftTensorB, class OutTensor>
__launch_bounds__(BlockSize, 1) __global__ void epilogue_kernel(double        alpha,
                                                                double        beta,
                                                                ProductTensor product_tensor,
                                                                ShiftTensorA  shift_tensor_a,
                                                                ShiftTensorB  shift_tensor_b,
                                                                OutTensor     out_tensor) {
    using product_datatype = tutorial::tensor_value_type_t<ProductTensor>;
    using shift_datatype   = tutorial::tensor_value_type_t<ShiftTensorA>;
    using out_datatype     = tutorial::tensor_value_type_t<OutTensor>;

    const auto tid_m = threadIdx.x + blockIdx.x * blockDim.x;
    const auto tid_n = threadIdx.y + blockIdx.y * blockDim.y;

    int shift_a = shift_tensor_a(tid_m);
    int shift_b = shift_tensor_b(tid_n);

    /*
     * EXERCISE --> Complete the implementation of the epilogue kernel
     */
    #pragma unroll
    for (/* for loop over diagonals */) {
        product_datatype diag_acc = product_tensor(tid_m, tid_n, diag);
        result += nth_slice_to_fp64<int32_t, int8_t>(diag, diag_acc, shift_a + shift_b);
    }

    out_tensor(tid_m, tid_n) = alpha * result + beta * out_tensor(tid_m, tid_n);
}

In [None]:
!cmake --build ./build -t 2b_partially_fused_emulation

In [None]:
!./build/2b_partially_fused_emulation

#### Solution

We will rewrite kernel now and recompile the solution. If you want to restart your exercise make sure you rewrite kernel back and recompile it.

In [None]:
%%writefile cpp/2b_partially_fused_emulation/fused_kernel.hpp.inc

template<int Slices, class BLAS, class DevicePipeline, class SliceProductTensor>
__launch_bounds__(DevicePipeline::max_threads_per_block, 1) __global__
    void fused_epilogue_kernel(__grid_constant__ DevicePipeline const device_pipeline,
                               SliceProductTensor                     slice_product_tensor) {
    extern __shared__ __align__(device_pipeline.buffer_alignment()) char smem[];
#ifdef __CUDA_ARCH__
    if constexpr (cublasdx::sm_of_v<BLAS> == __CUDA_ARCH__) {
        // ================================
        // 1. SETUP AND TILE PREPARATION
        // ================================

        constexpr auto initial_diag = Slices - 1;
        constexpr auto initial_term = 0;

        // Get pipeline tile
        auto tile_pipeline = device_pipeline.get_tile(
            smem, cublasdx::make_coord(blockIdx.x, initial_term), cublasdx::make_coord(blockIdx.y, initial_diag));

        auto accumulator = tile_pipeline.get_accumulator();

        // ============================================
        // 2. OZAKI SCHEME DIAGONAL ITERATION
        // ============================================

        // Iterate over diagonals in reverse order (highest power of 2 first)
        // This ensures proper accumulation order for numerical stability
#    pragma unroll 1
        for (auto diag = initial_diag; diag >= 0; --diag) {

            // Initialize accumulator for this diagonal
            accumulator.clear();

            // ==========================================
            // 3. SLICE COMBINATION COMPUTATION
            // ==========================================

            // Compute all slice combinations that contribute to this diagonal
            // For diagonal d, we compute: A_slice[i] * B_slice[d-i] for i = 0 to d
#    pragma unroll 1
            for (auto term = initial_term; term <= diag; ++term) {
                // =========================================
                // 4. N-STAGE MEMORY PIPELINE FOR GEMM
                // =========================================

                tile_pipeline.execute(accumulator);

                const auto next_slice_row = (term == diag) ? 0 : term + 1;                         // A slice index
                const auto next_slice_col = (term == diag) ? (diag - 1) : (diag - next_slice_row); // B slice index
                device_pipeline.reset_tile(tile_pipeline,
                                           cublasdx::make_coord(blockIdx.x, next_slice_row),
                                           cublasdx::make_coord(blockIdx.y, next_slice_col));
            } /* end of slice combination loop */

            // ========================================
            // 5. RESULT RECONSTRUCTION AND EPILOGUE
            // ========================================

            if (accumulator.is_thread_active()) {
                // Choose output tensor for this slice iteration
                auto this_slice_output = slice_product_tensor(cublasdx::slice, cublasdx::slice, diag);
                // Get output tile for this block
                auto slice_output_tile = cublasdx::get_tile(this_slice_output, BLAS::c_shape, blockIdx.x, blockIdx.y);
                // Store results
                accumulator.partition_and_store(slice_output_tile);
            }
        }
    }
#endif
}

In [None]:
%%writefile cpp/2b_partially_fused_emulation/epilogue_kernel.hpp.inc

template<int BlockSize, int Slices, class ProductTensor, class ShiftTensorA, class ShiftTensorB, class OutTensor>
__launch_bounds__(BlockSize, 1) __global__ void epilogue_kernel(double        alpha,
                                                                double        beta,
                                                                ProductTensor product_tensor,
                                                                ShiftTensorA  shift_tensor_a,
                                                                ShiftTensorB  shift_tensor_b,
                                                                OutTensor     out_tensor) {
    using product_datatype = tutorial::tensor_value_type_t<ProductTensor>;
    using shift_datatype   = tutorial::tensor_value_type_t<ShiftTensorA>;
    using out_datatype     = tutorial::tensor_value_type_t<OutTensor>;

    const auto tid_m = threadIdx.x + blockIdx.x * blockDim.x;
    const auto tid_n = threadIdx.y + blockIdx.y * blockDim.y;

    int shift_a = shift_tensor_a(tid_m);
    int shift_b = shift_tensor_b(tid_n);

    auto product_view = product_tensor(tid_m, tid_n, cublasdx::slice);

    double result = 0.0;

#pragma unroll
    for (auto diag = Slices-1; diag >= 0; diag--) {
        product_datatype diag_acc = product_tensor(tid_m, tid_n, diag);
        result += nth_slice_to_fp64<int32_t, int8_t>(diag, diag_acc, shift_a + shift_b);
    }

    out_tensor(tid_m, tid_n) = alpha * result + beta * out_tensor(tid_m, tid_n);
}

In [None]:
!cmake --build ./build -t 2b_partially_fused_emulation

In [None]:
!./build/2b_partially_fused_emulation

### Python

In [None]:
problems = [
  (2048, 2048, 2048, 0.9, 1.1),
]

In [None]:
def get_emulated_gemm_kernel(BLAS):

    assert BLAS.a_value_type == BLAS.b_value_type, "Invalid BLAS configuration"

    tile_m, tile_n = BLAS.c_dim
    
    @cuda.jit(extensions=pipeline_extensions, launch_bounds=(BLAS.block_size, 1))
    def gemm_kernel(tensor_c, device_pipeline: DevicePipeline):
        _, _, slices = tensor_c.shape

        block_m = cuda.blockIdx.x
        block_n = cuda.blockIdx.y

        smem = cuda.shared.array(shape=(0,), dtype=BLAS.a_value_type, alignment=device_pipeline.buffer_alignment)

        block_start_m = block_m * tile_m
        block_end_m = (block_m + 1) * tile_m

        block_start_n = block_n * tile_n
        block_end_n = (block_n + 1) * tile_n
 
        # EXERCISE --> Complete the kernel to compute all products and accumulate along diagonals in the same kernel

        # ================================
        # 1. SETUP AND TILE PREPARATION
        # ================================

        # EXERCISE --> Choose your starting diagonal and term along the diagonal
        initial_diag = -1
        initial_term = -1

        # Get pipeline tile
        tile_pipeline = device_pipeline.get_tile(smem,
                                                 (block_m, np.int32(initial_term)),
                                                 (block_n, np.int32(initial_diag)))
        
        accumulator = BLAS.suggest_accumulator()

        c_views = tensor_c[
            block_start_m : block_end_m,
            block_start_n : block_end_n,
            :
        ]
        ldc = max(c_views.strides[:2]) // c_views.itemsize
        
        # ============================================
        # 2. OZAKI SCHEME DIAGONAL ITERATION
        # ============================================
        for diag in range(-1): # EXERCISE --> for loop over diagonals

            # Initialize accumulator for this diagonal
            accumulator.clear()

            # ==========================================
            # 3. SLICE COMBINATION COMPUTATION
            # ==========================================
            for term in range(-1): # EXERCISE --> for loop to iterate along the diagonal
                # =========================================
                # 4. N-STAGE MEMORY PIPELINE FOR GEMM
                # =========================================
                tile_pipeline.execute(accumulator)

                # EXERCISE --> Determine which slice of A and slice of B to multiply
                next_slice_row = -1
                next_slice_col = -1

                device_pipeline.reset_tile(tile_pipeline,
                                           (block_m, np.int32(next_slice_row)),
                                           (block_n, np.int32(next_slice_col)))

            # ========================================
            # 5. RESULT RECONSTRUCTION AND EPILOGUE
            # ========================================
            if accumulator.is_thread_active():
                gmem_c = make_tensor(c_views[:,:,diag], BLAS.get_layout_gmem_c(ldc))
                accumulator.partition_and_copy(accumulator.get_results(), gmem_c)

        # tile_pipeline._del()

    return gemm_kernel

In [None]:
def partial_fused_dgemm_ozaki(tensor_slicedA_cupy, tensor_slicedB_cupy, tensor_diag_cupy, context, warmup=True):
    BLAS = context["BLAS"]
    pipeline_depth = context["PIPELINE_DEPTH"]
    gemm_kernel = context["gemm_kernel"]
    grid = context["gemm_grid"]
    block = context["gemm_block"]

    tensor_slicedA = cuda.as_cuda_array(tensor_slicedA_cupy)
    tensor_slicedB = cuda.as_cuda_array(tensor_slicedB_cupy)
    tensor_diag = cuda.as_cuda_array(tensor_diag_cupy)

    device_pipeline = BLAS.suggest_device_pipeline(pipeline_depth, tensor_slicedA, tensor_slicedB)

    if warmup:
        set_max_dynamic_shared_size_bytes(gemm_kernel, device_pipeline.buffer_size,
                                            tensor_diag, device_pipeline)
    gemm_kernel[grid, block, 0, device_pipeline.buffer_size](tensor_diag, device_pipeline)

In [None]:
def get_epilogue_kernel(block_size=64):
    uint8_width = get_width(np.uint8)

    @cuda.jit(device=True, forceinline=True)
    def nth_slice_to_fp64(nth, nth_slice, exponent_shift):
        ko = math.pow(2.0, -nth * uint8_width)

        value = ko * np.float64(nth_slice)
        return epilogue_ldexp(value, -exponent_shift)

    @cuda.jit(launch_bounds=(block_size, 1))
    def epilogue_kernel(slices, tensor_diag, tensor_shift_a, tensor_shift_b, tensor_out, alpha, beta):
        tid_m = cuda.threadIdx.x + cuda.blockIdx.x * cuda.blockDim.x
        tid_n = cuda.threadIdx.y + cuda.blockIdx.y * cuda.blockDim.y

        if tid_m >= tensor_out.shape[0] or tid_n >= tensor_out.shape[1]:
            return

        shift_a = tensor_shift_a[tid_m]
        shift_b = tensor_shift_b[tid_n]
        
        # EXERCISE --> Complete the implementation of the epilogue kernel
        diag_view = tensor_diag[tid_m, tid_n, :]

        result = 0.0
        for diag in range(-1): # EXERCISE --> loop over diagonals
            result += nth_slice_to_fp64(diag, diag_view[diag], shift_a + shift_b)

        tensor_out[tid_m, tid_n] = alpha * result + beta * tensor_out[tid_m, tid_n]

    return epilogue_kernel

def epilogue(slices, tensor_products, tensor_shift_a, tensor_shift_b, tensor_c, alpha, beta, context):
    epilogue_kernel = context["epilogue_kernel"]
    
    grid = context["epilogue_grid"]
    block = context["epilogue_block"]

    epilogue_kernel[grid, block](slices, tensor_products, tensor_shift_a, tensor_shift_b, tensor_c, alpha, beta)

In [None]:
def setup_func(m, n, k):
    tile_m = 128
    tile_n = 128
    tile_k = 128
    block_size = 128
    
    pipeline_depth = 3

    epilogue_tile_m = 16
    epilogue_tile_n = 16

    assert m % tile_m == 0, "Unsupported dimension m for TILE_M"
    assert n % tile_n == 0, "Unsupported dimension n for TILE_N"
    assert k % tile_k == 0, "Unsupported dimension n for TILE_N"
    assert k >= (tile_k * pipeline_depth), "Unsupported pipeline depth for k"

    assert m % epilogue_tile_m == 0, "Unsupported dimension for EPILOGUE_TILE_M"
    assert n % epilogue_tile_n == 0, "Unsupported dimension for EPILOGUE_TILE_N"
    
    BLAS = Matmul(size=(tile_m, tile_n, tile_k),
                  precision=(np.int8, np.int8, np.int32),
                  data_type="real",
                  alignment=MAX_ALIGNMENT,
                  arrangement=("row_major", "col_major", "col_major"), # Do not change
                  execution="Block",
                  block_size=block_size,
                  with_pipeline=True,
                  enable_input_streaming=True,
                  static_block_dim=True)

    gemm_grid = (m // tile_m, n // tile_n)
    gemm_block = BLAS.block_dim

    epilogue_grid = (m // epilogue_tile_m, n // epilogue_tile_n)
    epilogue_block = (epilogue_tile_m, epilogue_tile_n)

    return {
        "BLAS": BLAS,
        "PIPELINE_DEPTH": pipeline_depth,
        "gemm_kernel" : get_emulated_gemm_kernel(BLAS),
        "gemm_grid": gemm_grid,
        "gemm_block": gemm_block,
        "epilogue_kernel": get_epilogue_kernel(math.prod(epilogue_block)),
        "epilogue_grid": epilogue_grid,
        "epilogue_block": epilogue_block
    }

In [None]:
benchmark_partially_fused_emulated_dgemm(problems, setup_func, partial_fused_dgemm_ozaki, epilogue)

#### Solution

In [None]:
def get_emulated_gemm_kernel_solution(BLAS):

    assert BLAS.a_value_type == BLAS.b_value_type, "Invalid BLAS configuration"

    tile_m, tile_n = BLAS.c_dim
    
    @cuda.jit(extensions=pipeline_extensions, launch_bounds=(BLAS.block_size, 1))
    def gemm_kernel(tensor_c, device_pipeline: DevicePipeline):
        _, _, slices = tensor_c.shape

        block_m = cuda.blockIdx.x
        block_n = cuda.blockIdx.y

        smem = cuda.shared.array(shape=(0,), dtype=BLAS.a_value_type, alignment=device_pipeline.buffer_alignment)

        block_start_m = block_m * tile_m
        block_end_m = (block_m + 1) * tile_m

        block_start_n = block_n * tile_n
        block_end_n = (block_n + 1) * tile_n

        initial_diag = slices - 1
        initial_term = 0

        tile_pipeline = device_pipeline.get_tile(smem,
                                                 (block_m, np.int32(initial_term)),
                                                 (block_n, np.int32(initial_diag)))

        c_views = tensor_c[
            block_start_m : block_end_m,
            block_start_n : block_end_n,
            :
        ]
        ldc = max(c_views.strides[:2]) // c_views.itemsize
        
        accumulator = BLAS.suggest_accumulator()
        for diag in range(initial_diag, -1, -1):
            accumulator.clear()

            for term in range(initial_term, diag + 1):
                tile_pipeline.execute(accumulator)

                next_slice_row =          0 if term == diag else term + 1
                next_slice_col = (diag - 1) if term == diag else diag - next_slice_row

                device_pipeline.reset_tile(tile_pipeline,
                                           (block_m, np.int32(next_slice_row)),
                                           (block_n, np.int32(next_slice_col)))

            if accumulator.is_thread_active():
                gmem_c = make_tensor(c_views[:,:,diag], BLAS.get_layout_gmem_c(ldc))
                accumulator.partition_and_copy(accumulator.get_results(), gmem_c)

        tile_pipeline._del()

    return gemm_kernel

In [None]:
def get_epilogue_kernel_solution(block_size=64):
    uint8_width = get_width(np.uint8)

    @cuda.jit(device=True, forceinline=True)
    def nth_slice_to_fp64(nth, nth_slice, exponent_shift):
        ko = math.pow(2.0, -nth * uint8_width)

        value = ko * np.float64(nth_slice)
        return epilogue_ldexp(value, -exponent_shift)

    @cuda.jit(launch_bounds=(block_size, 1))
    def epilogue_kernel(slices, tensor_diag, tensor_shift_a, tensor_shift_b, tensor_out, alpha, beta):
        tid_m = cuda.threadIdx.x + cuda.blockIdx.x * cuda.blockDim.x
        tid_n = cuda.threadIdx.y + cuda.blockIdx.y * cuda.blockDim.y

        shift_a = tensor_shift_a[tid_m]
        shift_b = tensor_shift_b[tid_n]

        diag_view = tensor_diag[tid_m, tid_n, :]

        result = 0.0
        for diag in range(slices-1, -1, -1):
            result += nth_slice_to_fp64(diag, diag_view[diag], shift_a + shift_b)

        if beta != 0:
            result = alpha * result + beta * tensor_out[tid_m, tid_n]
        else:
            result = alpha * result

        tensor_out[tid_m, tid_n] = result

    return epilogue_kernel

In [None]:
def setup_func_solution(m, n, k):
    ctx = setup_func(m, n, k)
    BLAS = ctx["BLAS"]
    epilogue_block = ctx["epilogue_block"]
    ctx["gemm_kernel"] = get_emulated_gemm_kernel_solution(BLAS);
    ctx["epilogue_kernel"] = get_epilogue_kernel_solution(math.prod(epilogue_block))
    
    return ctx

In [None]:
benchmark_partially_fused_emulated_dgemm(problems, setup_func_solution, partial_fused_dgemm_ozaki, epilogue)

### Performance Model

In [None]:
import numpy as np
import math

# INT8 TOPS, MEMORY BANDWIDTH (GB/s)
GPU_SPECS = {
    "L40S": (733, 864),
    "B200": (4500, 8000)
}

# NOTE: This model is very simplistic and does not take quantization or other overheads like slicing and FP64 operations into account
def roofline_prediction_3_2(m, n, k, slices=7, TILE_M=128, TILE_N=128, TILE_K=128):
    INT8_TOPS, MEMORY_BANDWIDTH_GBS = GPU_SPECS["L40S"]

    num_products = (slices * (slices + 1)) // 2

    # By design since each thread is computing one output element
    tiles = math.ceil(m / TILE_M) * math.ceil(n / TILE_N)

    # Each tile does TILE_M * TILE_N dot products which each have k multiplications and k additions for every product
    flops_per_tile = 2 * TILE_M * TILE_N * k * num_products

    fp64_size = np.dtype(np.float64).itemsize
    int32_size = np.dtype(np.float64).itemsize
    int8_size = np.dtype(np.int8).itemsize

    # We load a TILE_M rows of matrix A, TILE_N columns of matrix B for each product.
    # Then, we read from and write to TILE_M * TILE_N elements of matrix C
    # This needs to happen once for each diagonal
    memory_per_tile = ((TILE_M * k + TILE_N * k) * int8_size + 2 * TILE_M * TILE_N * int32_size) * num_products

    # In the epilogue kernel, we load the products and write the output
    memory_per_tile += (TILE_M * TILE_N) * (num_products * int32_size + fp64_size)

    total_memory_gb = tiles * memory_per_tile * 1e-9
    total_tflop = tiles * flops_per_tile * 1e-12

    return total_tflop / INT8_TOPS, total_memory_gb / MEMORY_BANDWIDTH_GBS

time_flops, time_membw = roofline_prediction_3_2(2048, 2048, 2048)

print(f"The runtime from the math operations {time_flops * 1e3} ms and the runtime from memory is {time_membw * 1e3} ms")

# We will either be bottlenecked by FLOPS or Memory Bandwidth, so we take the maximum
print(f"Therefore, the estimated best case runtime is {max(time_flops, time_membw) * 1e3} ms")

## Conclusion

In this exercise, we've learned how we can use the pipeline APIs to implement more complex routines.  Specifically, we've learned how to:

1. Clear the accumulator when we are ready for a new computation
2. Reset the device pipeline accumulator for new calculations
3. How to iterate over 3D tensors with the pipeline API