In [None]:
import numpy as np
from tqdm.auto import tqdm

def transition_function(x: float) -> float:
    if x <= -0.5:
        return -1.0
    if x >= 0.5:
        return 1.0
    t1 = np.exp(-2.0 / (2.0 * x + 1.0))
    t2 = np.exp( 2.0 / (2.0 * x - 1.0))
    return 2.0 * t1 / (t1 + t2) - 1.0

def theta_eff(s, psi1, p):
    rot = np.cos(0.5 * psi1 / p['effective_frac'])
    wave_pos = (s - (p['effective_bend_speed'] * psi1 / (2.0 * np.pi) / p['frequency'])) / max(p['effective_bend_width'], 1e-12) + 0.5
    wave = -transition_function(wave_pos)
    return p['amplitude'] * ((1.0 - p['effective_bend_importance']) * rot - p['effective_bend_importance'] * wave)

def theta_rec(s, psi1, p):
    rot = np.cos(0.5 * psi1 / p['recovery_frac'] + np.pi)
    wave_pos = (s - (p['recovery_bend_speed'] * psi1 / (2.0 * np.pi) / p['frequency'])) / max(p['recovery_bend_width'], 1e-12) + 0.5
    wave = transition_function(wave_pos)
    return p['amplitude'] * ((1.0 - p['recovery_bend_importance']) * rot - p['recovery_bend_importance'] * wave)

def theta(s, psi1, p):
    mod_psi1 = (psi1 - 2.0 * np.pi * p['phase_delay'] * s / p['length']) % (2.0 * np.pi)
    psi_eff = 2.0 * np.pi * p['effective_frac']
    if mod_psi1 < psi_eff:
        return theta_eff(s, mod_psi1, p)
    else:
        return theta_rec(s, mod_psi1 - psi_eff, p)

def rot_y(psi2):
    c, s = np.cos(psi2), np.sin(psi2)
    return np.array([[ c, 0.0,  s],
                     [0.0, 1.0, 0.0],
                     [-s, 0.0,  c]], dtype=float)

def quat_to_rot_matrix(q):
    q = np.array(q, dtype=float).reshape(-1)
    assert q.size == 4, "Quaternion must be length-4 [w, x, y, z]"
    w, x, y, z = q
    n = w*w + x*x + y*y + z*z
    if n < 1e-24:
        return np.eye(3)
    s = 2.0 / n
    wx, wy, wz = w*x*s, w*y*s, w*z*s
    xx, xy, xz = x*x*s, x*y*s, x*z*s
    yy, yz, zz = y*y*s, y*z*s, z*z*s
    return np.array([
        [1.0 - (yy + zz),       xy - wz,       xz + wy],
        [      xy + wz, 1.0 - (xx + zz),       yz - wx],
        [      xz - wy,       yz + wx, 1.0 - (xx + yy)],
    ], dtype=float)

def cilium_element_position(
    s, psi1, psi2, params, base=None, steps=200, *,
    orientation_quat=None, orientation_matrix=None
):
    # Required params keys:
    # length, frequency, amplitude, effective_frac, phase_delay,
    # effective_bend_importance, recovery_bend_importance,
    # relative_effective_bend_width, relative_recovery_bend_width
    p = dict(params)
    p['recovery_frac'] = 1.0 - p['effective_frac']
    p['effective_duration'] = p['effective_frac'] / p['frequency']
    p['recovery_duration'] = p['recovery_frac'] / p['frequency']
    p['effective_bend_width'] = p['relative_effective_bend_width'] * p['length']
    p['recovery_bend_width'] = p['relative_recovery_bend_width'] * p['length']
    p['effective_bend_speed'] = (p['length'] + p['effective_bend_width']) / max(p['effective_duration'], 1e-12)
    p['recovery_bend_speed'] = (p['length'] + p['recovery_bend_width']) / max(p['recovery_duration'], 1e-12)

    s = float(np.clip(s, 0.0, p['length']))

    # Compose rotations: orientation * RotY(psi2)
    R_y = rot_y(psi2)
    if orientation_matrix is not None:
        R_o = np.array(orientation_matrix, dtype=float).reshape(3, 3)
    elif orientation_quat is not None:
        R_o = quat_to_rot_matrix(orientation_quat)
    else:
        R_o = np.eye(3)
    R = R_o @ R_y

    pos = np.zeros(3) if base is None else np.array(base, dtype=float)
    if s == 0.0:
        return pos

    steps = max(1, int(steps))
    ds = s / steps
    for i in range(1, steps + 1):
        si = i * ds
        th = theta(si, psi1, p)
        tangent = np.array([np.cos(th), 0.0, np.sin(th)])
        pos = pos + R @ tangent * ds
    return pos

