In [None]:
import torch
import numpy as np

In [None]:
from strains_torch import get_strain_stretch2D_torch, get_strain_curvature_3D_torch

In [None]:
def elastic_force(self, q: torch.Tensor) -> torch.Tensor:
    """
    q: (..., ndof)
    returns f_elastic: (..., ndof)
    """
    bend_springs = self.bend_springs.to(q.device)  # (n_springs, 3) with (i, j, k)
    f_full = torch.zeros_like(q)
    l_eff = 0.1

    for spring in bend_springs:
        i, j, k = spring

        dof_indices = [
            3 * i, 3 * i + 1, 3 * i + 2,
            3 * j, 3 * j + 1, 3 * j + 2,
            3 * k, 3 * k + 1, 3 * k + 2,
        ]

        node0 = q[..., 3 * i:3 * i + 3]
        node1 = q[..., 3 * j:3 * j + 3]
        node2 = q[..., 3 * k:3 * k + 3]

        q_spring = torch.stack([node0, node1, node2], dim=-2)  # (..., 3, 3)
        q_spring = q_spring.clone().requires_grad_(True)

        node0_s = q_spring[..., 0, :]
        node1_s = q_spring[..., 1, :]
        node2_s = q_spring[..., 2, :]

        longitudinal_strain = get_strain_stretch2D_torch(
            node0_s, node1_s, node2_s, l_eff, l_eff
        )
        curvature = get_strain_curvature_3D_torch(
            node0_s, node1_s, node2_s, l_eff
        )

        strains = torch.stack([longitudinal_strain, curvature], dim=-1)  # (..., 2)

        E_spring = self.energy_model(strains)  # (..., 1) or scalar

        (dE_dq_spring,) = torch.autograd.grad(
            outputs=E_spring.sum(),
            inputs=q_spring,
            create_graph=True
        )   # (..., 3, 3)

        f_spring = -dE_dq_spring.reshape(*q.shape[:-1], 9)  # (..., 9)
        f_full[..., dof_indices] += f_spring

    return f_full
