In [None]:
# SPDX-License-Identifier: Apache-2.0 AND CC-BY-NC-4.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

<img src="./images/nvmath_head_panel@0.5x.png" alt="nvmath-python" />

# Getting Started with nvmath-python: Kernel Fusion



## Exercise: Evaluate performance of CuPy `@`-based vs. `matmul`-based implementations of GEMM

In [None]:
## Exercise: Evaluate performance of CuPy `@`-based vs. `matmul`-based implementations of GEMM

import nvmath # noqa: F401 Facilitates shared objects loading required by CuPy (Workaround for CuPy being unable to find nvrtc installed in wheels)
import cupy as cp
import numpy as np
import cupyx as cpx

# Define GEMM parameters
m, n, k = 10_000_000, 40, 10

a = cp.random.rand(m, k, dtype=cp.float32)
b = cp.random.rand(k, n, dtype=cp.float32)
c = cp.random.rand(m, n, dtype=cp.float32)

alpha = 1.5
beta = 0.5

# Benchmarking function

# Helper function to benchmark two implementations F and (optionally) F_alternative
# When F_alternative is provided, in addition to raw performance numbers (seconds)
# speedup of F relative to F_alternative is reported
def benchmark(
    F, F_name="Implementation", F_alternative=None, F_alternative_name="Alternative implementation", n_repeat=10, n_warmup=1
):
    timing = cpx.profiler.benchmark(F, n_repeat=n_repeat, n_warmup=n_warmup)  # warm-up + repeated runs
    perf = np.min(timing.gpu_times)  # best time from repeated runs
    print(f"{F_name} performance = {perf:0.4f} sec")

    if F_alternative is not None:
        timing_alt = cpx.profiler.benchmark(F_alternative, n_repeat=n_repeat, n_warmup=n_warmup)
        perf_alt = np.min(timing_alt.gpu_times)
        print(f"{F_alternative_name} performance = {perf_alt:0.4f} sec")
        print(f"Speedup = {perf_alt / perf:0.4f}x")
    else:
        perf_alt = None

    return perf, perf_alt

# Write two functions that implement GEMM using `@` operator and `matmul` function.
def gemm_operator_form(a, b, c, alpha, beta):
    return alpha * a @ b + beta * c

def gemm_matmul_form(a, b, c, alpha, beta):
    return alpha * cp.matmul(a, b) + beta * c

# Benchmark the two implementations
benchmark(
    lambda: gemm_operator_form(a, b, c, alpha, beta),
    "GEMM @ form",
    lambda: gemm_matmul_form(a, b, c, alpha, beta),
    "GEMM matmul form",
    n_repeat=5,
    n_warmup=1
)

# Compute the number of flops for the two implementations
def gemm_operator_form_flops(a, b, c):
    matmul_flops = a.shape[0] * a.shape[1] * (2 * b.shape[1] - 1)
    alpha_a_flops = a.shape[0] * a.shape[1]
    beta_c_flops = c.shape[0] * c.shape[1]
    add_flops = a.shape[0] * b.shape[1]
    return matmul_flops + alpha_a_flops + beta_c_flops + add_flops

def gemm_matmul_form_flops(a, b, c):
    matmul_flops = a.shape[0] * a.shape[1] * (2 * b.shape[1] - 1)
    alpha_x_flops = a.shape[0] * b.shape[1]
    beta_c_flops = c.shape[0] * c.shape[1]
    add_flops = a.shape[0] * b.shape[1]
    return matmul_flops + alpha_x_flops + beta_c_flops + add_flops

# Print the number of flops for the two implementations
print(f"GEMM operator form: {gemm_operator_form_flops(a, b, c) * 1e-9:.2f} GFLOPS")
print(f"GEMM matmul form: {gemm_matmul_form_flops(a, b, c) * 1e-9:.2f} GFLOPS")