# Warp Core Tutorial: Generics

In [None]:
!pip install warp-lang

In [None]:
import warp as wp

wp.config.quiet = True

# Explicitly initializing Warp is not necessary but
# we do it here to ensure everything is good to go.
wp.init()

## Function Overloading

Warp allows defining multiple functions with the same name that have a different parameter signature.

In [None]:
@wp.func
def product(
    v: wp.vec2,
) -> float:
    return v[0] * v[1]


@wp.func
def product(
    m: wp.mat22,
) -> float:
    return m[0, 0] * m[0, 1] * m[1, 0] * m[1, 1]


# Define a kernel that computes the component-wise product
# of a vector and a matrix.
@wp.kernel
def product_kernel(
    v: wp.vec2,
    m: wp.mat22,
    out_product: wp.array(dtype=float),
):
    out_product[0] = product(v) * product(m)


print("\nproduct:")
v = wp.vec2(2.0, 4.0)
m = wp.mat22(3.0, 5.0, 7.0, 9.0)
out_product = wp.empty(1, dtype=float)
wp.launch(product_kernel, dim=1, inputs=(v, m), outputs=(out_product,))
print(out_product)

## Generic Functions

A complementary approach to overloading functions is to use one of the generic types `typing.Any`, `wp.Int`, `wp.Float`, or `wp.Scalar`, and let Warp infer the final function's signature based on the arguments being passed to it.

In [None]:
# This function works with integer and floating-point types of any width.
@wp.func
def square(x: wp.Scalar) -> wp.Scalar:
    return x * x


# Define two kernels that square the values of an array,
# one for 16-bit integers, and another one for 64-bit floating-points.
@wp.kernel
def square_kernel_i16(arr: wp.array(dtype=wp.int16)):
    i = wp.tid()
    arr[i] = square(arr[i])


@wp.kernel
def square_kernel_f64(arr: wp.array(dtype=wp.float64)):
    i = wp.tid()
    arr[i] = square(arr[i])


# First implicit kernel instantiation with a 16-bit integer type.
print("\narr_i16:")
arr_i16 = wp.array((1, 2, 3), dtype=wp.int16)
wp.launch(square_kernel_i16, dim=arr_i16.shape, inputs=(arr_i16,))
print(arr_i16)

# Second implicit kernel instantiation with a 64-bit floating-point type.
print("\narr_f64:")
arr_f64 = wp.array((4, 5, 6), dtype=wp.float64)
wp.launch(square_kernel_f64, dim=arr_f64.shape, inputs=(arr_f64,))
print(arr_f64)

## Generic Kernels

The same generic types `typing.Any`, `wp.Int`, `wp.Float`, and `wp.Scalar` can also be used when annotating parameters on a kernel.

To generate the final kernels from such generic types, Warp supports implicit and explicit instantiations.

### Implicit Instantiation

By default, Warp infers the final kernel's signature and implementation based on the arguments being passed to it when calling `wp.launch()`.

In [None]:
# Define a kernel that scales the values of an array with a coefficient.
# Its elements can be integers or floating-points of any width.
@wp.kernel
def scale_kernel(arr: wp.array(dtype=wp.Scalar), coeff: wp.Scalar):
    i = wp.tid()
    arr[i] *= coeff


# First implicit kernel instantiation with a 16-bit integer type.
print("arr_i16:")
arr_i16 = wp.array((1, 2, 3), dtype=wp.int16)
wp.launch(scale_kernel, dim=arr_i16.shape, inputs=(arr_i16, wp.int16(2)))
print(arr_i16)

# Second implicit kernel instantiation with a 64-bit floating-point type.
print("\narr_f64:")
arr_f64 = wp.array((4, 5, 6), dtype=wp.float64)
wp.launch(scale_kernel, dim=arr_f64.shape, inputs=(arr_f64, wp.float64(2)))
print(arr_f64)

### Explicit Instantiation

It's also possible to specify which types a kernel should be instantiated against, before even needing to call `wp.launch()`. This is done using the `@wp.overload` decorator.

One advantage of this approach is that it speeds up kernel launches since Warp won't need to try inferring and generating a new kernel instance each time. Another is related to module reloading, as detailed in the [documentation here](https://nvidia.github.io/warp/modules/generics.html#module-reloading-behavior).

In [None]:
# Define a kernel that scales the values of an array with a coefficient.
# Its elements can be integers or floating-points of any width.
@wp.kernel
def scale_kernel(arr: wp.array(dtype=wp.Scalar), coeff: wp.Scalar):
    i = wp.tid()
    arr[i] *= coeff


# Explicit instantiation for 16-bit integers.
@wp.overload
def scale_kernel(arr: wp.array(dtype=wp.int16), coeff: wp.int16):
    ...


