In [1]:
"""
One-point direction-consistency sweep + run-end statistics
=========================================================

This script **tests Conclusion (C2) – _Directional consistency within
zero–crossing intervals_ – for the rotational‐dynamics model of a cell
(θ) and an internal fibre (ψ).**  For every (p₁,p₂,a,b) parameter
combination it

1. **Builds the function  f(δ) = dψ/dt – dθ/dt** (δ ≔ ψ − θ) and locates
   all of its zeros inside a single π-period.  
   → `one_period_zeros()`  
   • dense grid scan on [−π, π]  
   • 30-step bisection refines each sign-change to ~10⁻⁸

2. **Chooses one test initial difference δ₀ inside every zero interval
   (z_L,z_R).**  
   The midpoint is used:
       δ₀ = ½ (z_L + z_R) ∈ (z_L, z_R).  
   Because f(δ) is continuous, its sign on that midpoint is the sign on
   the whole open interval.

3. **Predicts, using C2, which zero the trajectory must converge to.**  
   If sign f(δ₀) > 0 the flow should increase δ and hit z_R (index k+1);  
   otherwise it should decrease δ and hit z_L (index k).

4. **Integrates the full two-angle ODE from θ₀ = π/3, ψ₀ = θ₀ + δ₀ until
   δ(t) plateaus.**  
   Plateau test: a 10-time-unit sliding window (`WIN`) whose spread
   stays below 5×10⁻⁴ (`TOL_PLATEAU`).  
   → `converge_zero_index()`

5. **Compares the zero actually reached with the C2 prediction.**  
   Counts a *mismatch* if the wrong zero is reached, or *unsteady* if no
   plateau is found before `T_MAX`.

6. **Aggregates results over the whole sweep.**  
   • per-parameter-set table (intervals, mismatches, unsteady)  
   • global totals and distribution of “# intervals per π-period”

Successful execution with a mathematically correct C2 should report  
`total mismatches = 0` and `total unsteady launches = 0`.

Parameter sweep
---------------

``p1_list × p2_list × a_list × b_list``  
p₁ ∈ {0.5,1.0,1.5,2.0,2.5} | p₂ ∈ {2,…,6} | a ∈ {1,…,5} | b ∈ {0.5,…,2.5}

Key implementation choices
--------------------------

* **Midpoint sampling.**  Any interior δ₀ satisfies C2; the midpoint is
  deterministic, simple, and symmetric.
* **Plateau detection.**  A fixed 10-unit window balances speed against
  robustness; enlarge `WIN` or relax `TOL_PLATEAU` if extremely slow
  convergence appears.
* **Zero list closure.**  The routine always includes both endpoints
  *(start, start+π)* so that every sign-definite interval is represented.

Running the script prints a concise summary and can optionally save the
full dataframe to CSV for further analysis.
"""

import itertools, numpy as np, pandas as pd
from scipy.integrate import solve_ivp
from math import sin, cos, pi

# ───────── sweep lists ─────────
p1_list = [0.5, 1.0, 1.5, 2.0, 2.5]
p2_list = [2.0, 3.0, 4.0, 5.0, 6.0]
a_list  = [1.0, 2.0, 3.0, 4.0, 5.0]
b_list  = [0.5, 1.0, 1.5, 2.0, 2.5]

# ───────── integration / plateau params ─────────
T_MAX, DT = 100.0, 0.001
TIME_EVAL = np.arange(0.0, T_MAX + DT, DT)
WIN         = int(10 / DT)
TOL_PLATEAU = 5e-4

# ───────── system definitions ─────────
def single_cell_ode(t, y, p1, p2, a, b):
    θ, ψ = y
    sin2 = sin(2*(ψ-θ)); sin_ = sin(ψ-θ); cos_ = cos(ψ-θ)
    denom = 2*((a*sin_)**2 + (b*cos_)**2)**1.5
    num   = (a*b)*(a**2 - b**2)*sin2
    dθ = -1.0 - p1*num/denom
    dψ = -p2*sin2
    return [dθ, dψ]

def f_delta(d,p1,p2,a,b):
    s2=np.sin(2*d); s=np.sin(d); c=np.cos(d)
    denom=2*((a*s)**2+(b*c)**2)**1.5
    num=(a*b)*(a**2-b**2)*s2
    dθ=-1.0 - p1*num/denom
    dψ=-p2*s2
    return dψ - dθ

# ───────── helpers ─────────
def converge_zero_index(d0, zeros, p1,p2,a,b):
    θ0 = pi/3; ψ0 = θ0 + d0
    sol = solve_ivp(lambda t,y: single_cell_ode(t,y,p1,p2,a,b),
                    (0.0,T_MAX), [θ0,ψ0],
                    t_eval=TIME_EVAL, rtol=1e-6, atol=1e-9)
    if not sol.success: return None
    δ = sol.y[1] - sol.y[0]
    for i in range(len(δ)-WIN):
        w = δ[i:i+WIN]
        if np.max(np.abs(w-w[0])) < TOL_PLATEAU:
            steady = w[0]; break
    else: return None
    return int(np.argmin(np.abs(np.asarray(zeros)-steady)))

