In [None]:
def gaunt_tensor_product_fourier_2D(
    input1: e3nn.IrrepsArray,
    input2: e3nn.IrrepsArray,
    *,
    res_theta: int,
    res_phi: int,
    convolution_type: str,
    filter_ir_out=None,
) -> e3nn.IrrepsArray:
    """Gaunt tensor product using 2D Fourier functions."""
    filter_ir_out = _validate_filter_ir_out(filter_ir_out)

    # Pad the inputs with zeros.
    lmax1 = input1.irreps.lmax
    if input1.irreps != e3nn.s2_irreps(lmax1):
        input1 = input1.extend_with_zeros(e3nn.s2_irreps(lmax1))

    lmax2 = input2.irreps.lmax
    if input2.irreps != e3nn.s2_irreps(lmax2):
        input2 = input2.extend_with_zeros(e3nn.s2_irreps(lmax2))

    with jax.ensure_compile_time_eval():
        # Precompute the change of basis matrices.
        y1_grid = gtp_utils.compute_y_grid(lmax1, res_theta=res_theta, res_phi=res_phi)
        y2_grid = gtp_utils.compute_y_grid(lmax2, res_theta=res_theta, res_phi=res_phi)
        z_grid = gtp_utils.compute_z_grid(lmax1 + lmax2, res_theta=res_theta, res_phi=res_phi)

        # Convert to sparse arrays.
        y1_grid_sp = sparse.BCOO.fromdense(y1_grid.round(8))
        y2_grid_sp = sparse.BCOO.fromdense(y2_grid.round(8))
        z_grid_sp = sparse.BCOO.fromdense(z_grid.round(8))

    @sparse.sparsify
    def to_2D_fourier_coeffs(input, y_grid):
        return jnp.einsum("...a,auv->...uv", input.array, y_grid)

    # Convert to 2D Fourier coefficients.
    input1_uv = to_2D_fourier_coeffs(input1.array, y1_grid_sp)
    input2_uv = to_2D_fourier_coeffs(input2.array, y2_grid_sp)

    # Perform the convolution in Fourier space, either directly or using FFT.
    if convolution_type == "direct":
        output_uv = gtp_utils.convolve_2D_direct(input1_uv, input2_uv)
    elif convolution_type == "fft":
        output_uv = gtp_utils.convolve_2D_fft(input1_uv, input2_uv)
    else:
        raise ValueError(f"Unknown convolution type {convolution_type}.")

    @sparse.sparsify
    def to_SH_coeffs(input, z_grid):
        return jnp.einsum("...uv,auv->...a", input.conj(), z_grid)

    # Convert back to SH coefficients.
    output_lm = to_SH_coeffs(output_uv, z_grid_sp)
    output_lm = e3nn.IrrepsArray(
        e3nn.s2_irreps(lmax1 + lmax2),
        output_lm.real,
    )
    return output_lm.filter(filter_ir_out)