Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implementation strategy #2

Open
shoyer opened this issue Jan 23, 2023 · 2 comments
Open

Implementation strategy #2

shoyer opened this issue Jan 23, 2023 · 2 comments
Assignees
Labels
question Further information is requested

Comments

@shoyer
Copy link

shoyer commented Jan 23, 2023

This project looks really cool!

I would love to understand at a high level how this package works -- how do you actually implement stencil computations in JAX? Do you reuse jax.lax.scan or something else? Does it support auto-diff? How does performance compare on CPU/GPU/TPU (or whichever configs you've tried)?

@ASEM000 ASEM000 self-assigned this Jan 23, 2023
@ASEM000 ASEM000 added the question Further information is requested label Jan 23, 2023
@ASEM000
Copy link
Owner

ASEM000 commented Jan 23, 2023

Hello Stephan,

I would love to understand at a high level how this package works

  • First, I generate view indices using jax.vmap :

    kernex/kernex/_src/utils.py

    Lines 133 to 163 in f5dd7f2

    def general_product(*args):
    """Equivalent to tuple(zip(*itertools.product(*args)))` for arrays
    Example:
    >>> general_product(
    ... jnp.array([[1,2],[3,4]]),
    ... jnp.array([[5,6],[7,8]]))
    (
    DeviceArray([[[1, 2],[1, 2]],[[3, 4],[3, 4]]], dtype=int32),
    DeviceArray([[[5, 6],[7, 8]],[[5, 6],[7, 8]]], dtype=int32)
    )
    >>> tuple(zip(*(itertools.product([[1,2],[3,4]],[[5,6],[7,8]]))))
    (
    ([1, 2], [1, 2], [3, 4], [3, 4]),
    ([5, 6], [7, 8], [5, 6], [7, 8])
    )
    """
    def nvmap(n):
    in_axes = [None] * len(args)
    in_axes[-n] = 0
    return (
    vmap(lambda *x: x, in_axes=in_axes)
    if n == 1
    else vmap(nvmap(n - 1), in_axes=in_axes)
    )
    return nvmap(len(args))(*args)

    @cached_property
    def views(self) -> tuple[jnp.ndarray, ...]:
    """Generate absolute sampling matrix"""
    # this function is cached because it is called multiple times
    # and it is expensive to calculate
    # the view is the indices of the array that is used to calculate
    # the output value
    dim_range = tuple(
    general_arange(di, ki, si, x0, xf)
    for (di, ki, si, (x0, xf)) in zip(
    self.shape, self.kernel_size, self.strides, self.border
    )
    )
    matrix = general_product(*dim_range)
    return tuple(map(lambda xi, wi: xi.reshape(-1, wi), matrix, self.kernel_size))

  • Second, I create a new function that applies the user-function, using view indices as the first input.
    Given certain view indices, this function first retrieves the array potion using jnp.ix_ then applies the user function on it. In the case of relative=True or in other words, the indexing is relative (center is 0 , like numba.stencil), then I roll the array portion before applying the function.

for kmap

def reduce_map_func(self, func, *args, **kwargs) -> Callable:
if self.relative:
# if the function is relative, the function is applied to the view
return lambda view, array: func(
roll_view(array[ix_(*view)]), *args, **kwargs
)
else:
return lambda view, array: func(array[ix_(*view)], *args, **kwargs)

for kscan

def reduce_scan_func(self, func, *args, **kwargs) -> Callable:
if self.relative:
# if the function is relative, the function is applied to the view
# the result is a 1D array of the same length as the number of views
return lambda view, array: array.at[self.index_from_view(view)].set(
func(roll_view(array[ix_(*view)]), *args, **kwargs)
)
else:
return lambda view, array: array.at[self.index_from_view(view)].set(
func(array[ix_(*view)], *args, **kwargs)
)


  • Third

For kmap, I use jax.vmap to vectorize the new view indices-accepting function over array of all possible view indices.

def __single_call__(self, array: jnp.ndarray, *args, **kwargs):
padded_array = jnp.pad(array, self.pad_width)
# convert the function to a callable that takes a view and an array
# and returns the result of the function applied to the view
reduced_func = self.reduce_map_func(self.funcs[0], *args, **kwargs)
# apply the function to each view using vmap
# the result is a 1D array of the same length as the number of views
result = vmap(lambda view: reduced_func(view, padded_array))(self.views)
# reshape the result to the output shape
# for example if the input shape is (3, 3) and the kernel shape is (2, 2)
# and the stride is 1 , and the padding is 0, the output shape is (2, 2)
return result.reshape(*self.output_shape, *result.shape[1:])

For kscan - my prime motivation- I use jax.lax.scan to scan the indices array

def __single_call__(self, array, *args, **kwargs):
padded_array = jnp.pad(array, self.pad_width)
reduced_func = self.reduce_scan_func(self.funcs[0], *args, **kwargs)
def scan_body(padded_array, view):
result = reduced_func(view, padded_array).reshape(padded_array.shape)
return result, result[self.index_from_view(view)]
return lax.scan(scan_body, padded_array, self.views)[1].reshape(
self.output_shape
)

Does it support auto-diff?

Yes, definitely, the library relies on jax.numpy, jax.vmap, jax.lax.scan, and jax.lax.switch for it's internals.

How does performance compare on CPU/GPU/TPU (or whichever configs you've tried)?

I benchmarked jax.lax.conv_general_dilated_patches and jax.lax.conv_general_dilated for based on kmap
The code is under tests_and_benchmarks. In general, kmap seems faster for many scenarios, especially on CPU*, However it needs more rigorous benchmarking, especially on TPU.

In general, my prime motivation is to solve PDEs using a stencil definition, which might require applying different functions at different locations of the array (ex., boundary),
This is the reason kernex offers the ability to use kmap and kscan along with jax.lax.switch to apply different functions on different portions of the array. The following example introduces the function mesh concept, where different stencils can be applied using indexing. The backbone for this feature is jax.lax.switch

Function mesh Array equivalent
F = kex.kmap(kernel_size=(1,))
F[0] = lambda x:x[0]**2
F[1:] = lambda x:x[0]**3





array = jnp.arange(1,11).astype('float32')
print(F(array))
>>> [1., 8., 27., 64., 125.,
... 216., 343., 512., 729., 1000.]

print(jax.grad(lambda x:jnp.sum(F(x)))(array))
>>> [2.,12.,27.,48.,75.,
... 108.,147.,192.,243.,300.]
def F(x):
    f1 = lambda x:x**2
    f2 = lambda x:x**3
    x = x.at[0].set(f1(x[0]))
    x = x.at[1:].set(f2(x[1:]))
    return x

array = jnp.arange(1,11).astype('float32')
print(F(array))
>>> [1., 8., 27., 64., 125.,
... 216., 343., 512., 729., 1000.]

print(jax.grad(lambda x: jnp.sum(F(x)))(array))
>>> [2.,12.,27.,48.,75.,
... 108.,147.,192.,243.,300.]

@shoyer
Copy link
Author

shoyer commented Jan 23, 2023

OK great, thank you for sharing!

I agree that this is a very promising approach for implementing PDE kernels, and in general this is similar to the way I've implemented PDE solvers in JAX by hand (e.g., the wave equation solver).

conv_general_dilated and conv_general_dilated_patches use XLA's Convolution operation, which is really optimized for convolutional neural networks with large numbers of channels. I wouldn't expect them to work well for PDE kernels, except perhaps on TPUs.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
question Further information is requested
Projects
None yet
Development

No branches or pull requests

2 participants