Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improving/generalizing flat operators. #48

Merged
merged 3 commits into from
Mar 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/dctkit/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import sys
import enum
import jax.numpy as jnp
from jax.config import config as cfg
from jax import config as cfg

if sys.version_info[:2] >= (3, 8):
# TODO: Import directly (no need for conditional) when `python_requires = >= 3.8`
Expand Down
57 changes: 1 addition & 56 deletions src/dctkit/dec/cochain.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,46 +261,6 @@ def coboundary(c: Cochain) -> Cochain:
return dc


def coboundary_closure(c: CochainP) -> CochainD:
"""Implements the operator that complements the coboundary on the boundary
of dual (n-1)-simplices, where n is the dimension of the complex.

Args:
c: a primal (n-1)-cochain
Returns:
the coboundary closure of c, resulting in a dual n-cochain with non-zero
coefficients in the "uncompleted" cells.
"""
n = c.complex.dim
num_tets = c.complex.S[n].shape[0]
num_dual_faces = c.complex.S[n-1].shape[0]

# to extract only the boundary components with the right orientation
# we construct a dual n-2 cochain and we take the (true) coboundary.
# In this way the obtain a cochain such that an entry is 0 if it's in
# the interior of the complex and ±1 if it's in the boundary
ones = CochainD(dim=n-2, complex=c.complex, coeffs=jnp.ones(num_tets))
diagonal_elems = coboundary(ones).coeffs
diagonal_matrix_rows = jnp.arange(num_dual_faces)
diagonal_matrix_cols = diagonal_matrix_rows
diagonal_matrix_COO = [diagonal_matrix_rows, diagonal_matrix_cols, diagonal_elems]

# build the absolute value of the (n-1)-coboundary
abs_dual_coboundary_faces = c.complex.boundary[n-1].copy()
# same of doing abs(dual_coboundary_faces)
abs_dual_coboundary_faces[2] = abs_dual_coboundary_faces[2]**2
# with this product, we extract with the right orientation the boundary pieces
diagonal_times_c = spmv.spmm(diagonal_matrix_COO, c.coeffs,
transpose=False,
shape=c.complex.S[n-1].shape[0])
# here we sum their contribution taking into account the orientation
d_closure_coeffs = spmv.spmm(abs_dual_coboundary_faces, diagonal_times_c,
transpose=False,
shape=c.complex.num_nodes)
d_closure = CochainD(dim=n, complex=c.complex, coeffs=0.5*d_closure_coeffs)
return d_closure


def star(c: Cochain) -> Cochain:
"""Implements the diagonal Hodge star operator (see Grinspun et al.).

Expand All @@ -324,7 +284,7 @@ def star(c: Cochain) -> Cochain:
return star_c


def inner_product(c1: Cochain, c2: Cochain) -> Array:
def inner(c1: Cochain, c2: Cochain) -> Array:
"""Computes the inner product between two cochains.

Args:
Expand Down Expand Up @@ -388,21 +348,6 @@ def laplacian(c: Cochain) -> Cochain:
return laplacian


def deformation_gradient(c: Cochain) -> Cochain:
"""Compute the deformation gradient of a primal vector-valued 0-cochain
representing the node coordinates.

Args:
c: a primal vector-valued 0 cochain.

Returns:
the deformation gradient for each dual nodes, i.e. a dual tensor-valued
0-cochain.
"""
F = c.complex.get_deformation_gradient(c.coeffs)
return Cochain(0, not c.is_primal, c.complex, F)


def transpose(c: Cochain) -> Cochain:
"""Compute the transpose of a tensor-valued cochain.

Expand Down
54 changes: 54 additions & 0 deletions src/dctkit/dec/flat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
import jax.numpy as jnp
from dctkit.dec import cochain as C
from jax import Array


def flat(c: C.CochainP0 | C.CochainD0, weights: Array,
edges: C.CochainP1V | C.CochainD1V) -> C.CochainP1 | C.CochainD1:

weighted_v = c.coeffs @ weights
if c.coeffs.ndim == 2:
# vector field case
# perform dot product row-wise with the edge vectors
# of the dual edges (see definition of DPD in Hirani, pag. 54).
weighted_v_T = weighted_v.T
coch_coeffs = jnp.einsum("ij, ij -> i", weighted_v_T, edges.coeffs)
elif c.coeffs.ndim == 3:
# tensor field case
# apply each matrix (rows of the multiarray weighted_v_T fixing the first axis)
# to the edge vector of the corresponding dual edge
weighted_v_T = jnp.transpose(weighted_v, axes=(2, 0, 1))
coch_coeffs = jnp.einsum("ijk, ik -> ij", weighted_v_T, edges.coeffs)

if edges.is_primal:
return C.CochainP1(c.complex, coch_coeffs)
else:
return C.CochainD1(c.complex, coch_coeffs)


def flat_DPD(c: C.CochainD0V | C.CochainD0T) -> C.CochainD1:
"""Implements the flat DPD operator for dual discrete vector fields.

