In [None]:
import torch
import numpy as np

In [None]:
from bend_energy_3D import get_strain_curvature_3D
from stretch_energy import get_strain_stretch2D

In [None]:
def elastic_force(self, q: torch.Tensor) -> torch.Tensor:
        """
        q: (..., ndof)
        returns f_elastic: (..., ndof)
        """
        # Make sure bend_springs is on the right device
        bend_springs = self.bend_springs.to(q.device)  # shape: (n_springs, 4) with (i,j,k,leff)

        f_full = torch.zeros_like(q)

        l_eff = 0.1

        for spring in bend_springs:
            i, j, k = spring

            node0 = q[..., 3 * i:3 * i + 3]
            node1 = q[..., 3 * j:3 * j + 3]
            node2 = q[..., 3 * k:3 * k + 3]

            dof_indices = [
                3 * i, 3 * i + 1, 3 * i + 2,
                3 * j, 3 * j + 1, 3 * j + 2,
                3 * k, 3 * k + 1, 3 * k + 2,
            ]

            q_spring = [node0, node1, node2]
            longitudinal_strain = get_strain_stretch2D(np.asarray(node0), np.asarray(node1), np.asarray(node2), l_eff, l_eff)
            curvature = get_strain_curvature_3D(np.asarray(node0), np.asarray(node1), np.asarray(node2))

            strains = torch.tensor([longitudinal_strain, curvature], dtype=q.dtype, device=q.device).unsqueeze(-1).requires_grad_(True)  # (..., 2, 1)

            E_spring = self.energy_model(strains)  # (..., 1) or scalar

            (dE_dq_spring,) = torch.autograd.grad(
                outputs=E_spring.sum(),   # <-- important
                inputs=q_spring,
                create_graph=True
            )

            f_full[..., dof_indices] -= dE_dq_spring

        return f_full