In [None]:
# SPDX-License-Identifier: Apache-2.0 AND CC-BY-NC-4.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

<img src="./images/nvmath_head_panel@0.5x.png" alt="nvmath-python" />

# Getting Started with nvmath-python: Stateful APIs and Autotuning

## Exercise: Batch Dimension vs. Batch Sequence

In the above example, we implemented batching as a sequence of matrices being processed one by one in a loop. This is a common technique for streaming data or when the entire batch does not fit into GPU memory. An alternative approach is to add a dedicated batching dimension and operate with the batch as a single tensor. The **nvmath-python** library supports both use cases.

Implement a batching dimension approach and compare performance to the batch sequence approach. Explain the performance difference (if any).

In [None]:
## Exercise: Batch Dimension vs. Batch Sequence

import nvmath
from nvmath.linalg.advanced import MatmulEpilog
import cupy as cp
import numpy as np
import cupyx as cpx


# Helper function to benchmark two implementations F and (optionally) F_alternative
# When F_alternative is provided, in addition to raw performance numbers (seconds)
# speedup of F relative to F_alternative is reported
def benchmark(
    F, F_name="Implementation", F_alternative=None, F_alternative_name="Alternative implementation", n_repeat=10, n_warmup=1
):
    timing = cpx.profiler.benchmark(F, n_repeat=n_repeat, n_warmup=n_warmup)  # warm-up + repeated runs
    perf = np.min(timing.gpu_times)  # best time from repeated runs
    print(f"{F_name} performance = {perf:0.4f} sec")

    if F_alternative is not None:
        timing_alt = cpx.profiler.benchmark(F_alternative, n_repeat=n_repeat, n_warmup=n_warmup)
        perf_alt = np.min(timing_alt.gpu_times)
        print(f"{F_alternative_name} performance = {perf_alt:0.4f} sec")
        print(f"Speedup = {perf_alt / perf:0.4f}x")
    else:
        perf_alt = None

    return perf, perf_alt


m, n, k, batch_size = 64, 128, 256, 512 # Select larger sizes for benchmarking purposes

a = cp.random.rand(batch_size, m, k, dtype=cp.float32)
b = cp.random.rand(batch_size, k, n, dtype=cp.float32)
d = cp.empty((batch_size, m, n), dtype=cp.float32)
bias = cp.random.rand(batch_size, m, 1, dtype=cp.float32)


def matmul_batched_stateless(a, b, bias):
    global d # Use pre-allocated array to save memory
    d[:] = nvmath.linalg.advanced.matmul( # Batch is inferred from the shape of the operands.
        a, b, epilog=MatmulEpilog(MatmulEpilog.RELU_BIAS), epilog_inputs={"bias": bias}
    )


def matmul_batched_stateful_sequence(a, b, bias):
    with nvmath.linalg.advanced.Matmul(a[0], b[0]) as mm: # We create a Matmul object for the first batch element.
        mm.plan(epilog=MatmulEpilog(MatmulEpilog.RELU_BIAS), epilog_inputs={"bias": bias[0]}) # We assume a[0] and b[0] are representative of the batch.
        mm.execute() # The execution doesn't require resetting the operands.
        for i in range(1, batch_size):
            mm.reset_operands(a=a[i], b=b[i], epilog_inputs={"bias": bias[i]}) # Subsequent executions require resetting the operands.
            d[i] = mm.execute() # We execute with the new operands.


benchmark(
    lambda: matmul_batched_stateless(a, b, bias),
    "Stateless batched dimension approach",
    lambda: matmul_batched_stateful_sequence(a, b, bias),
    "Stateful with a batch sequence approach",
)