# Example:
# params = {
#     'length': 10e-6, 'frequency': 30.0, 'amplitude': np.deg2rad(40.0),
#     'effective_frac': 0.6, 'phase_delay': 0.1,
#     'effective_bend_importance': 0.4, 'recovery_bend_importance': 0.7,
#     'relative_effective_bend_width': 0.4, 'relative_recovery_bend_width': 0.6
# }
# p = cilium_element_position(5e-6, psi1=1.0, psi2=0.2, params=params)
# print(p)

In [None]:
# Build segment positions from phases and filament references/quaternions
def parse_phases_from_true_states(phase_data, num_fils):
    M = phase_data.shape[1]
    # Heuristics for column layout
    if M >= 1 + 2*num_fils:
        psi1 = phase_data[:, 1:1+num_fils]
        psi2 = phase_data[:, 1+num_fils:1+2*num_fils]
    elif M >= 2 + num_fils:
        psi1 = phase_data[:, 2:2+num_fils]
        psi2 = np.zeros_like(psi1)
    else:
        psi1 = phase_data[:, 1:1+num_fils]
        psi2 = np.zeros_like(psi1)
    return np.mod(psi1, 2*np.pi), np.mod(psi2, 2*np.pi)

def reconstruct_seg_positions(base_path, num_fils, num_segs, params, *,
                              true_states_path=None, fil_refs_path=None, fil_q_path=None,
                              stride=1, steps_per_length=200, progress=True):
    """
    Reconstruct seg_positions (T, num_fils, num_segs, 3) when segment states are missing.
    """
    # Resolve paths
    if true_states_path is None:
        true_states_path = f"{base_path}_true_states.dat"
    if fil_refs_path is None:
        fil_refs_path = f"{base_path}_fil_references.dat"
    if fil_q_path is None:
        fil_q_path = f"{base_path}_fil_q.dat"

    # Load inputs
    phase_data = np.loadtxt(true_states_path)

    # Bases: support single-line "x1 y1 z1 x2 y2 z2 ..." or (num_fils, >=3)
    bases_raw = np.array(np.loadtxt(fil_refs_path), dtype=float)
    if bases_raw.ndim == 1:
        if bases_raw.size < 3 * num_fils:
            raise ValueError(f"fil_references has {bases_raw.size} values; expected at least {3*num_fils}")
        bases = bases_raw[:3*num_fils].reshape(num_fils, 3)
    else:
        if bases_raw.shape[1] < 3 or bases_raw.shape[0] < num_fils:
            raise ValueError(f"fil_references shape {bases_raw.shape} incompatible with num_fils={num_fils}")
        bases = bases_raw[:num_fils, :3]

    # Quats: support single-line "w1 x1 y1 z1 w2 x2 y2 z2 ..." or (num_fils, >=4)
    quats_raw = np.array(np.loadtxt(fil_q_path), dtype=float)
    if quats_raw.ndim == 1:
        if quats_raw.size < 4 * num_fils:
            raise ValueError(f"fil_q has {quats_raw.size} values; expected at least {4*num_fils}")
        quats = quats_raw[:4*num_fils].reshape(num_fils, 4)
    else:
        if quats_raw.shape[1] < 4 or quats_raw.shape[0] < num_fils:
            raise ValueError(f"fil_q shape {quats_raw.shape} incompatible with num_fils={num_fils}")
        quats = quats_raw[:num_fils, :4]

    psi1_all, psi2_all = parse_phases_from_true_states(phase_data, num_fils)
    time = phase_data[:, 0]

    # Prepare arclength sampling
    L = float(params['length'])
    s_vals = np.linspace(0.0, L, num_segs)
    base_steps = max(50, int(steps_per_length))

    frames = np.arange(0, time.size, int(max(1, stride)))
    seg_positions = np.empty((frames.size, num_fils, num_segs, 3), dtype=float)
    phases = np.empty((frames.size, num_fils), dtype=float)

    frame_iter = tqdm(frames, desc="Reconstructing frames", unit="frame") if progress else frames
    for ti, t in enumerate(frame_iter):
        psi1_t = psi1_all[t]
        psi2_t = psi2_all[t] if psi2_all.size else np.zeros(num_fils)

        fil_iter = tqdm(range(num_fils), desc="Filaments", unit="fil", leave=False) if progress else range(num_fils)
        for i in fil_iter:
            base = bases[i]
            quat = quats[i]
            for j, s in enumerate(s_vals):
                steps = max(1, int(base_steps * (s / max(L, 1e-12))))
                seg_positions[ti, i, j, :] = cilium_element_position(
                    s, float(psi1_t[i]), float(psi2_t[i]), params,
                    base=base, steps=steps, orientation_quat=quat
                )
        phases[ti] = psi1_t

    return seg_positions, phases[...], time[frames], frames

