In [None]:
# import diffrax
import equinox as eqx
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt

# import multiprocess as mp
import numpy as np
from jax import grad, jit, vmap
from jax.experimental.ode import odeint

from typing import TYPE_CHECKING, Callable, Union, Optional

In [None]:
class VariablesClass(eqx.Module):
    """Class with Bistable buckle shape variables.

    Attributes given by the user:
        `k_soft`: jax.Array: torque constants for each hinge in soft direction
        `k_stiff`: jax.Array: torque constants for each hinge in soft direction

    Attributes computed by the class:
        
    """
    k_soft: jax.Array = eqx.field(init=False)
    """`(hinges,)` stiffnesses in soft direction"""

    k_stiff: jax.Array = eqx.field(init=False)
    """`(hinges,)` stiffnesses in stiff direction"""
    
    thetas_ss: jax.Array = eqx.field(init=False)
    """`(hinges.)` rest angles of hinges"""
    
    k_stretch: jax.Array = eqx.field(init=False)
    """`(hinges,)` stiffnesses of rods, very large so rods are stiff""" 
    
    def __init__(self, k_soft: jax.Array = None, k_stiff: jax.Array = None, thetas_ss: jax.Array = None,
                 stretch_scale: float = 50.0):
    
        self.k_soft  = jnp.ones(hinges, jnp.float32) if k_soft  is None else jnp.asarray(k_soft,  jnp.float32)
        self.k_stiff = jnp.ones(hinges, jnp.float32) if k_stiff is None else jnp.asarray(k_stiff, jnp.float32)
        self.thetas_ss = jnp.ones(hinges, jnp.float32) if thetas_ss is None else jnp.asarray(thetas_ss, jnp.float32)

        # A single stretch stiffness (applied to every edge)
        self.k_stretch = jnp.asarray(stretch_scale * jnp.max(k_stiff), jnp.float32)

In [None]:
class StructureClass(eqx.Module):
    """Bistable buckle structure (1D chain in the plane)."""
    
    # --- user-provided (static) ---
    hinges: int = eqx.field(static=True)   # number of hinges in the chain
    shims: int  = eqx.field(static=True)   # e.g. shim count per hinge (kept static)
    L: float = eqx.field(static=True)  # rest length of rods

    # --- computed in __init__ ---
    edges_arr:  jax.Array = eqx.field(init=False)   # (hinges+1, 2) point indices
    edges: int = eqx.field(init=False)
    hinges_arr: jax.Array = eqx.field(init=False)   # (hinges, 2)  edge indices
    rest_lengths: jax.Array = eqx.field(init=False) # (H+1,)   floats

    def __init__(self, hinges: int, shims: int, L: float, rest_lengths:  Optional[jax.Array] = None):
        self.hinges = int(hinges)
        self.shims  = int(shims)
        self.L = float(L)

        self.edges_arr  = self._build_edges(self.hinges)            # (E=hinges+1, 2)
        self.edges      = jnp.shape(self.edges_arr)[0]
        self.hinges_arr = self._build_hinges(self.hinges)           # (H=hinges, 2)
        
        self.rest_lengths = self._build_rest_lengths(self.hinges, L, rest_lengths=rest_lengths)  # rest lengths (float32)
        
    # --- builders ---
    @staticmethod
    def _build_edges(hinges: int) -> jax.Array:
        """Edges between consecutive points: [[0,1],[1,2],...,[hinges,hinges+1]]."""
        starts = jnp.arange(hinges + 1, dtype=jnp.int32)
        ends   = starts + 1
        return jnp.stack([starts, ends], axis=1)

    @staticmethod
    def _build_hinges(hinges: int) -> jax.Array:
        """Hinges connect consecutive edges: [[0,1],[1,2],...,[hinges-1,hinges]]."""
        if hinges <= 0:
            return jnp.empty((0, 2), dtype=jnp.int32)
        starts = jnp.arange(hinges, dtype=jnp.int32)
        ends   = starts + 1
        return jnp.stack([starts, ends], axis=1)
    
    @staticmethod
    def _build_rest_lengths(hinges: int, L, *, rest_lengths: Optional[jax.Array]) -> jax.Array:   
        edges = hinges + 1
        # 1) user-provided non-uniform
        if rest_lengths is not None:
            rl = jnp.asarray(rest_lengths, jnp.float32)
            assert rl.shape == (E,), f"rest_lengths shape {rl.shape} != ({E},)"
            return rl
        # 3) uniform spacing
        return jnp.full((edges,), L, dtype=jnp.float32)
    
    # vectorized helpers (handy + jit-friendly)
    def all_edge_lengths(self, pos: jax.Array) -> jax.Array:
        return jax.vmap(lambda e: self._get_edge_length(pos, e))(jnp.arange(self.edges))

    def all_hinge_angles(self, pos: jax.Array) -> jax.Array:
        return jax.vmap(lambda h: self._get_theta(pos, h))(jnp.arange(self.hinges))
    
    # --- geometry ---      
    def _get_theta(self, pos_arr: jax.Array, hinge: int):
        """Angle at a hinge (radians), CCW positive."""
        fourpoints = pos_arr[self.edges_arr[self.hinges_arr[hinge]]]# (2,2,2) coords for each edge's endpoints
        vecs = fourpoints[:, 1, :] - fourpoints[:, 0, :]   # (2,2)
        u, v = vecs[:-1], vecs[1:]
        dot = jnp.sum(u * v, axis=-1)
        cross_z = u[..., 0] * v[..., 1] - u[..., 1] * v[..., 0]  # scalar z of 2D cross
        theta = jnp.arctan2(cross_z, dot)         # signed angle from u -> v        
        return theta
    
    def _get_edge_length(self, pos_arr, edge):
        """Length of one edge given current positions pos: (Npoints,2) float."""
        twopoints = pos_arr[self.edges_arr[edge]] # (2,)
        vec = twopoints[1, :] - twopoints[0, :]  # (2,)
        return jnp.linalg.norm(vec) 

