In [None]:
import jax
import jax.numpy as jnp
from kernex._src.utils import general_arange, general_product
import functools as ft

In [None]:
@ft.lru_cache(maxsize=None)
def views(shape, kernel_size, strides, border) -> tuple[jnp.ndarray, ...]:
    """Generate absolute sampling matrix"""
    # this function is cached because it is called multiple times
    # and it is expensive to calculate
    # the view is the indices of the array that is used to calculate
    # the output value
    dim_range = tuple(
        general_arange(di, ki, si, x0, xf)
        for (di, ki, si, (x0, xf)) in zip(shape, kernel_size, strides, border)
    )
    matrix = general_product(*dim_range)
    return tuple(map(lambda xi, wi: xi.reshape(-1, wi), matrix, kernel_size))