def one_period_zeros(p1,p2,a,b,n_scan=20000):
    x=np.linspace(-pi,pi,n_scan); y=f_delta(x,p1,p2,a,b)
    roots=[]
    for i in range(n_scan-1):
        if y[i]*y[i+1] < 0:
            lo,hi=x[i],x[i+1]
            for _ in range(30):
                mid=0.5*(lo+hi)
                if f_delta(lo,p1,p2,a,b)*f_delta(mid,p1,p2,a,b) < 0: hi=mid
                else: lo=mid
            roots.append(0.5*(lo+hi))
    roots.sort()
    if not roots: return []
    start=roots[0]
    per=[z for z in roots if start < z < start+pi]
    end=start+pi
    if abs(end-per[-1])>1e-8: per.append(end)
    per.insert(0,start)
    return per

def direction_check_one(p1,p2,a,b):
    zeros = one_period_zeros(p1,p2,a,b)
    if len(zeros)<2: return None
    mism=unst=0
    for k,(zL,zR) in enumerate(zip(zeros[:-1],zeros[1:])):
        mid = 0.5*(zL+zR)
        fsgn = np.sign(f_delta(mid,p1,p2,a,b))
        pred = k+1 if fsgn>0 else k
        idx  = converge_zero_index(mid, zeros, p1,p2,a,b)
        if idx is None: unst += 1
        elif idx != pred: mism += 1
    return dict(intervals=len(zeros)-1, mismatches=mism, unsteady=unst)

# ───────── main sweep ─────────
def main():
    combos=list(itertools.product(p1_list,p2_list,a_list,b_list))
    total_sets=len(combos)

    try:
        from tqdm import tqdm
        iterator=tqdm(combos,total=total_sets,desc="Sweeping")
    except ModuleNotFoundError:
        print("(Install tqdm for a progress bar)")
        iterator=combos

    rows=[]
    for p1,p2,a,b in iterator:
        res=direction_check_one(p1,p2,a,b)
        if res: rows.append(dict(p1=p1,p2=p2,a=a,b=b,**res))

    df=pd.DataFrame(rows)
    print("\nSweep summary (one-point check per interval):")
    if df.empty:
        print("  No parameter set had a zero.")
        return
    print(df[["p1","p2","a","b","intervals",
              "mismatches","unsteady"]].to_string(index=False))

    # ───── aggregate statistics ─────
    print("\nAggregate statistics")
    print("────────────────────")
    print(f" Parameter sets tested           : {total_sets}")
    print(f" Sets with ≥1 zero                : {len(df)}")
    print(f" Total intervals examined         : {df['intervals'].sum()}")
    print(f"   • total SIGN_MISMATCH intervals: {df['mismatches'].sum()}")
    print(f"   • total UNSTEADY intervals     : {df['unsteady'].sum()}")
    # distribution of interval counts
    counts = (df["intervals"]
              .value_counts()
              .sort_index()
              .rename_axis("# intervals")
              .to_frame("sets"))
    print("\nDistribution of interval counts per π-period:")
    print(counts.to_string())

    # optionally save everything
    # df.to_csv("direction_sweep_summary.csv", index=False)

if __name__ == "__main__":
    main()


Sweeping: 100%|██████████| 625/625 [01:52<00:00,  5.58it/s]


Sweep summary (one-point check per interval):
 p1  p2   a   b  intervals  mismatches  unsteady
0.5 2.0 1.0 0.5          2           0         0
0.5 2.0 1.0 1.0          2           0         0
0.5 2.0 1.0 1.5          2           0         0
0.5 2.0 1.0 2.0          2           0         0
0.5 2.0 1.0 2.5          2           0         0
0.5 2.0 2.0 0.5          2           0         0
0.5 2.0 2.0 1.0          2           0         0
0.5 2.0 2.0 1.5          2           0         0
0.5 2.0 2.0 2.0          2           0         0
0.5 2.0 2.0 2.5          2           0         0
0.5 2.0 3.0 0.5          4           0         0
0.5 2.0 3.0 1.0          2           0         0
0.5 2.0 3.0 1.5          2           0         0
0.5 2.0 3.0 2.0          2           0         0
0.5 2.0 3.0 2.5          2           0         0
0.5 2.0 4.0 0.5          4           0         0
0.5 2.0 4.0 1.0          4           0         0
0.5 2.0 4.0 1.5          2           0         0
0.5 2.0 4.0 2.0       


