Skip to content

Commit

Permalink
Improved comments and docs
Browse files Browse the repository at this point in the history
  • Loading branch information
Smantii authored and alucantonio committed Jul 14, 2023
1 parent ccd3f61 commit 6ca4d44
Show file tree
Hide file tree
Showing 4 changed files with 89 additions and 28 deletions.
7 changes: 4 additions & 3 deletions src/dctkit/dec/cochain.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,14 +294,15 @@ def coboundary(c: Cochain) -> Cochain:
return dc


def coboundary_closure(c: CochainD) -> CochainD:
def coboundary_closure(c: CochainP) -> CochainD:
"""Implements the operator that complements the coboundary on the boundary
of dual (n-1)-simplices, where n is the dimension of the complex.
Args:
c: a dual (n-1)-cochain
c: a primal (n-1)-cochain
Returns:
the coboundary closure of c.
the coboundary closure of c, resulting in a dual n-cochain with non-zero
coefficients in the "uncompleted" cells.
"""
n = c.complex.dim
Expand Down
102 changes: 81 additions & 21 deletions src/dctkit/physics/elasticity.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,22 +135,24 @@ def force_balance_residual_dual(self, node_coords: C.CochainP0, f: C.CochainD2,
strain = self.get_GreenLagrange_strain(node_coords=node_coords.coeffs)
stress = self.get_stress(strain=strain)
stress_tensor = V.DiscreteTensorFieldD(S=self.S, coeffs=stress.T, rank=2)
# compute forces on dual edges
stress_integrated = V.flat_DPP(stress_tensor)
forces = C.star(stress_integrated)
# FIXME: comment it!
# compute the tractions on boundary primal edges
forces_closure = C.star(V.flat_DPD(stress_tensor))
# set tractions on given sub-portions of the boundary
forces_closure_update = self.set_boundary_tractions(
forces_closure, boundary_tractions)
balance_forces_closure = C.coboundary_closure(forces_closure_update)
balance = C.add(C.coboundary(forces), balance_forces_closure)
residual = C.add(balance, f)
return residual

def obj_linear_elasticity(self, node_coords: npt.NDArray | Array,
f: npt.NDArray | Array, gamma: float, boundary_values:
Dict[str, Tuple[Array, Array]],
boundary_tractions: Dict[str, Tuple[Array, Array]],
is_dual_balance=None) -> float:
def obj_linear_elasticity_primal(self, node_coords: npt.NDArray | Array,
f: npt.NDArray | Array, gamma: float,
boundary_values: Dict[str, Tuple[Array, Array]],
boundary_tractions:
Dict[str, Tuple[Array, Array]]) -> float:
"""Objective function of the optimization problem associated to linear
elasticity balance equation with Dirichlet boundary conditions on a portion
of the boundary.
Expand Down Expand Up @@ -178,23 +180,81 @@ def obj_linear_elasticity(self, node_coords: npt.NDArray | Array,
"""
node_coords_reshaped = node_coords.reshape(self.S.node_coords.shape)
node_coords_coch = C.CochainP0(complex=self.S, coeffs=node_coords_reshaped)
if is_dual_balance is not None:
f = f.reshape((self.S.num_nodes, self.S.space_dim-1))
f_coch = C.CochainD2(complex=self.S, coeffs=f)
residual = self.force_balance_residual_dual(
node_coords_coch, f_coch, boundary_tractions).coeffs
else:
f = f.reshape((self.S.S[2].shape[0], self.S.space_dim-1))
f_coch = C.CochainP2(complex=self.S, coeffs=f)
residual = self.force_balance_residual_primal(
node_coords_coch, f_coch, boundary_tractions).coeffs
f = f.reshape((self.S.S[2].shape[0], self.S.space_dim-1))
f_coch = C.CochainP2(complex=self.S, coeffs=f)
residual = self.force_balance_residual_primal(
node_coords_coch, f_coch, boundary_tractions).coeffs
penalty = self.set_displacement_bc(node_coords=node_coords_reshaped,
boundary_values=boundary_values,
gamma=gamma)
energy = jnp.sum(residual**2) + penalty
return energy

def obj_linear_elasticity_dual(self, node_coords: npt.NDArray | Array,
f: npt.NDArray | Array, gamma: float,
boundary_values: Dict[str, Tuple[Array, Array]],
boundary_tractions:
Dict[str, Tuple[Array, Array]]) -> float:
"""Objective function of the optimization problem associated to linear
elasticity balance equation with Dirichlet boundary conditions on a portion
of the boundary.
Args:
node_coords: 1-dimensional array obtained after flattening
the matrix with node coordinates arranged row-wise.
f: 1-dimensional array obtained after flattening the
matrix of external sources (constant term of the system).
gamma: penalty factor.
boundary_values: a dictionary of tuples. Each key represent the type of
coordinate to manipulate (x,y, or both), while each tuple consists of
two np.arrays in which the first encodes the indices of boundary
values, while the last encodes the boundary values themselves.
boundary_tractions: a dictionary of tuples. Each key represent the type
of coordinate to manipulate (x,y, or both), while each tuple consists
of two jax arrays, in which the first encordes the indices where we want
to impose the boundary tractions, while the last encodes the boundary
traction values themselves. It is None when we perform the force balance
on dual cells.
Returns:
the value of the objective function at node_coords.
"""
node_coords_reshaped = node_coords.reshape(self.S.node_coords.shape)
node_coords_coch = C.CochainP0(complex=self.S, coeffs=node_coords_reshaped)
f = f.reshape((self.S.num_nodes, self.S.space_dim-1))
f_coch = C.CochainD2(complex=self.S, coeffs=f)
residual = self.force_balance_residual_dual(
node_coords_coch, f_coch, boundary_tractions).coeffs
penalty = self.set_displacement_bc(node_coords=node_coords_reshaped,
boundary_values=boundary_values,
gamma=gamma)
energy = jnp.sum(residual**2) + penalty
return energy

def set_displacement_bc(self, node_coords: npt.NDArray | Array, boundary_values:
Dict[str, Tuple[Array, Array]],
gamma: float) -> float:
""" Set displacement boundary conditions as a quadratic penalty term.
Args:
node_coords: node coordinates of the current configuration.
boundary_values: a dictionary of tuples. Each key represent the type of
coordinate to manipulate (x,y, or both), while each tuple consists of
two np.arrays in which the first encodes the indices of boundary
values, while the last encodes the boundary values themselves.
gamma: penalty factor.
Return:
the penalty term
"""
penalty = 0.
for key in boundary_values:
idx, values = boundary_values[key]
if key == ":":
penalty += jnp.sum((node_coords_reshaped[idx, :] - values)**2)
penalty += jnp.sum((node_coords[idx, :] - values)**2)
else:
penalty += jnp.sum((node_coords_reshaped[idx, :]
[:, int(key)] - values)**2)
energy = jnp.sum(residual**2) + gamma*penalty
return energy
penalty += jnp.sum((node_coords[idx, :]
[:, int(key)] - values)**2)
return gamma*penalty
2 changes: 1 addition & 1 deletion tests/test_cochain.py
Original file line number Diff line number Diff line change
Expand Up @@ -532,7 +532,7 @@ def test_coboundary_closure(setup_test):
S_2 = util.build_complex_from_mesh(mesh_2, is_well_centered=False)
S_2.get_hodge_star()

c = C.CochainD1(complex=S_2, coeffs=np.arange(1, 9, dtype=dctkit.float_dtype))
c = C.CochainP1(complex=S_2, coeffs=np.arange(1, 9, dtype=dctkit.float_dtype))
cob_clos_c = C.coboundary_closure(c)
cob_clos_c_true = np.array([-0.5, 2.5, 5., 2., 0.], dtype=dctkit.float_dtype)
assert np.allclose(cob_clos_c.coeffs, cob_clos_c_true)
6 changes: 3 additions & 3 deletions tests/test_linear_elasticity.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def test_linear_elasticity_primal(setup_test):

prb = optctrl.OptimizationProblem(dim=S.node_coords.size,
state_dim=S.node_coords.size,
objfun=ela.obj_linear_elasticity)
objfun=ela.obj_linear_elasticity_primal)

prb.set_obj_args(obj_args)
node_coords_flattened = S.node_coords.flatten()
Expand Down Expand Up @@ -148,11 +148,11 @@ def test_linear_elasticity_dual(setup_test):
f = np.zeros((S.num_nodes, (embedded_dim-1))).flatten()

obj_args = {'f': f, 'gamma': gamma, 'boundary_values': boundary_values,
'boundary_tractions': boundary_tractions, 'is_dual_balance': 1}
'boundary_tractions': boundary_tractions}

prb = optctrl.OptimizationProblem(dim=S.node_coords.size,
state_dim=S.node_coords.size,
objfun=ela.obj_linear_elasticity)
objfun=ela.obj_linear_elasticity_dual)

prb.set_obj_args(obj_args)
node_coords_flattened = S.node_coords.flatten()
Expand Down

0 comments on commit 6ca4d44

Please sign in to comment.