# 🎛 Custom convolutions

In this notebook, overriding the convolution layers operation is demonstrated using [kernex](https://github.com/ASEM000/kernex/blob/main/README.md). By defining only the kernel operation, the layer can be used in the same way as the original layer and parameter initialization/shape checking is handled automatically.

## Direct convolution

This example demonstrates how to recreate the convolution operation using the `kernex` library. `kernex` offers function transformation similar to `jax.vmap`, that wraps a kernel operation (e.g. `lambda input,kernel: sum(input*kernel)`) and returns a function that works on array views.

In [1]:
!pip install git+https://github.com/ASEM000/serket --quiet
!pip install kernex --quiet

In [2]:
import kernex as kex  # for stencil operations like convolutions
import serket as sk
import jax
import jax.random as jr
import jax.numpy as jnp
import numpy.testing as npt


def my_conv(
    input: jax.Array,
    weight: jax.Array,
    bias: jax.Array | None,
    strides: tuple[int, ...],
    padding: tuple[tuple[int, int], ...],
    dilation: tuple[int, ...],
    groups: int,
    mask: jax.Array | None,
):
    # same function signature as serket.nn.conv_nd
    del mask  #
    del dilation  # for simplicity
    del groups  # for simplicity
    _, in_features, *kernel_size = weight.shape

    @kex.kmap(
        kernel_size=(in_features, *kernel_size),
        strides=(1, *strides),
        padding=((0, 0), *padding),
    )
    def conv_func(input, weight):
        # define the kernel operation
        return jnp.sum(input * weight)

    # vectorize over the out_features of the weight
    out = jax.vmap(conv_func, in_axes=(None, 0))(input, weight)
    # squeeze out the vmapped axis
    out = jnp.squeeze(out, axis=1)
    return out + bias if bias is not None else out


class CustomConv2D(sk.nn.Conv2D):
    # override the conv_op
    conv_op = my_conv


k1, k2 = jr.split(jr.PRNGKey(0), 2)

basic_conv = sk.nn.Conv2D(
    in_features=1,
    out_features=2,
    kernel_size=3,
    bias_init=None,
    key=k1,
)

custom_conv = CustomConv2D(
    in_features=1,
    out_features=2,
    kernel_size=3,
    bias_init=None,
    key=k1,
)

# channel-first input
input = jr.uniform(k2, shape=(1, 10, 10))

npt.assert_allclose(
    basic_conv(input),
    custom_conv(input),
    atol=1e-6,
)
# lets check gradients
npt.assert_allclose(
    jax.grad(lambda x: basic_conv(x).sum())(input),
    jax.grad(lambda x: custom_conv(x).sum())(input),
    atol=1e-6,
)

## Depthwise convolution

Similar to the above example, For recreating depthwise convolution, the only addition is to add vectorize the kernel operation over the channels dimension using `jax.vmap`

In [3]:
import kernex as kex  # for stencil operations like convolutions
import jax
import jax.random as jr
import jax.numpy as jnp
import numpy.testing as npt


def my_depthwise_conv(
    input: jax.Array,
    weight: jax.Array,
    bias: jax.Array | None,
    strides: tuple[int, ...],
    padding: tuple[tuple[int, int], ...],
    mask: jax.Array | None,
):
    # same function signature as serket.nn.depthwise_conv_nd
    del mask  #
    _, _, *kernel_size = weight.shape

    @jax.vmap  # <- vectorize over the input channels
    @kex.kmap(
        kernel_size=tuple(kernel_size),
        strides=strides,
        padding=padding,
    )
    def conv_func(input, weight):
        # define the kernel operation
        return jnp.sum(input * weight)

    # vectorize over the output channels (filters)
    out = jax.vmap(conv_func, in_axes=(None, 0))(input, weight)
    out = jnp.squeeze(out, axis=1)  # squeeze out the vmapped axis
    return out + bias if bias is not None else out


class CustomDepthwiseConv2D(sk.nn.DepthwiseConv2D):
    # override the conv_op
    conv_op = my_depthwise_conv


k1, k2 = jr.split(jr.PRNGKey(0), 2)

basic_conv = sk.nn.DepthwiseConv2D(
    in_features=1,
    depth_multiplier=2,
    kernel_size=3,
    bias_init=None,
    key=k1,
)

custom_conv = CustomDepthwiseConv2D(
    in_features=1,
    depth_multiplier=2,
    kernel_size=3,
    bias_init=None,
    key=k1,
)

# channel-first input
input = jr.uniform(k2, shape=(1, 10, 10))

npt.assert_allclose(
    basic_conv(input),
    custom_conv(input),
    atol=1e-6,
)
# lets check gradients
npt.assert_allclose(
    jax.grad(lambda x: basic_conv(x).sum())(input),
    jax.grad(lambda x: custom_conv(x).sum())(input),
    atol=1e-6,
)

## Positive kernel convolution

In this example, a custom convolution operation is defined. As a toy examaple the operation will only multiply weight values
that are not zero.

In [4]:
import kernex as kex  # for stencil operations like convolutions
import serket as sk
import jax
import jax.random as jr
import jax.numpy as jnp
import numpy.testing as npt


def my_custom_conv(
    input: jax.Array,
    weight: jax.Array,
    bias: jax.Array | None,
    strides: tuple[int, ...],
    padding: tuple[tuple[int, int], ...],
    dilation: tuple[int, ...],
    groups: int,
    mask: jax.Array | None,
):
    # same function signature as serket.nn.conv_nd
    del mask  #
    del dilation  # for simplicity
    del groups  # for simplicity
    _, in_features, *kernel_size = weight.shape

    @kex.kmap(
        kernel_size=(in_features, *kernel_size),
        strides=(1, *strides),
        padding=((0, 0), *padding),
    )
    def conv_func(input, weight):
        # define a custom kernel operation
        # that only multiplies the input with the weight
        # if the weight is positive
        return jnp.sum(input * jnp.where(weight < 0, 0, weight))

    # vectorize over the out_features of the weight
    out = jax.vmap(conv_func, in_axes=(None, 0))(input, weight)
    # squeeze out the vmapped axis
    out = jnp.squeeze(out, axis=1)
    return out + bias if bias is not None else out


class CustomConv2D(sk.nn.Conv2D):
    # override the conv_op
    conv_op = my_custom_conv


k1, k2 = jr.split(jr.PRNGKey(0), 2)


custom_conv = CustomConv2D(
    in_features=1,
    out_features=2,
    kernel_size=3,
    bias_init=None,
    key=k1,
)

# channel-first input
input = jr.uniform(k2, shape=(1, 10, 10))

basic_conv(input).shape

(2, 10, 10)