In [None]:
class StateClass(eqx.Module):
    """
    Dynamic state of the chain (positions + hinge stiffness regime).
    """
    
    # ---- state / derived ----
    pos_arr: jax.Array             # (H+2, 2) current positions (float)
    rest_lengths: jax.Array        # (H+1,) edge rest lengths (from initial pos)
    initial_hinge_angles: jax.Array  # (H,) hinge angles at rest (usually zeros)
    
    pos_arr: jax.Array = eqx.field(init=False)   # (hinges+2, 2) integer coordinates
    buckle: jax.Array           # (H,) ∈ {+1,-1} per hinge/shim (direction of stiff side)

    # derived (from initial layout)
    rest_lengths: jax.Array           # (H+1,) edge rest lengths
    initial_hinge_angles: jax.Array   # (H,)   hinge rest angles
    
#     # calcaulted
#     k_rot_state: jax.Array  # (H,) effective hinge stiffnesses, soft or stiff, theta dependent
    
    def __init__(self, Strctr: "StructureClass", buckle: jax.Array):
        # default buckle: all +1
        if buckle is None:
            self.buckle = jnp.ones((Strctr.hinges), dtype=jnp.int32)
        else:
            self.buckle = jnp.asarray(buckle, jnp.int32)
            assert self.buckle.shape == (Strctr.hinges,)
            
        self.pos_arr = self._initiate_pos(Strctr.hinges)  # (N=hinges+2, 2)
        # with a straight chain, each edge's rest length = init_spacing
        self.rest_lengths = jnp.full((Strctr.hinges + 1,), Strctr.L, dtype=jnp.float32)
        # straight chain -> 0 resting hinge angles
        self.initial_hinge_angles = jnp.zeros((Strctr.hinges,), dtype=jnp.float32)
            
    # --- build ---
    
    @staticmethod
    def _initiate_pos(hinges: int) -> jax.Array:
        """`(hinges+2, 2)` each pair is (xi, yi) of point i going like [[0, 0], [1, 0], [2, 0], etc]"""
        x = jnp.arange(hinges + 2, dtype=jnp.int32)
        return jnp.stack([x, jnp.zeros_like(x)], axis=1)
    
    @eqx.filter_jit
    def energy(self, Variabs: "VariablesClass", Strctr: "StructureClass"):
        """Compute the potential energy of the origami with the resting positions as reference"""

        thetas = vmap(lambda h: Strctr._get_theta(self.pos_arr, h))(jnp.arange(Strctr.hinges))
        edges_length = vmap(lambda e: Strctr._get_edge_length(self.pos_arr, e))(jnp.arange(Strctr.edges))
        stiff_mask = (
                        ((self.buckle == 1)  & (thetas > -Variabs.thetas_ss))
                        | ((self.buckle == -1) & (thetas < Variabs.thetas_ss))
                    )
        k_rot_state = jnp.where(stiff_mask, Variabs.k_stiff, Variabs.k_soft)   # (H,)
        # self.ks_from_state = vmap(lambda i: self.k_from_stat(Variabs, Strctr, hinge))

        rotation_energy = 0.5 * jnp.sum(
            k_rot_state * (thetas - Variabs.thetas_ss) ** 2
        )
        stretch_energy = 0.5 * jnp.sum(
            Variabs.k_stretch * (edges_length - Strctr.rest_lengths) ** 2
        )

        total_energy = rotation_energy + stretch_energy
        return jnp.array([total_energy, rotation_energy, stretch_energy])
    
#     def k_from_state(self, Variabs, Strctr, hinge, shim):
#         theta = Strctr.theta()
#         if self.buckle[hinge] == 1 and self.thetas[hinge] > -theta_ss[hinge] or 
#             buckle[hinge] == -1 and self.theta[hinge] < theta_ss[hinge]:
#             k = self.k_stiff[hinge]
#         else:
#             k = k_soft[hinge]
#         return E_rot_hinge

    @eqx.filter_jit
    def total_potential_energy(self, variabs: "VariablesClass", strctr: "StructureClass") -> jax.Array:
        return self.energy(variabs, strctr)[0]

In [None]:
H, S = 5, 1

# --- build geometry (all topology stays in StructureClass) ---
Strctr = StructureClass(hinges=H, shims=S, L=1)  # your StructureClass from earlier

# --- parameters / variables ---
k_soft  = jnp.ones((H), dtype=jnp.float32) * 1.0
k_stiff = jnp.ones((H), dtype=jnp.float32) * 10.0
thetas_ss = jnp.full((H), jnp.deg2rad(33.0), dtype=jnp.float32)  # 33° per hinge
buckle  = jnp.ones((H), dtype=jnp.int32)

Variabs = VariablesClass(
    k_soft=k_soft,
    k_stiff=k_stiff,
    thetas_ss=thetas_ss,           # rest/target angles
    stretch_scale=10.0,              # k_stretch = 50 * max(k_stiff)
)

# --- state (straight chain, unit spacing => rest lengths = 1) ---
State = StateClass(Strctr, buckle)  # buckle defaults to +1

# --- energy ---
E_total, E_rot, E_stretch = State.energy(Variabs, Strctr)
print("Total:", float(E_total))
print("Rotation:", float(E_rot))
print("Stretch:", float(E_stretch))