# Explicit instantiation for 64-bit floating-points.
@wp.overload
def scale_kernel(arr: wp.array(dtype=wp.float64), coeff: wp.float64):
    ...


# Launch the kernel instance using a 16-bit integer type.
print("arr_i16:")
arr_i16 = wp.array((1, 2, 3), dtype=wp.int16)
wp.launch(scale_kernel, dim=arr_i16.shape, inputs=(arr_i16, wp.int16(2)))
print(arr_i16)

# Launch the kernel instance using a 64-bit floating-point type.
print("\narr_f64:")
arr_f64 = wp.array((4, 5, 6), dtype=wp.float64)
wp.launch(scale_kernel, dim=arr_f64.shape, inputs=(arr_f64, wp.float64(2)))
print(arr_f64)

## Type Introspection

Due to Warp's strict typing rules and lack of integer/floating-point promotion rules, it is required to pass the exact argument types when calling functions. For example, when constructing a `wp.vec3s()` instance, it is necessary to ensure that each argument is explicitly casted to the type `wp.int16`, if it isn't of that type already, like `wp.vec3s(wp.int16(1), wp.int16(2), wp.int16(3))`, since integer literals default to 32-bit.

In the context of a generic kernel/function where the parameter type is only known at runtime, Warp exposes a `type()` operator that allows retrieving the resolved type of a variable in order to initialize/cast values.

To retrieve the data type of the elements of an array, calling `type()` on the first element can be used, but a more convenient form is also available with `array.dtype`.

In [None]:
# Define a kernel that increases the values of an array by a fixed amount.
@wp.kernel
def increase_kernel(arr: wp.array(dtype=wp.Scalar)):
    i = wp.tid()

    # These 2 calls are equivalent.
    arr[i] += type(arr[0])(2)
    arr[i] += arr.dtype(3)


# Launch the kernel instance using a 16-bit integer type.
print("arr_i16:")
arr_i16 = wp.array((1, 2, 3), dtype=wp.int16)
wp.launch(increase_kernel, dim=arr_i16.shape, inputs=(arr_i16,))
print(arr_i16)

# Launch the kernel instance using a 64-bit floating-point type.
print("\narr_f64:")
arr_f64 = wp.array((4, 5, 6), dtype=wp.float64)
wp.launch(increase_kernel, dim=arr_f64.shape, inputs=(arr_f64,))
print(arr_f64)

## Dynamic Code Generation

When more flexibility is desired than what the approaches covered so far can offer, we can make use of the dynamic nature of Python to generate kernels, functions, and even structs at runtime using closures that define values, types, or even functions as parameters.

In [None]:
# Define some operator functions that we can pass to the kernel as arguments.


@wp.func
def op_add(a: wp.Scalar, b: wp.Scalar) -> wp.Scalar:
    return a + b


@wp.func
def op_mul(a: wp.Scalar, b: wp.Scalar) -> wp.Scalar:
    return a * b


# Closure creating and returning a kernel.
# All the argument values will be embedded into the generated code
# that is to be compiled against the target architecture (CUDA or C++).
def create_kernel(vec_length: int, vec_dtype: wp.Scalar, num_iter: int, op_fn: wp.Function) -> wp.kernel:
    # Define the vector type from its length/dtype.
    vec = wp.vec(vec_length, vec_dtype)

    # Define a function that reduces all of a vector's components into a single
    # value, using the provided operator function.
    @wp.func
    def reduce(v: vec) -> vec_dtype:
        out = vec_dtype(0)
        for i in range(vec_length):
            out += op_fn(v[i], vec_dtype(i))

        return out

    # Define the kernel function to return.
    @wp.kernel
    def kernel(arr: wp.array(dtype=vec)):
        tid = wp.tid()

        v = vec()
        for i in range(vec_length):
            v[i] = vec_dtype(tid + i)

        for _ in range(num_iter):
            v *= reduce(v)

        arr[tid] = v

    return kernel


# Generate and evaluate a first kernel.
print("arr_1:")
vec_length = 2
vec_dtype = wp.int32
num_iter = 3
op_fn = op_mul
arr_1 = wp.empty(3, dtype=wp.vec(vec_length, vec_dtype))
kernel_1 = create_kernel(vec_length, vec_dtype, num_iter, op_fn)
wp.launch(kernel_1, dim=arr_1.shape, inputs=(arr_1,))
print(arr_1)

# Generate and evaluate a second kernel.
print("\narr_2:")
vec_length = 3
vec_dtype = wp.float64
num_iter = 2
op_fn = op_add
arr_2 = wp.empty(3, dtype=wp.vec(vec_length, vec_dtype))
kernel_2 = create_kernel(vec_length, vec_dtype, num_iter, op_fn)
wp.launch(kernel_2, dim=arr_2.shape, inputs=(arr_2,))
print(arr_2)