Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Flat #51

Merged
merged 5 commits into from
Mar 26, 2024
Merged

Flat #51

Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 69 additions & 0 deletions src/dctkit/dec/cochain.py
Original file line number Diff line number Diff line change
Expand Up @@ -394,3 +394,72 @@ def sym(c: Cochain) -> Cochain:
its symmetric part.
"""
return scalar_mul(add(c, transpose(c)), 0.5)


def convolution(c: Cochain, kernel: Cochain, kernel_window: float) -> Cochain:
""" Compute the convolution between two scalar 0-cochains.

Args:
c: a scalar 0-cochain.
kernel: the scalar 0-cochain kernel.
kernel_window: the kernel window.

Returns:
the convolution rho*kernel.
"""
# we build a kernel matrix K by rolling the kernel vector k + 1 times, where
# k is the kernel window.
# For example if kernel = [1,2,0,0] and kernel_window = 2, then
# K = [[1,2,0,0],
# [0,1,2,0],
# [0,0,1,2]].
# In this way we can express the
# convolution between c and k as SK @c_coeffs where S is the hodge star
n = len(c.coeffs)
K = jnp.zeros((n, n), dtype=dt.float_dtype)

# for simplicity, we roll the kernel coeffs n=len(kernel) times and
# this is a trick to do so
buffer = jnp.empty((n, n*2 - 1))
buffer = buffer.at[:, :n].set(kernel.coeffs[:n].T)
buffer = buffer.at[:, n:].set(kernel.coeffs[:n-1].T)

rolled = buffer.reshape(-1)[n-1:-1].reshape(n, -1)

K_full_roll = jnp.roll(rolled[:, :n], shift=1, axis=0)
# since we want to roll only k+1 times, we extract the correct portion
# of the matrix
K_non_zero = K_full_roll[:n - kernel_window + 1]
K = K.at[:n - kernel_window + 1, :].set(K_non_zero)
K_coch = Cochain(c.dim, c.is_primal, c.complex, K)

# apply hodge star to compute SK
star_kernel = star(K_coch)
conv = Cochain(c.dim, c.is_primal, c.complex, star_kernel.coeffs@c.coeffs)
return conv


def abs(c: Cochain) -> Cochain:
""" Compute the absolute value of a cochain.

Args:
c: a cochain.

Returns:
its absolute value.
"""
return Cochain(c.dim, c.is_primal, c.complex, jnp.abs(c.coeffs))


def maximum(c_1: Cochain, c_2: Cochain) -> Cochain:
""" Compute the component-wise maximum between two cochains.

Args:
c_1: a cochain.
c_2: a cochain.

Returns:
the component-wise maximum
"""
return Cochain(c_1.dim, c_1.is_primal, c_1.complex,
jnp.maximum(c_1.coeffs, c_2.coeffs))
19 changes: 14 additions & 5 deletions src/dctkit/dec/flat.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,12 @@
from dctkit.dec import cochain as C
from jax import Array, vmap
from functools import partial
from typing import Callable, Dict, Optional


def flat(c: C.CochainP0 | C.CochainD0, weights: Array,
edges: C.CochainP1V | C.CochainD1V) -> C.CochainP1 | C.CochainD1:
def flat(c: C.CochainP0 | C.CochainD0, weights: Array, edges: C.CochainP1V |
C.CochainD1V, interp_func: Optional[Callable] = None,
interp_func_args: Optional[Dict] = {}) -> C.CochainP1 | C.CochainD1:
"""Applies the flat to a vector/tensor-valued cochain representing a discrete
vector/tensor field to obtain a scalar-valued cochain over primal/dual edges.

Expand All @@ -19,12 +21,19 @@ def flat(c: C.CochainP0 | C.CochainD0, weights: Array,
interpolation scheme chosen for the input discrete vector/tensor field.
edges: vector-valued cochain collecting the primal/dual edges over which the
discrete vector/tensor field should be integrated.
interp_func: interpolation function (callable) taking in input the cochain c
and providing in output a 1-cochain of the same type (primal/dual). If
it is None, then an interpolation function is built as W^T@c.coeffs.
interp_func_args: additional keyword arguments for interp_func
Returns:
a primal/dual scalar/vector-valued cochain defined over primal/dual edges.
"""
# contract over the simplices of the input cochain (last axis of weights, first axis
# of input cochain coeffs)
weighted_v = jnp.tensordot(weights.T, c.coeffs, axes=1)
if interp_func is None:
# contract over the simplices of the input cochain (last axis of weights,
# first axis of input cochain coeffs)
def interp_func(x): return jnp.tensordot(weights.T, x.coeffs, axes=1)
interp_func_args = {}
weighted_v = interp_func(c, **interp_func_args)
# contract input vector/tensors with edge vectors (last indices of both
# coefficient matrices), for each target primal/dual edge
contract = partial(jnp.tensordot, axes=([-1,], [-1,]))
Expand Down
17 changes: 17 additions & 0 deletions tests/test_cochain.py
Original file line number Diff line number Diff line change
Expand Up @@ -605,6 +605,23 @@ def test_codifferential(setup_test):
assert np.allclose(inner_all[i], cod_inner_all[i])


def test_convolution(setup_test):
mesh_1, _ = util.generate_line_mesh(11, 1.)
S_1 = util.build_complex_from_mesh(mesh_1)
S_1.get_hodge_star()
n_1 = S_1.S[1].shape[0]
vD0 = np.arange(n_1, dtype=dctkit.float_dtype)
cD0 = C.CochainD0(complex=S_1, coeffs=vD0)
kernel = 4*np.arange(1, 4)
kernel_coeffs = np.zeros_like(vD0)
kernel_coeffs[:len(kernel)] = kernel
kernel_coch = C.CochainD0(complex=S_1, coeffs=kernel_coeffs)
conv = C.convolution(cD0, kernel_coch, len(kernel))
conv_true = np.array([3.2, 5.6, 8., 10.4, 12.8, 15.2,
17.6, 20., 0., 0.]).reshape(-1, 1)
assert np.allclose(conv.coeffs, conv_true)


def test_coboundary_closure(setup_test):
mesh_2, _ = util.generate_square_mesh(1.0)
S_2 = util.build_complex_from_mesh(mesh_2, is_well_centered=False)
Expand Down
Loading