# Example usage:
# base_path = '../data/no_tilt_4/20250730/ciliate_309fil_18000blob_8.00R_0.1500torsion_0.0000tilt_0.3000f_eff_1.4960theta0_0.0000freqshift'
# params = {
#     'length': 10e-6, 'frequency': 30.0, 'amplitude': np.deg2rad(40.0),
#     'effective_frac': 0.6, 'phase_delay': 0.1,
#     'effective_bend_importance': 0.4, 'recovery_bend_importance': 0.7,
#     'relative_effective_bend_width': 0.4, 'relative_recovery_bend_width': 0.6
# }
# seg_positions, phases, time, frames = reconstruct_seg_positions(base_path, num_fils=309, num_segs=20, params=params)

In [6]:
base_path = '../data/cell_gaps_2/20250901/ciliate_298fil_18000blob_8.00R_0.1500torsion_0.2182tilt_0.3000f_eff_1.4960theta0_0.0000freqshift'
params = {
    'length': 49.4, 'frequency': 2*np.pi, 'amplitude': np.pi/2.1,
    'effective_frac': 0.3, 'phase_delay': 0.05,
    'effective_bend_importance': 0.1, 'recovery_bend_importance': 0.85,
    'relative_effective_bend_width': 0.8, 'relative_recovery_bend_width': 0.4
}
seg_positions, phases, time, frames = reconstruct_seg_positions(base_path, num_fils=298, num_segs=20, params=params)

KeyboardInterrupt: 

In [None]:
from pathlib import Path
import json

def save_reconstructed_npz(out_path, seg_positions, phases, time, frames, params, base_path, *, to_float32=True, compress=True):
    out_path = Path(out_path)
    out_path.parent.mkdir(parents=True, exist_ok=True)

    sp = seg_positions.astype(np.float32, copy=False) if to_float32 else seg_positions
    ph = phases.astype(np.float32, copy=False)
    tm = time.astype(np.float32, copy=False) if to_float32 else time
    fr = frames.astype(np.int32, copy=False)

    meta = json.dumps({"params": params, "base_path": str(base_path)})
    if compress:
        np.savez_compressed(out_path, seg_positions=sp, phases=ph, time=tm, frames=fr, meta=meta)
    else:
        np.savez(out_path, seg_positions=sp, phases=ph, time=tm, frames=fr, meta=meta)

    try:
        sz_mb = out_path.stat().st_size / (1024**2)
        print(f"Saved: {out_path} ({sz_mb:.2f} MB)")
    except Exception:
        print(f"Saved: {out_path}")

In [None]:
out_dir = Path("analysis_output")
npz_path = out_dir / f"{Path(base_path).name}_reconstructed.npz"
save_reconstructed_npz(npz_path, seg_positions, phases, time, frames, params, base_path)