Args:
v: a dual discrete vector field.
Returns:
the dual 1-cochain resulting from the application of the flat operator.
"""
dual_edges = c.complex.dual_edges_vectors[:, :c.coeffs.shape[0]]
flat_matrix = c.complex.flat_DPD_weights

return flat(c, flat_matrix, C.CochainD1(c.complex, dual_edges))


def flat_DPP(c: C.CochainD0V | C.CochainD0T) -> C.CochainP1:
"""Implements the flat DPP operator for dual discrete vector fields.

Args:
v: a dual discrete vector field.
Returns:
the primal 1-cochain resulting from the application of the flat operator.
"""
primal_edges = c.complex.primal_edges_vectors[:, :c.coeffs.shape[0]]
flat_matrix = c.complex.flat_DPP_weights

return flat(c, flat_matrix, C.CochainP1(c.complex, primal_edges))
164 changes: 0 additions & 164 deletions src/dctkit/dec/vector.py

This file was deleted.

43 changes: 43 additions & 0 deletions src/dctkit/experimental/coboundary_closure.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
import dctkit.dec.cochain as C
from dctkit.math import spmv
import jax.numpy as jnp


def coboundary_closure(c: C.CochainP) -> C.CochainD:
"""Implements the operator that complements the coboundary on the boundary
of dual (n-1)-simplices, where n is the dimension of the complex.

Args:
c: a primal (n-1)-cochain
Returns:
the coboundary closure of c, resulting in a dual n-cochain with non-zero
coefficients in the "uncompleted" cells.
"""
n = c.complex.dim
num_tets = c.complex.S[n].shape[0]
num_dual_faces = c.complex.S[n-1].shape[0]

# to extract only the boundary components with the right orientation
# we construct a dual n-2 cochain and we take the (true) coboundary.
# In this way the obtain a cochain such that an entry is 0 if it's in
# the interior of the complex and ±1 if it's in the boundary
ones = C.CochainD(dim=n-2, complex=c.complex, coeffs=jnp.ones(num_tets))
diagonal_elems = C.coboundary(ones).coeffs
diagonal_matrix_rows = jnp.arange(num_dual_faces)
diagonal_matrix_cols = diagonal_matrix_rows
diagonal_matrix_COO = [diagonal_matrix_rows, diagonal_matrix_cols, diagonal_elems]

# build the absolute value of the (n-1)-coboundary
abs_dual_coboundary_faces = c.complex.boundary[n-1].copy()
# same of doing abs(dual_coboundary_faces)
abs_dual_coboundary_faces[2] = abs_dual_coboundary_faces[2]**2
# with this product, we extract with the right orientation the boundary pieces
diagonal_times_c = spmv.spmm(diagonal_matrix_COO, c.coeffs,
transpose=False,
shape=c.complex.S[n-1].shape[0])
# here we sum their contribution taking into account the orientation
d_closure_coeffs = spmv.spmm(abs_dual_coboundary_faces, diagonal_times_c,
transpose=False,
shape=c.complex.num_nodes)
d_closure = C.CochainD(dim=n, complex=c.complex, coeffs=0.5*d_closure_coeffs)
return d_closure
4 changes: 2 additions & 2 deletions src/dctkit/physics/elastica.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,9 +66,9 @@ def energy(self, theta: npt.NDArray, B: float, theta_0: float, F: float) -> Arra

# potential of the applied load
A_coch = C.scalar_mul(self.ones_coch, A)
load = C.inner_product(C.sin(theta_coch), A_coch)
load = C.inner(C.sin(theta_coch), A_coch)

energy = 0.5*C.inner_product(moment, curvature) - load
energy = 0.5*C.inner(moment, curvature) - load

return energy

Expand Down
Loading
Loading