Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Flat #51

Merged
merged 5 commits into from
Mar 26, 2024
Merged

Flat #51

Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 73 additions & 0 deletions src/dctkit/dec/cochain.py
Original file line number Diff line number Diff line change
Expand Up @@ -394,3 +394,76 @@ def sym(c: Cochain) -> Cochain:
its symmetric part.
"""
return scalar_mul(add(c, transpose(c)), 0.5)


def convolution(c: Cochain, kernel: Cochain, kernel_window: float) -> Cochain:
""" Compute the convolution between two scalar 0-cochains.

Args:
c: a scalar 0-cochain.
kernel: the scalar 0-cochain kernel.
kernel_window: the kernel window.

Returns:
the convolution rho*kernel.
"""
# we build a kernel matrix K by rolling the kernel vector k + 1 times, where
# k is the kernel window. In this way we can express the
# convolution between c and k as SK @c_coeffs where S is the hodge star
n = len(c.coeffs)
K = jnp.zeros((n, n), dtype=dt.float_dtype)
buffer = jnp.empty((n, n*2 - 1))

# generate a wider array that we want a slice into
buffer = buffer.at[:, :n].set(kernel.coeffs[:n].T)
buffer = buffer.at[:, n:].set(kernel.coeffs[:n-1].T)

rolled = buffer.reshape(-1)[n-1:-1].reshape(n, -1)
K_full_roll = jnp.roll(rolled[:, :n], shift=1, axis=0)
K_non_zero = K_full_roll[:n - kernel_window + 1]
K = K.at[:n - kernel_window + 1, :].set(K_non_zero)
kernel_coch = Cochain(c.dim, c.is_primal, c.complex, K)
Smantii marked this conversation as resolved.
Show resolved Hide resolved

# apply hodge star
star_kernel = star(kernel_coch)
conv = Cochain(c.dim, c.is_primal, c.complex, star_kernel.coeffs@c.coeffs)
return conv


def constant_sub(k: float, c: Cochain) -> Cochain:
Smantii marked this conversation as resolved.
Show resolved Hide resolved
"""Compute the cochain subtraction between a constant cochain and another cochain.

Args:
k: a constant.
c: a cochain.

Returns:
the resulting subtraction
"""
return Cochain(c.dim, c.is_primal, c.complex, k - c.coeffs)


def abs(c: Cochain) -> Cochain:
""" Compute the absolute value of a cochain.

Args:
c: a cochain.

Returns:
its absolute value.
"""
return Cochain(c.dim, c.is_primal, c.complex, jnp.abs(c.coeffs))


def maximum(c_1: Cochain, c_2: Cochain) -> Cochain:
""" Compute the component-wise maximum between two cochains.

Args:
c_1: a cochain.
c_2: a cochain.

Returns:
the component-wise maximum
"""
return Cochain(c_1.dim, c_1.is_primal, c_1.complex,
jnp.maximum(c_1.coeffs, c_2.coeffs))
19 changes: 14 additions & 5 deletions src/dctkit/dec/flat.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,12 @@
from dctkit.dec import cochain as C
from jax import Array, vmap
from functools import partial
from typing import Callable, Dict, Optional


def flat(c: C.CochainP0 | C.CochainD0, weights: Array,
edges: C.CochainP1V | C.CochainD1V) -> C.CochainP1 | C.CochainD1:
def flat(c: C.CochainP0 | C.CochainD0, weights: Array, edges: C.CochainP1V |
C.CochainD1V, weighted_I: Optional[Callable] = None,
Smantii marked this conversation as resolved.
Show resolved Hide resolved
weighted_I_args: Optional[Dict] = {}) -> C.CochainP1 | C.CochainD1:
"""Applies the flat to a vector/tensor-valued cochain representing a discrete
vector/tensor field to obtain a scalar-valued cochain over primal/dual edges.

Expand All @@ -19,12 +21,19 @@ def flat(c: C.CochainP0 | C.CochainD0, weights: Array,
interpolation scheme chosen for the input discrete vector/tensor field.
edges: vector-valued cochain collecting the primal/dual edges over which the
discrete vector/tensor field should be integrated.
weighted_I: interpolation function (callable) taking in input the cochain c
and providing in output a 1-cochain of the same type (primal/dual). If
it is None, then an interpolation function is built as W^T@c.coeffs.
weighted_I_args: additional keyword arguments for weighted_I
Returns:
a primal/dual scalar/vector-valued cochain defined over primal/dual edges.
"""
# contract over the simplices of the input cochain (last axis of weights, first axis
# of input cochain coeffs)
weighted_v = jnp.tensordot(weights.T, c.coeffs, axes=1)
if weighted_I is None:
# contract over the simplices of the input cochain (last axis of weights,
# first axis of input cochain coeffs)
def weighted_I(x): return jnp.tensordot(weights.T, x.coeffs, axes=1)
weighted_I_args = {}
weighted_v = weighted_I(c, **weighted_I_args)
# contract input vector/tensors with edge vectors (last indices of both
# coefficient matrices), for each target primal/dual edge
contract = partial(jnp.tensordot, axes=([-1,], [-1,]))
Expand Down
17 changes: 17 additions & 0 deletions tests/test_cochain.py
Original file line number Diff line number Diff line change
Expand Up @@ -605,6 +605,23 @@ def test_codifferential(setup_test):
assert np.allclose(inner_all[i], cod_inner_all[i])


def test_convolution(setup_test):
mesh_1, _ = util.generate_line_mesh(11, 1.)
S_1 = util.build_complex_from_mesh(mesh_1)
S_1.get_hodge_star()
n_1 = S_1.S[1].shape[0]
vD0 = np.arange(n_1, dtype=dctkit.float_dtype)
cD0 = C.CochainD0(complex=S_1, coeffs=vD0)
kernel = 4*np.arange(1, 4)
kernel_coeffs = np.zeros_like(vD0)
kernel_coeffs[:len(kernel)] = kernel
kernel_coch = C.CochainD0(complex=S_1, coeffs=kernel_coeffs)
conv = C.convolution(cD0, kernel_coch, len(kernel))
conv_true = np.array([3.2, 5.6, 8., 10.4, 12.8, 15.2,
17.6, 20., 0., 0.]).reshape(-1, 1)
assert np.allclose(conv.coeffs, conv_true)


def test_coboundary_closure(setup_test):
mesh_2, _ = util.generate_square_mesh(1.0)
S_2 = util.build_complex_from_mesh(mesh_2, is_well_centered=False)
Expand Down
Loading