<div style="text-align:center">
    <h1 style="color:#1f77b4; font-weight:1000;">
        Coupled Oscillations: Pendulums and Mass‚ÄìSpring Systems
    </h1>
</div>

> A theory-first, computation-forward notebook on coupled oscillators: nonlinear dynamics, linear normal modes, frequency-domain analysis, and visualization.

## üìë Table of Contents

> **Quick Navigation:** Click any section to jump directly. Expand sections to see detailed subsections.


### üéØ Core Setup
- [üìö Introduction](#introduction)
- [‚öôÔ∏è Imports and Notebook Setup](#imports-and-notebook-setup)

---

<div style="display: grid; grid-template-columns: 1fr 1fr; gap: 30px; margin-top: 20px;">

<div>

### ‚öñÔ∏è **Pendulum Systems**

<details open>
<summary><strong>üîó 1. Two Spring-Coupled Pendulums (Nonlinear Model)</strong></summary>

- [1.1 Geometry and Coordinate Definitions (Two Pendulums)](#11-geometry-and-coordinate-definitions-two-pendulums)
- [1.2 Numerical Integration and Animation (Two Pendulums)](#12-numerical-integration-and-animation-two-pendulums)
  - [1.2.1 Nonlinear Equations of Motion in State-Space Form](#121-nonlinear-equations-of-motion-in-state-space-form)
  - [1.2.2 Linearization and Theoretical Normal-Mode Frequencies (Small-Angle)](#122-linearization-and-theoretical-normal-mode-frequencies-small-angle)
  - [1.2.3 Numerical Frequency Estimation (Modal Projection + FFT)](#123-numerical-frequency-estimation-modal-projection--fft)
  - [1.2.4 Animation of the Coupled Pendulums](#124-animation-of-two-spring-coupled-pendulums)
- [1.3 Phase Portraits and Configuration-Space Trajectories](#13-phase-portraits-and-configuration-space-trajectories)
- [1.4 Time-Series Diagnostics (Angles)](#14-frequency-domain-analysis-fft-and-peak-picking)
- [1.5 Frequency-Domain Analysis (FFT and Peak Picking)](#15-frequency-domain-analysis-fft-and-peak-picking)

</details>

<details>
<summary><strong>‚öôÔ∏è 2. Three Spring-Coupled Pendulums (Nonlinear Model)</strong></summary>

- [2.1 Geometry and Coordinates (Three Pendulums)](#21-geometry-and-coordinates-three-pendulums)
- [2.2 Numerical Integration and Animation (Three Pendulums)](#22-numerical-integration-and-animation-three-pendulums)
  - [2.2.1 Equations of Motion (State-Space Form)](#221-equations-of-motion-state-space-form)
  - [2.2.2 Linearization and Theoretical Normal-Mode Frequencies](#222-linearization-and-theoretical-normal-mode-frequencies)
  - [2.2.3 Dominant Frequencies from Numerical Data](#223-dominant-frequencies-from-numerical-data)
  - [2.2.4 Animation of the Three Spring-Coupled Pendulums](#224-animation-of-the-three-pendulum-system)
- [2.3 Time-Series Diagnostics](#23-time-series-diagnostics-angles)
- [2.4 Frequency-Domain Analysis (FFT)](#24-frequency-domain-analysis-fft)
- [2.5 Phase Portraits and Configuration-Space Trajectories](#25-phase-space-and-configuration-space-trajectories)

</details>

</div>

<div>

### „Ä∞Ô∏è **Mass-Spring Systems**

<details open>
<summary><strong>‚ö° 3. Two Coupled Mass-Spring Oscillators</strong></summary>

- [3.1 Geometry and Parameters](#31-geometry-and-parameters)
- [3.2 ODE Model, Normal Modes, and Animation](#32-ode-model-normal-modes-and-animation)
  - [3.2.1 Equations of Motion (State-Space Form)](#321-equations-of-motion-state-space-form)
  - [3.2.2 Theoretical Normal-Mode Frequencies](#322-theoretical-normal-mode-frequencies)
  - [3.2.3 Numerical Frequency Estimation](#323-numerical-frequency-estimation-from-simulation-data)
  - [3.2.4 Animation of the coupled mass-spring system](#324-animating-the-motion-of-coupled-mass-spring-system)
- [3.3 Time-Series and Phase-Space Visualizations](#33-time-series-and-phase-space-visualizations)
- [3.4 Two Coupled Mass-Spring System Analysis](#34-coupled-mass-spring-system-analysis-with-two-masses)
- [3.5 Effect of Coupling Strength on Dynamics](#35-coupling-strength-effects)

</details>

<details>
<summary><strong>üîó 4. N-Mass-Spring Chain: Normal Modes and Dynamics</strong></summary>

- [4.1 State-Space Derivatives for an N-Mass Chain](#41-state-space-derivatives-for-an-n-mass-chain)
- [4.2 Theoretical Normal Modes via Eigenvalue Problem](#42-theoretical-normal-modes-via-eigenvalue-problem)
- [4.3 Numerical Frequency Estimates (FFT)](#43-numerical-frequency-estimates-fft)
- [4.4 Animation of the N-Mass Chain](#44-animation-of-the-n-mass-chain)
- [4.5 Phase Portraits and Configuration-Space Plots (Example: N=3)](#45-phase-portraits-and-configuration-space-plots-example-n3)
- [4.6 Poincar√© Sections (Example: N=3)](#46-poincar√©-sections-example-n3)

</details>

</div>

</div>

---

### üìö Additional Resources
- [üìñ References](#5-references)

</div>

## Introduction

This notebook develops a **unified view of coupled oscillations** across two canonical settings: (i) **spring-coupled pendulums** (nonlinear dynamics with a spring coupling that depends on horizontal displacement), and (ii) **coupled mass‚Äìspring systems** (finite chains with matrix/eigenmode structure).

The workflow is deliberately *theory ‚Üí model ‚Üí numerics ‚Üí diagnostics*:

- **Theory and derivation:** We start from the Lagrangian formulation and derive the governing equations (exact nonlinear form for coupled pendulums, then linearized small-angle limits). For mass‚Äìspring chains we connect the equations of motion to the standard eigenvalue problem that produces normal modes.
- **Normal modes (rigorous + computational):** The small-amplitude normal-mode frequencies are obtained from the **dynamical matrix** (eigenvalues give $\omega^2$). These are compared against **numerical estimates** extracted from time-domain simulations.
- **Numerical integration:** The nonlinear ODE systems are solved using `scipy.integrate.solve_ivp` with clean, reusable derivative functions written in state-space form.
- **Diagnostics and insight:** Multiple complementary views are included‚Äîtime series, phase portraits, configuration-space trajectories, and frequency-domain analysis (FFT + peak-picking). These help identify in-phase/out-of-phase motion, mixed-mode excitation, and beat-like energy exchange. we investigate how the coupling strength between two oscillators controls energy exchange and beat formation in a coupled mass‚Äìspring system. We perform a parametric sweep by varying the middle spring constant ($k_2$) while keeping ($k_1$) and ($k_3$) fixed, and compare weak, medium, and strong coupling regimes. 


- **Visualization:** Animations are produced for each model to connect the mathematics to physical intuition (geometry + dynamics).

By the end, you should be able to (1) derive and implement coupled-oscillator ODEs, (2) compute theoretical normal modes, and (3) validate/interpret simulations using both time-domain and spectral diagnostics.

## Imports and Notebook Setup

This section imports numerical and plotting libraries used throughout the notebook. The `%matplotlib ipympl` backend enables interactive Matplotlib figures and animations inside Jupyter.

In [None]:
# For handling file paths and OS operations
import os

# Numerical and scientific computing
import numpy as np
from scipy.integrate import solve_ivp
from scipy.linalg import eigh
from scipy.signal import find_peaks

# Matplotlib visualization
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation
from matplotlib.gridspec import GridSpec
from matplotlib.patches import Rectangle, FancyArrowPatch
from mpl_toolkits.mplot3d import proj3d

# Jupyter notebook magic command
# %matplotlib ipympl

## 1. Two Spring-Coupled Pendulums (Nonlinear Model)

Before diving into the simulation, we first need to derive the equations of motion for the coupled pendulum system using the Lagrangian formalism. Below is a step-by-step derivation of the exact (nonlinear) Euler‚ÄìLagrange equations for two spring-coupled planar pendulums.

Consider two planar pendulums of (common) length $L$, with bob masses $m_1,m_2$, moving in the same vertical plane. Let the generalized coordinates be the angles $\theta_1(t),\theta_2(t)$ measured from the downward vertical. Take gravity $g$ downward.

I‚Äôll derive the **exact (nonlinear) Euler‚ÄìLagrange equations** (i.e., no small-angle approximation) for the same coupling model used in your simulation: the spring energy depends on the **horizontal separation** of the two bobs, i.e. on $\Delta x = x_2-x_1 = L(\sin\theta_2-\sin\theta_1)$.

**1) Kinematics (exact)**

Place each pivot at the origin of its own local coordinates; equivalently, we only need each bob‚Äôs displacement from its pivot. For pendulum $i\in\{1,2\}$:

$$
x_i = L\sin\theta_i,\qquad y_i = -L\cos\theta_i.
$$

Differentiate:

$$
\dot x_i = L\cos\theta_i\,\dot\theta_i,\qquad
\dot y_i = L\sin\theta_i\,\dot\theta_i.
$$

Hence the speed squared is

$$
v_i^2 = \dot x_i^2+\dot y_i^2
= L^2\dot\theta_i^2(\cos^2\theta_i+\sin^2\theta_i)
= L^2\dot\theta_i^2.
$$

So the kinetic energy is

$$
T = \frac12 m_1 L^2\dot\theta_1^2+\frac12 m_2 L^2\dot\theta_2^2.
$$



**2) Potential energy (exact)**

**(a) Gravitational potential**
Using the standard zero reference at the bottom ($\theta_i=0$), the height increase is
$$
\Delta h_i = L(1-\cos\theta_i).
$$
Therefore
$$
V_g = m_1 gL(1-\cos\theta_1)+m_2 gL(1-\cos\theta_2).
$$

**(b) Spring potential (the coupling used in your model)**
Assume the spring effectively ‚Äúmeasures‚Äù the **horizontal difference** of the bobs. Then the extension variable is
$$
\Delta x = x_2-x_1 = L(\sin\theta_2-\sin\theta_1).
$$
Taking the spring‚Äôs unstretched length such that $\Delta x=0$ is equilibrium for the coupling part, the spring potential is
$$
V_s=\frac12 k(\Delta x)^2
=\frac12 kL^2(\sin\theta_2-\sin\theta_1)^2.
$$

So the total potential is
$$
V = V_g+V_s.
$$


**3) Lagrangian**

$$
\mathcal{L}(\theta_1,\theta_2,\dot\theta_1,\dot\theta_2)=T-V
$$

$$
\boxed{
\mathcal{L} =
\frac12 m_1L^2\dot\theta_1^2+\frac12 m_2L^2\dot\theta_2^2
-\Big[m_1 gL(1-\cos\theta_1)+m_2 gL(1-\cos\theta_2)\Big]
-\frac12 kL^2(\sin\theta_2-\sin\theta_1)^2
}
$$

No approximation has been used so far.

**4) Euler‚ÄìLagrange equations (explicit derivation)**

For each generalized coordinate $\theta_i$:
$$
\frac{d}{dt}\left(\frac{\partial \mathcal{L}}{\partial \dot\theta_i}\right)-
\frac{\partial \mathcal{L}}{\partial \theta_i}=0.
$$

**Equation for $\theta_1$**

**Step 1: derivative w.r.t. $\dot\theta_1$**
$$
\frac{\partial \mathcal{L}}{\partial \dot\theta_1} = m_1L^2\dot\theta_1
\quad\Rightarrow\quad
\frac{d}{dt}\left(\frac{\partial \mathcal{L}}{\partial \dot\theta_1}\right) = m_1L^2\ddot\theta_1.
$$

**Step 2: derivative w.r.t. $\theta_1$**
Only the potential terms depend on $\theta_1$.

- Gravity part:
$$
\frac{\partial}{\partial\theta_1}\Big(m_1 gL(1-\cos\theta_1)\Big) = m_1 gL\sin\theta_1.
$$

- Spring part:
$$
V_s=\frac12 kL^2(\sin\theta_2-\sin\theta_1)^2.
$$
Differentiate using chain rule:
$$
\frac{\partial V_s}{\partial\theta_1} =\frac12 kL^2\cdot 2(\sin\theta_2-\sin\theta_1)\cdot \frac{\partial}{\partial\theta_1}(\sin\theta_2-\sin\theta_1).
$$
But
$$
\frac{\partial}{\partial\theta_1}(\sin\theta_2-\sin\theta_1)= -\cos\theta_1.
$$
So
$$
\boxed{\frac{\partial V_s}{\partial\theta_1} = -kL^2(\sin\theta_2-\sin\theta_1)\cos\theta_1.}
$$

Since $\mathcal{L}=T-V$, we have
$$
\frac{\partial\mathcal{L}}{\partial\theta_1}
= -\frac{\partial V}{\partial\theta_1}
= -\Big(m_1 gL\sin\theta_1+\frac{\partial V_s}{\partial\theta_1}\Big)
= -m_1 gL\sin\theta_1 + kL^2(\sin\theta_2-\sin\theta_1)\cos\theta_1.
$$

**Step 3: Euler‚ÄìLagrange**

$$
m_1L^2\ddot\theta_1-\left[-m_1 gL\sin\theta_1 + kL^2(\sin\theta_2-\sin\theta_1)\cos\theta_1\right]=0
$$
$$
\boxed{
m_1L^2\ddot\theta_1 + m_1 gL\sin\theta_1 - kL^2(\sin\theta_2-\sin\theta_1)\cos\theta_1 = 0
}
$$

Divide by $m_1L^2$:
$$
\boxed{
\ddot\theta_1 =
-\frac{g}{L}\sin\theta_1
+\frac{k}{m_1}\,(\sin\theta_2-\sin\theta_1)\cos\theta_1
}
$$

**Equation for $\theta_2$**

**Step 1:**
$$
\frac{\partial \mathcal{L}}{\partial \dot\theta_2}=m_2L^2\dot\theta_2
\quad\Rightarrow\quad
\frac{d}{dt}\left(\frac{\partial \mathcal{L}}{\partial \dot\theta_2}\right)=m_2L^2\ddot\theta_2.
$$
**Step 2 (gravity):**
$$
\frac{\partial}{\partial\theta_2}\Big(m_2 gL(1-\cos\theta_2)\Big)
= m_2 gL\sin\theta_2.
$$

**Step 2 (spring):**
$$
\frac{\partial V_s}{\partial\theta_2}
=\frac12 kL^2\cdot 2(\sin\theta_2-\sin\theta_1)\cdot \frac{\partial}{\partial\theta_2}(\sin\theta_2-\sin\theta_1)
$$
and
$$
\frac{\partial}{\partial\theta_2}(\sin\theta_2-\sin\theta_1)=\cos\theta_2,
$$
so
$$
\boxed{\frac{\partial V_s}{\partial\theta_2}
= kL^2(\sin\theta_2-\sin\theta_1)\cos\theta_2.}
$$

Then
$$
\frac{\partial\mathcal{L}}{\partial\theta_2}
= -m_2 gL\sin\theta_2 - kL^2(\sin\theta_2-\sin\theta_1)\cos\theta_2.
$$

**Step 3 (Euler‚ÄìLagrange):**

$$
m_2L^2\ddot\theta_2-\left[-m_2 gL\sin\theta_2 - kL^2(\sin\theta_2-\sin\theta_1)\cos\theta_2\right]=0
$$

$$
\boxed{
m_2L^2\ddot\theta_2 + m_2 gL\sin\theta_2 + kL^2(\sin\theta_2-\sin\theta_1)\cos\theta_2 = 0
}
$$

Divide by $m_2L^2$:

$$
\boxed{
\ddot\theta_2 =
-\frac{g}{L}\sin\theta_2
-\frac{k}{m_2}\,(\sin\theta_2-\sin\theta_1)\cos\theta_2
}
$$

**5) Final exact nonlinear equations (no small-angle approximation)**

$$
\boxed{
\begin{aligned}
\ddot\theta_1 &=
-\frac{g}{L}\sin\theta_1
+\frac{k}{m_1}(\sin\theta_2-\sin\theta_1)\cos\theta_1,\\[4pt]
\ddot\theta_2 &=
-\frac{g}{L}\sin\theta_2
-\frac{k}{m_2}(\sin\theta_2-\sin\theta_1)\cos\theta_2.
\end{aligned}}
$$

These are ‚Äúexact‚Äù in the strict sense: they follow exactly from the chosen (nonlinear) Lagrangian with spring energy $V_s=\tfrac12 kL^2(\sin\theta_2-\sin\theta_1)^2$, with no Taylor expansions.

**Small-angle approximation (linearized equations and approximate solutions)**

Starting from the exact nonlinear equations,
$$
\ddot\theta_1 =
-\frac{g}{L}\sin\theta_1
+\frac{k}{m_1}(\sin\theta_2-\sin\theta_1)\cos\theta_1,
\qquad
\ddot\theta_2 =
-\frac{g}{L}\sin\theta_2
-\frac{k}{m_2}(\sin\theta_2-\sin\theta_1)\cos\theta_2,
$$

apply the small-angle approximations valid for $|\theta_1|,|\theta_2|\ll 1$:

$$
\sin\theta \approx \theta,\qquad \cos\theta \approx 1.
$$
Then $\sin\theta_2-\sin\theta_1\approx \theta_2-\theta_1$ and the coupled equations become **linear**:
$$
\boxed{
\begin{aligned}
\ddot\theta_1 &= -\frac{g}{L}\theta_1+\frac{k}{m_1}(\theta_2-\theta_1),\\[4pt]
\ddot\theta_2 &= -\frac{g}{L}\theta_2-\frac{k}{m_2}(\theta_2-\theta_1).
\end{aligned}}
$$

It is often convenient to rewrite them in matrix form. Multiply by $m_iL^2$ to get
$$
\boxed{
\begin{aligned}
m_1L^2\ddot\theta_1 + (m_1gL+kL^2)\theta_1 - kL^2\theta_2 &= 0,\\
m_2L^2\ddot\theta_2 - kL^2\theta_1 + (m_2gL+kL^2)\theta_2 &= 0.
\end{aligned}}
$$
Define the vectors and matrices
$$
\mathbf{\Theta}=
\begin{bmatrix}\theta_1\\ \theta_2\end{bmatrix},\quad
M=
\begin{bmatrix}
m_1L^2 & 0\\
0 & m_2L^2
\end{bmatrix},\quad
K=
\begin{bmatrix}
m_1gL+kL^2 & -kL^2\\
-kL^2 & m_2gL+kL^2
\end{bmatrix}.
$$
Then the equations become
$$
\boxed{M\ddot{\mathbf{\Theta}}+K\mathbf{\Theta}=0.}
$$


**Normal modes (eigenvalue problem) and mode types**

Look for normal-mode solutions of the form
$$
\mathbf{\Theta}(t)=\mathbf{A}\,e^{i\omega t},
$$
so $\ddot{\mathbf{\Theta}}=-\omega^2\mathbf{\Theta}$. Substitution into $M\ddot{\mathbf{\Theta}}+K\mathbf{\Theta}=0$ gives
$$
(K-\omega^2M)\mathbf{A}=0.
$$
A nontrivial mode shape $\mathbf{A}\neq 0$ exists only if
$$
\boxed{\det(K-\omega^2M)=0.}
$$

Carrying out the determinant yields two normal-mode frequencies:
$$
\boxed{\omega_1^2=\frac{g}{L}},\qquad
\boxed{\omega_2^2=\frac{g}{L}+k\left(\frac{1}{m_1}+\frac{1}{m_2}\right)}.
$$

**(1) In-phase mode (symmetric mode)**
For $\omega_1^2=\frac{g}{L}$, the eigenvector condition gives
$$
\boxed{A_1=A_2.}
$$
So the two pendulums swing together:
$$
\boxed{\theta_1(t)=\theta_2(t)\propto \cos(\omega_1 t+\phi).}
$$
Physically: the spring extension is (approximately) zero because $\theta_2-\theta_1\approx 0$, so the spring does not affect the frequency, giving $\omega_1=\sqrt{g/L}$.

**(2) Out-of-phase mode (antisymmetric mode)**
For $\omega_2$, the eigenvector relation becomes
$$
\boxed{A_2=-\frac{m_1}{m_2}A_1.}
$$
So the pendulums swing oppositely (with unequal amplitudes if $m_1\neq m_2$):
$$
\boxed{\theta_2(t)=-\frac{m_1}{m_2}\,\theta_1(t)\propto \cos(\omega_2 t+\phi).}
$$
Physically: the spring is stretched/compressed strongly, increasing the restoring effect and raising the frequency.


**General small-angle solution (superposition of the two modes)**

Because the linearized system is linear, any motion is a sum of the two normal modes:

$$
\boxed{
\mathbf{\Theta}(t) =
C_1\,\mathbf{v}_1\cos(\omega_1 t+\phi_1)
+
C_2\,\mathbf{v}_2\cos(\omega_2 t+\phi_2),
}
$$

where (one convenient choice of eigenvectors) is

$$
\mathbf{v}_1=
\begin{bmatrix}1\\1\end{bmatrix},
\qquad
\mathbf{v}_2=
\begin{bmatrix}1\\-\frac{m_1}{m_2}\end{bmatrix}.
$$
The constants $C_1,C_2,\phi_1,\phi_2$ are determined by the initial conditions $\theta_1(0),\theta_2(0),\dot\theta_1(0),\dot\theta_2(0)$.

**Beats / energy transfer (typical observation)**

If the initial condition is **not** a pure eigenvector, both modes are excited and their superposition can produce **beats**, i.e., a slow exchange of energy between the pendulums.

For the equal-mass case $m_1=m_2=m$, define modal coordinates
$$
q_+(t)=\frac{\theta_1+\theta_2}{2},\qquad q_-(t)=\frac{\theta_1-\theta_2}{2}.
$$
Then the small-angle equations decouple into
$$
\boxed{\ddot q_+ + \omega_1^2 q_+ = 0},\qquad
\boxed{\ddot q_- + \omega_2^2 q_- = 0},
$$
with
$$
\omega_1^2=\frac{g}{L},\qquad \omega_2^2=\frac{g}{L}+\frac{2k}{m}.
$$

A classic beat setup is: displace only pendulum 1 and release from rest,
$$
\theta_1(0)=\theta_0,\quad \theta_2(0)=0,\quad \dot\theta_1(0)=\dot\theta_2(0)=0.
$$
This gives
$$
\theta_1(t)=\frac{\theta_0}{2}\big(\cos\omega_1 t+\cos\omega_2 t\big),\qquad
\theta_2(t)=\frac{\theta_0}{2}\big(\cos\omega_1 t-\cos\omega_2 t\big).
$$
Using trigonometric identities:
$$
\boxed{
\theta_1(t)=\theta_0\cos\!\left(\frac{\omega_1+\omega_2}{2}t\right)\cos\!\left(\frac{\omega_2-\omega_1}{2}t\right),
}
$$
$$
\boxed{
\theta_2(t)=-\theta_0\sin\!\left(\frac{\omega_1+\omega_2}{2}t\right)\sin\!\left(\frac{\omega_2-\omega_1}{2}t\right).
}
$$
So the oscillations occur at a fast ‚Äúcarrier‚Äù frequency $\frac{\omega_1+\omega_2}{2}$, while the amplitude varies slowly with envelope frequency $\frac{\omega_2-\omega_1}{2}$. The beat (energy-exchange) period is
$$
\boxed{T_{\mathrm{beat}}=\frac{2\pi}{\omega_2-\omega_1}.}
$$



### 1.1 Geometry and Coordinate Definitions (Two Pendulums)

We first define the geometry used by the simulator (pivot locations, bob positions, and the spring connection) and verify that the drawing routine correctly reflects **in-phase** and **out-of-phase** motion through spring extension/compression.

In [None]:
def draw_spring_with_hook(start_pos, end_pos, num_coils, radius, hook_length):
    """
    Draw a spring from start_pos to end_pos along the line connecting them.
    """
    start_x, start_y = start_pos
    end_x, end_y = end_pos

    # Calculate the length and angle of the line connecting the masses
    dx = end_x - start_x
    dy = end_y - start_y
    current_dist = np.sqrt(dx**2 + dy**2)

    if current_dist < 1e-9:
        return np.array([start_x]), np.array([start_y])

    # Angle of the line connecting the masses
    angle = np.arctan2(dy, dx)

    # Calculate the actual length available for the coiled part of the spring
    # It's the total distance minus the two fixed hooks
    spring_len = current_dist - 2 * hook_length

    # If masses are too close (closer than 2*hook_length), we clamp spring_len to 0
    # This prevents the spring from inverting, though physically they shouldn't overlap
    if spring_len < 0:
        spring_len = 0
        # Adjust hooks to meet in the middle
        actual_hook_len = current_dist / 2
    else:
        actual_hook_len = hook_length

    # Create spring along the line
    t = np.linspace(0, num_coils * 2 * np.pi, 300)

    # Parametric position along the coiled part (0 to 1)
    s = np.linspace(0, 1, 300)

    # Base position along the line for the COILED part
    # It starts after the first hook
    start_coil_x = start_x + actual_hook_len * np.cos(angle)
    start_coil_y = start_y + actual_hook_len * np.sin(angle)

    # The coil vector spans the spring_len
    coil_dx = spring_len * np.cos(angle)
    coil_dy = spring_len * np.sin(angle)

    x_base = start_coil_x + s * coil_dx
    y_base = start_coil_y + s * coil_dy

    # Perpendicular oscillations
    perp_x = -np.sin(angle) * radius * np.sin(t)
    perp_y = np.cos(angle) * radius * np.sin(t)

    # Spring coordinates
    x_spring = x_base + perp_x
    y_spring = y_base + perp_y

    # Add hooks at both ends
    # Hook at start (straight line from mass 1 to start of coil)
    hook_start_x = np.linspace(start_x, start_coil_x, 50)
    hook_start_y = np.linspace(start_y, start_coil_y, 50)

    # Hook at end (straight line from end of coil to mass 2)
    end_coil_x = start_coil_x + coil_dx
    end_coil_y = start_coil_y + coil_dy

    hook_end_x = np.linspace(end_coil_x, end_x, 50)
    hook_end_y = np.linspace(end_coil_y, end_y, 50)

    # Combine all parts
    x_full = np.concatenate([hook_start_x, x_spring, hook_end_x])
    y_full = np.concatenate([hook_start_y, y_spring, hook_end_y])

    return x_full, y_full


def calculate_pendulum_mass_positions(x_initial, string_length, theta_degrees):
    """Calculate the (x, y) positions of the pendulum mass."""
    theta_radians = np.radians(theta_degrees)
    x_position = x_initial + string_length * np.sin(theta_radians)
    y_position = -string_length * np.cos(theta_radians)
    return x_position, y_position


def calculate_fixed_hook_length(
    ceiling_length, string_length, separation_ratio, hook_ratio=0.12
):
    """
    Calculate the fixed hook length based on the relaxed state (theta=0).
    hook_ratio is the ratio hook_len / relaxed_spring_len (default 0.1).
    """
    # Geometry in relaxed state
    anchor_separation = separation_ratio * ceiling_length
    left_anchor_x = -anchor_separation / 2
    right_anchor_x = anchor_separation / 2

    # Positions at rest (theta = 0)
    left_mass_x, _ = calculate_pendulum_mass_positions(left_anchor_x, string_length, 0)
    right_mass_x, _ = calculate_pendulum_mass_positions(
        right_anchor_x, string_length, 0
    )

    # Distance between masses at rest
    relaxed_dist = np.abs(right_mass_x - left_mass_x)

    # relaxed_dist = relaxed_spring_len + 2 * hook_len
    # hook_len = hook_ratio * relaxed_spring_len
    # => relaxed_dist = relaxed_spring_len * (1 + 2 * hook_ratio)
    relaxed_spring_len = relaxed_dist / (1 + 2 * hook_ratio)
    hook_len = hook_ratio * relaxed_spring_len

    return hook_len


def calculate_plot_limits(
    ceiling_length, left_mass_x, right_mass_x, left_mass_y, right_mass_y, padding=0.15
):
    """Calculate appropriate plot limits."""
    x_min = min(-ceiling_length / 2, left_mass_x, right_mass_x)
    x_max = max(ceiling_length / 2, left_mass_x, right_mass_x)
    y_min = min(left_mass_y, right_mass_y)
    y_max = 0

    x_range = x_max - x_min
    y_range = y_max - y_min

    x_pad = x_range * padding
    y_pad = y_range * padding

    xlim = (x_min - x_pad, x_max + x_pad)
    ylim = (y_min - y_pad, y_max + y_pad)

    return xlim, ylim


def draw_pendulum_string(ax, x_anchor, y_anchor, x_mass, y_mass, color="black", lw=2):
    ax.plot([x_anchor, x_mass], [y_anchor, y_mass], color=color, lw=lw)


def draw_ceiling(ax, ceiling_length, ceiling_height, color="saddlebrown"):
    rect = Rectangle(
        (-ceiling_length / 2, -ceiling_height / 2),
        ceiling_length,
        ceiling_height,
        color=color,
        zorder=5,
    )
    ax.add_patch(rect)
    return rect


def draw_coupled_pendulum(
    theta_1=0,
    theta_2=10,
    ceiling_length=8,
    string_length=12,
    separation_ratio=0.8,
    num_coils=8,
    mass_size=25,
    figsize=(10, 8),
    padding=0.15,
):
    """
    Draw a coupled pendulum system.
    """
    # Calculate system dimensions
    ceiling_height = 0.1 * ceiling_length
    anchor_separation = separation_ratio * ceiling_length

    # Calculate anchor positions
    left_anchor_x = -anchor_separation / 2
    right_anchor_x = anchor_separation / 2
    anchor_y = 0

    # Calculate mass positions
    left_mass_x, left_mass_y = calculate_pendulum_mass_positions(
        left_anchor_x, string_length, theta_1
    )
    right_mass_x, right_mass_y = calculate_pendulum_mass_positions(
        right_anchor_x, string_length, theta_2
    )

    # 1. Calculate FIXED hook length based on relaxed state
    hook_length = calculate_fixed_hook_length(
        ceiling_length, string_length, separation_ratio
    )

    # 2. Fixed spring radius
    spring_radius = 0.1 * ceiling_length

    # Create figure
    fig, ax = plt.subplots(figsize=figsize)

    # Draw ceiling
    draw_ceiling(ax, ceiling_length, ceiling_height)

    # Draw pendulum strings
    draw_pendulum_string(ax, left_anchor_x, anchor_y, left_mass_x, left_mass_y)
    draw_pendulum_string(ax, right_anchor_x, anchor_y, right_mass_x, right_mass_y)

    # Draw spring
    # We pass the mass positions directly. The function handles the fixed hooks.
    x_spring, y_spring = draw_spring_with_hook(
        (left_mass_x, left_mass_y),
        (right_mass_x, right_mass_y),
        num_coils=num_coils,
        radius=spring_radius,
        hook_length=hook_length,
    )
    ax.plot(x_spring, y_spring, color="darkgrey", lw=2)

    # Draw masses
    # left_mass
    ax.plot(left_mass_x, left_mass_y, "ro", ms=mass_size)
    # right_mass
    ax.plot(right_mass_x, right_mass_y, "bo", ms=mass_size)

    # Calculate and set plot limits
    xlim, ylim = calculate_plot_limits(
        ceiling_length, left_mass_x, right_mass_x, left_mass_y, right_mass_y, padding
    )
    ax.set_xlim(xlim)
    ax.set_ylim(ylim)
    ax.set_xticks([])
    ax.set_yticks([])
    ax.set_xticklabels([])
    ax.set_yticklabels([])
    ax.set_title(
        f"Coupled Pendulum Geometry: $\\theta_1 = {theta_1}^\\circ$, $\\theta_2 = {theta_2}^\\circ$"
    )

    return fig, ax


# 1. Relaxed state (hooks should be visible, spring normal)
fig1, ax1 = draw_coupled_pendulum(theta_1=0, theta_2=0)
plt.show()

# 2. Out of phase (Spring compressed)
# The hooks remain same size, coils get squashed
fig2, ax2 = draw_coupled_pendulum(theta_1=10, theta_2=-10)
plt.show()

# 3. In phase (Spring length same as relaxed, just rotated)
fig3, ax3 = draw_coupled_pendulum(theta_1=10, theta_2=10)
plt.show()


### 1.2 Numerical Integration and Animation (Two Pendulums)

We implement the nonlinear ODEs in state-space form, integrate them numerically, and visualize the dynamics. The same run is then analyzed in time and frequency domains to connect simulation output with normal-mode theory.

#### 1.2.1 Nonlinear Equations of Motion in State-Space Form

The integrator expects a first-order system. We therefore write the second-order equations for $(\theta_1,\theta_2)$ as a first-order ODE in the state vector $y=[\theta_1,\dot\theta_1,\theta_2,\dot\theta_2]$.

In [None]:
def coupled_pendulum_derivatives(t, y, g, L, k, m1, m2):
    """
    Compute the derivatives for the coupled pendulum system.

    Uses the exact non-linear equations of motion:
    Œ∏Ãà‚ÇÅ = -(g/L)sin(Œ∏‚ÇÅ) + (k/m‚ÇÅ)(sin(Œ∏‚ÇÇ) - sin(Œ∏‚ÇÅ))cos(Œ∏‚ÇÅ)
    Œ∏Ãà‚ÇÇ = -(g/L)sin(Œ∏‚ÇÇ) - (k/m‚ÇÇ)(sin(Œ∏‚ÇÇ) - sin(Œ∏‚ÇÅ))cos(Œ∏‚ÇÇ)

    Parameters
    ----------
    t : float
        Time (not used explicitly, required by solve_ivp)
    y : array-like
        State vector [Œ∏‚ÇÅ, œâ‚ÇÅ, Œ∏‚ÇÇ, œâ‚ÇÇ] where œâ = dŒ∏/dt
    g : float
        Gravitational acceleration
    L : float
        Length of pendulum rods
    k : float
        Spring constant
    m1, m2 : float
        Masses of the two pendulum bobs

    Returns
    -------
    list
        Derivatives [dŒ∏‚ÇÅ/dt, dœâ‚ÇÅ/dt, dŒ∏‚ÇÇ/dt, dœâ‚ÇÇ/dt]
    """
    theta1, omega1, theta2, omega2 = y

    # Spring extension term
    delta_sin = np.sin(theta2) - np.sin(theta1)

    # Angular accelerations (exact non-linear equations)
    d_theta1 = omega1
    d_omega1 = -(g / L) * np.sin(theta1) + (k / m1) * delta_sin * np.cos(theta1)

    d_theta2 = omega2
    d_omega2 = -(g / L) * np.sin(theta2) - (k / m2) * delta_sin * np.cos(theta2)

    return [d_theta1, d_omega1, d_theta2, d_omega2]


#### 1.2.2 Linearization and Theoretical Normal-Mode Frequencies (Small-Angle)

To obtain closed-form mode frequencies we linearize for small angles ($\sin\theta\approx\theta$, $\cos\theta\approx 1$), which yields a coupled linear system whose eigenvalues give $\omega^2$. The results provide a clean reference for validating the nonlinear simulations at small amplitudes.

In [None]:
def th_normal_modes_double_pendulum(g, L, k, m1, m2):
    """
    Calculate the theoretical normal mode frequencies for the linearized system.

    For equal masses (m1 = m2 = m):
        œâ‚ÇÅ¬≤ = g/L           (in-phase mode)
        œâ‚ÇÇ¬≤ = g/L + 2k/m    (out-of-phase mode)

    For unequal masses, we solve the eigenvalue problem of the linearized system.

    Parameters
    ----------
    g, L, k, m1, m2 : float
        System parameters

    Returns
    -------
    tuple
        (omega1, omega2, f1, f2) - angular frequencies and frequencies in Hz
    """
    # For the linearized system with potentially different masses
    # The equations become:
    # m1*L¬≤*Œ∏Ãà‚ÇÅ + m1*g*L*Œ∏‚ÇÅ - k*L¬≤*(Œ∏‚ÇÇ - Œ∏‚ÇÅ) = 0
    # m2*L¬≤*Œ∏Ãà‚ÇÇ + m2*g*L*Œ∏‚ÇÇ + k*L¬≤*(Œ∏‚ÇÇ - Œ∏‚ÇÅ) = 0

    # This gives a matrix equation: M*Œ∏Ãà = -K*Œ∏
    # where M = diag(m1*L¬≤, m2*L¬≤) and K is the stiffness matrix

    # Construct the dynamical matrix A = M^(-1) * K
    # For eigenvalue problem: det(K - œâ¬≤M) = 0

    # Coefficients for the linearized system
    a11 = g / L + k / m1
    a12 = -k / m1
    a21 = -k / m2
    a22 = g / L + k / m2

    # Form the matrix [[a11, a12], [a21, a22]]
    # Eigenvalues give œâ¬≤
    A = np.array([[a11, a12], [a21, a22]])
    eigenvalues = np.linalg.eigvals(A)

    # Sort eigenvalues (œâ¬≤ values)
    omega_sq = np.sort(np.real(eigenvalues))

    omega1 = np.sqrt(omega_sq[0])  # Lower frequency (in-phase)
    omega2 = np.sqrt(omega_sq[1])  # Higher frequency (out-of-phase)

    # Convert to Hz
    f1 = omega1 / (2 * np.pi)
    f2 = omega2 / (2 * np.pi)

    return omega1, omega2, f1, f2


#### 1.2.3 Numerical Frequency Estimation (Modal Projection + FFT)

From the simulated time series we estimate dominant frequencies. A particularly robust approach is to project $[\theta_1(t),\theta_2(t)]$ onto the **linear normal-mode coordinates** and then perform FFT peak-picking on each modal coordinate.

In `simulate_double_pendulum`, the ‚Äúdominant frequencies‚Äù are extracted from the numerically solved angle time series by an FFT-based modal analysis:

- **Step 1: numerically solve for** $\theta_1(t),\theta_2(t)$  
  `solve_ivp` integrates the nonlinear ODEs and evaluates the dense solution on a uniform grid `t_eval`; the arrays `theta1 = y[0]` and `theta2 = y[2]` are the sampled signals used for frequency analysis.

- **Step 2: compute mode frequencies from the *data***  
  `simulate_double_pendulum` calls `est_freqs_double_pendulum(t_eval, theta1, theta2, ...)` to estimate two frequencies numerically.

- **Step 3: project motion onto normal-mode coordinates (key idea)**  
  Inside `est_freqs_double_pendulum`, the code:
  - stacks $\theta(t)=\begin{bmatrix}\theta_1(t)\\ \theta_2(t)\end{bmatrix}$ and removes the mean (DC offset) per channel
  - builds the *small-angle linear* generalized eigenproblem $K\mathbf{v}=\omega^2 M\mathbf{v}$ (with $M=\mathrm{diag}(m_1,m_2)$ and $K=\begin{bmatrix}m_1 g/L+k & -k\\-k & m_2 g/L+k\end{bmatrix}$)
  - solves it using `eigh(K, M)` and then forms **modal coordinates**
    $$
    \mathbf{q}(t)=V^T\,M\,\theta(t)
    $$
    so each $q_i(t)$ ideally contains (mostly) one mode.

- **Step 4: FFT and peak-picking**  
  For each modal signal $q_i(t)$, it optionally applies a Hann window (reduces spectral leakage), computes the real FFT `rfft`, builds the frequency axis with `rfftfreq(n, dt)`, and takes magnitudes $|Q_i(f)|$.

- **Step 5: choose the dominant frequency (ignore DC + ignore unexcited modes)**  
  It ignores the DC bin (`amp[i, 0]`), finds the index of the maximum amplitude peak `argmax(amp[i, 1:])`, and returns that bin‚Äôs frequency. If a mode‚Äôs spectrum is tiny relative to the strongest one (based on an RMS threshold), it reports that mode as ‚Äúnot excited‚Äù.

This is why the printed `f1_num, f2_num` are best interpreted as ‚Äúdominant frequencies of the two **normal-mode components** present in the simulated motion,‚Äù rather than ‚ÄúFFT peaks of $\theta_1$ and $\theta_2$ directly.‚Äù

In [None]:
def est_freqs_double_pendulum(t, theta1, theta2, g, L, k, m1, m2, use_hann_window=True):
    """
    Estimate the 2 normal-mode frequencies for unequal masses/springs by:
    1) solving the *generalized* eigenproblem K v = œâ¬≤ M v (small-angle linear model)
    2) projecting simulated Œ∏(t) onto the eigenvectors (modal coordinates)
    3) FFT on each modal coordinate and picking the dominant peak

    Parameters
    ----------
    t : array
        Time array
    theta1, theta2 : array
        Angle arrays from simulation
    g, L, k, m1, m2 : float
        System parameters
    use_hann_window : bool
        Whether to apply a Hann window before FFT

    Returns
    -------
    tuple
        (f1_est, f2_est) - Estimated frequencies in Hz
    """
    t = np.asarray(t)
    dt = t[1] - t[0]
    n = t.size

    # Stack angles: shape (2, n)
    theta = np.vstack([np.asarray(theta1), np.asarray(theta2)])

    # Remove DC offset per channel
    theta = theta - theta.mean(axis=1, keepdims=True)

    # --- Build generalized eigenproblem (small-angle linearization) ---
    # For coordinates Œ∏:
    #   T = 1/2 Œ£ m_i L^2 Œ∏Ãá_i^2  -> MŒ∏ = diag(m_i L^2)
    #   V = 1/2 Œ£ m_i g L Œ∏_i^2 + 1/2 k L^2 (Œ∏2-Œ∏1)^2
    # Divide by L^2 => generalized eigen in terms of:
    #   M = diag(m_i)
    #   K = diag(m_i g/L) + spring contributions
    omega0_sq = g / L

    M = np.diag([m1, m2]).astype(float)
    K = np.array(
        [
            [m1 * omega0_sq + k, -k],
            [-k, m2 * omega0_sq + k],
        ],
        dtype=float,
    )

    # eigh for symmetric generalized eigenproblem -> real œâ¬≤, eigenvectors V
    # SciPy normalizes eigenvectors so that V.T @ M @ V = I
    omega_sq, V = eigh(K, M)  # omega_sq sorted ascending
    omega_sq = np.maximum(omega_sq, 0.0)

    # Modal coordinates q(t) = V^T M Œ∏(t)
    q = V.T @ (M @ theta)  # shape (2, n)

    # Optional Hann window to reduce spectral leakage
    if use_hann_window:
        w = np.hanning(n)
        q = q * w

    # FFT on each modal coordinate
    freqs = np.fft.rfftfreq(n, dt)
    Q = np.fft.rfft(q, axis=1)
    amp = np.abs(Q)

    rms = np.sqrt(np.mean(amp**2, axis=1))
    rms_rel = rms / (np.max(rms) + 1e-30)

    f_est = []
    for i in range(2):
        # Threshold can be tuned (e.g. 1e-2 or 1e-3)
        if rms_rel[i] < 1e-2:
            f_est.append(0.0)
            continue

        # Ignore DC component (index 0)
        idx = int(np.argmax(amp[i, 1:])) + 1
        f_est.append(float(freqs[idx]))

    return f_est[0], f_est[1]

#### 1.2.4 Animation of Two Spring-Coupled Pendulums

The function `simulate_double_pendulum` defined below, has two main phases: **(1) simulate the dynamics** and **(2) animate precomputed results**.

- **Numerical simulation (ODE solve)**  
  It converts the initial angles (degrees ‚Üí radians), builds the state vector $y=[\theta_1,\omega_1,\theta_2,\omega_2]$, and integrates the coupled nonlinear equations using `solve_ivp(..., method="RK45")`. It then evaluates the *dense solution* on a uniform time grid `t_eval` and extracts arrays $\theta_1(t)$, $\theta_2(t)$, $\omega_1(t)$, $\omega_2(t)$.

- **Precomputation for fast animation**  
  Before animating, it computes everything needed *for every frame*:
  - bob positions from geometry: $x_i(t)=x_{\text{pivot},i}+L\sin\theta_i(t)$ and $y_i(t)=y_{\text{pivot}}-L\cos\theta_i(t)$  
  - spring extension $\Delta x(t)=L(\sin\theta_2-\sin\theta_1)$ and energies (KE, gravitational PE, spring PE, total)  
  This avoids recomputing expensive pieces inside the animation loop.

- **Matplotlib scene setup (‚Äúartists‚Äù)**  
  It creates a `fig, ax`, sets plot limits, draws the ceiling/pivots/reference lines, and initializes plot objects (rods as lines, spring as a line, masses as markers, and short ‚Äútrace‚Äù lines). It also creates text boxes for time/angles/energies and frequency info.

- **Animation update loop**  
  Two nested functions define the animation:
  - `init()` clears all artists (empty data) for a clean start.
  - `animate(frame)` uses the precomputed arrays at index `frame` to:
    - update rod endpoints,
    - redraw the spring shape using `draw_spring_with_hook(start_pos=(x1,y1), end_pos=(x2,y2), ...)`,
    - move the mass markers,
    - append to and render short traces,
    - update the time and the info text (current $\theta$‚Äôs, $\dot\theta$‚Äôs, spring extension, energies). 

- **Create and (optionally) save the animation**  
  It builds `FuncAnimation(fig, animate, init_func=init, frames=len(t_eval), interval=1000/fps, ...)`. If `save_anim=True`, it writes the output to `OUTPUTS/ANIMATIONS/...` (tries ffmpeg first, falls back to pillow); otherwise it displays with `plt.show()`.

In [None]:
def simulate_double_pendulum(
    theta_1_init=30.0,
    theta_2_init=0.0,
    m1=1.0,
    m2=1.0,
    k=5.0,
    L=2.0,
    g=9.81,
    simulation_time=20.0,
    fps=30,
    save_anim=False,
    filename=None,
):
    """
    Simulate and animate a coupled pendulum system.

    Parameters
    ----------
    theta_1_init : float
        Initial angle of pendulum 1 in degrees
    theta_2_init : float
        Initial angle of pendulum 2 in degrees
    m1, m2 : float
        Masses of pendulum bobs
    k : float
        Spring constant
    L : float
        Length of pendulum rods
    g : float
        Gravitational acceleration
    simulation_time : float
        Total simulation time in seconds
    fps : int
        Frames per second for animation
    save_anim : bool
        Whether to save the animation
    filename : str or None
        Filename for saved animation (auto-generated if None)

    Returns
    -------
        anim - animation object
    """
    # Convert initial angles to radians
    theta1_0 = np.radians(theta_1_init)
    theta2_0 = np.radians(theta_2_init)

    # Initial state: [Œ∏‚ÇÅ, œâ‚ÇÅ, Œ∏‚ÇÇ, œâ‚ÇÇ] (starting from rest)
    y0 = [theta1_0, 0.0, theta2_0, 0.0]

    # Time span
    t_span = (0, simulation_time)
    n_frames = int(simulation_time * fps) + 1
    t_eval = np.linspace(0, simulation_time, n_frames)

    # Solve the ODE
    print("Solving differential equations...")
    solution = solve_ivp(
        coupled_pendulum_derivatives,
        t_span,
        y0,
        args=(g, L, k, m1, m2),
        method="RK45",
        rtol=1e-5,
        atol=1e-7,
        dense_output=True,
    )
    if solution.sol is None:
        raise RuntimeError(
            "solve_ivp did not return a dense solution (solution.sol is None)"
        )

    y = solution.sol(t=t_eval)
    theta1 = y[0]
    omega1 = y[1]
    theta2 = y[2]
    omega2 = y[3]

    print(f"Solution computed: {len(t_eval)} time steps")

    # =========================================================================
    # PRECOMPUTE SPRING EXTENSION + ENERGIES
    # =========================================================================

    theta1_deg_all = np.degrees(theta1)
    theta2_deg_all = np.degrees(theta2)

    spring_ext_all = L * (np.sin(theta2) - np.sin(theta1))

    KE_all = 0.5 * m1 * (L * omega1) ** 2 + 0.5 * m2 * (L * omega2) ** 2
    PE_grav_all = -m1 * g * L * np.cos(theta1) - m2 * g * L * np.cos(theta2)
    PE_spring_all = 0.5 * k * spring_ext_all**2
    total_E_all = KE_all + PE_grav_all + PE_spring_all

    # =========================================================================
    # FREQUENCY ANALYSIS
    # =========================================================================

    # Theoretical normal modes
    omega1_theory, omega2_theory, f1_theory, f2_theory = (
        th_normal_modes_double_pendulum(g, L, k, m1, m2)
    )

    # Numerical frequency estimation
    f1_num, f2_num = est_freqs_double_pendulum(t_eval, theta1, theta2, g, L, k, m1, m2)

    def _format_freq(f):
        return f"{f:.4f} Hz" if f > 0 else "N/A (Not Excited)"

    f1_str = _format_freq(f1_num)
    f2_str = _format_freq(f2_num)

    print("\nNormal Mode Frequencies:")
    print(f"  Theoretical: œâ‚ÇÅ = {omega1_theory:.4f} rad/s (f‚ÇÅ = {f1_theory:.4f} Hz)")
    print(f"              œâ‚ÇÇ = {omega2_theory:.4f} rad/s (f‚ÇÇ = {f2_theory:.4f} Hz)")
    print(f"  Numerical:   f‚ÇÅ ‚âà {f1_str}, f‚ÇÇ ‚âà {f2_str}")

    # =========================================================================
    # ANIMATION SETUP
    # =========================================================================

    # Pivot positions
    pivot_separation = 1.5  # Distance between pivots
    pivot1_x = -pivot_separation / 2
    pivot2_x = pivot_separation / 2
    pivot_y = 0

    ceiling_length = pivot_separation + 1.0  # Add some margin
    separation_ratio = pivot_separation / ceiling_length
    spring_hook_length = calculate_fixed_hook_length(
        ceiling_length, L, separation_ratio, hook_ratio=0.1
    )
    spring_radius = 0.03 * ceiling_length  # Radius of the coiled part
    spring_num_coils = 10  # Number of coils in the spring
    ceiling_height = 0.05 * ceiling_length  # Height of the ceiling bar

    # Precompute mass positions
    x1_all = pivot1_x + L * np.sin(theta1)
    y1_all = pivot_y - L * np.cos(theta1)
    x2_all = pivot2_x + L * np.sin(theta2)
    y2_all = pivot_y - L * np.cos(theta2)

    # Figure setup
    fig, ax = plt.subplots(figsize=(12, 8))
    fig.subplots_adjust(top=0.87, bottom=0.02, left=0.02, right=0.81)
    fig.suptitle(
        "Spring-Coupled Double Pendulum Motion", fontsize=16, fontweight="bold"
    )

    xlim, ylim = calculate_plot_limits(
        ceiling_length,
        np.min(x1_all),
        np.max(x2_all),
        np.min(y1_all),
        np.min(y2_all),
        padding=0.2,
    )

    ax.set_xlim(xlim)
    ax.set_ylim(ylim)
    ax.grid(False)
    ax.set_xticks([])
    ax.set_yticks([])
    ax.set_xticklabels([])
    ax.set_yticklabels([])

    # Determine motion type based on initial conditions
    if abs(theta_1_init - theta_2_init) < 1e-6:
        motion_type = "In-Phase Oscillation (Normal Mode 1)"
    elif abs(theta_1_init + theta_2_init) < 1e-6:
        motion_type = "Out-of-Phase Oscillation (Normal Mode 2)"
    elif theta_1_init == 0 or theta_2_init == 0:
        motion_type = "Energy Transfer (Beats Phenomenon)"
    else:
        motion_type = "Mixed-Mode Oscillation (Superposition)"

    title_text = (
        f"Coupled Pendulum System: {motion_type}\n"
        f"Normal Mode Frequencies: "
        f"$\\omega_1={omega1_theory:.3f}$ rad/s (in-phase), "
        f"$\\omega_2={omega2_theory:.3f}$ rad/s (out-of-phase)"
    )

    ax.set_title(title_text, fontsize=13, fontweight="bold", pad=10)

    # Draw ceiling
    draw_ceiling(ax, ceiling_length, ceiling_height)

    # Pivot points
    ax.plot(pivot1_x, pivot_y, "o", color="black", markersize=15, zorder=6)
    ax.plot(pivot2_x, pivot_y, "o", color="black", markersize=15, zorder=6)

    # Reference lines for equilibrium positions
    ax.axvline(
        pivot1_x, ymin=0, ymax=1, color="gray", ls="--", lw=1, zorder=0, alpha=0.7
    )
    ax.axvline(
        pivot2_x, ymin=0, ymax=1, color="gray", ls="--", lw=1, zorder=0, alpha=0.7
    )

    # Initialize plot elements
    # Rods
    (rod1_line,) = ax.plot([], [], "k-", lw=2.5, zorder=3)
    (rod2_line,) = ax.plot([], [], "k-", lw=2.5, zorder=3)

    # Spring
    (spring_line,) = ax.plot([], [], "gray", lw=1.5, zorder=2)

    # Masses (as markers)
    mass_radius_min = 20
    mass_radius_ref = 25
    avg_mass = (m1 + m2) / 2.0 if (m1 + m2) > 0 else 1.0
    mass_1_size = max(mass_radius_min, mass_radius_ref * np.cbrt(m1 / avg_mass))
    mass_2_size = max(mass_radius_min, mass_radius_ref * np.cbrt(m2 / avg_mass))

    (mass1_plot,) = ax.plot(
        [],
        [],
        "o",
        color="crimson",
        markersize=mass_1_size,
        zorder=4,
        label=f"Mass 1 ($m_1$={m1} kg)",
    )
    (mass2_plot,) = ax.plot(
        [],
        [],
        "o",
        color="royalblue",
        markersize=mass_2_size,
        zorder=4,
        label=f"Mass 2 ($m_2$={m2} kg)",
    )

    # Traces
    trace1_x, trace1_y = [], []
    trace2_x, trace2_y = [], []
    (trace1_line,) = ax.plot([], [], "crimson", lw=0.8, alpha=0.5, zorder=1)
    (trace2_line,) = ax.plot([], [], "royalblue", lw=0.8, alpha=0.5, zorder=1)

    # Dynamic text annotations
    time_text = ax.text(
        0.02,
        0.98,
        "",
        transform=ax.transAxes,
        fontsize=12,
        va="top",
        fontfamily="monospace",
        bbox=dict(boxstyle="round", facecolor="wheat", alpha=0.8),
    )

    # Info box with angles and frequencies
    info_box = ax.text(
        1.02,
        0.99,
        "",
        transform=ax.transAxes,
        fontsize=10,
        va="top",
        ha="left",
        fontfamily="monospace",
        bbox=dict(boxstyle="round", facecolor="lightcyan", alpha=0.8),
    )

    # Frequency comparison box
    freq_text = (
        "Frequency Comparison:\n"
        f"{'-' * 20}\n"
        f"Theoretical Normal Modes:\n"
        f"$f_1={f1_theory:.3f}$ Hz\n$f_2={f2_theory:.3f}$ Hz\n"
        f"Numerical Normal Modes:\n"
        f"$f_1\\approx${f1_str}\n$f_2\\approx${f2_str}\n"
        f"Ratio $\\omega_2/\\omega_1 = {omega2_theory / omega1_theory:.3f}$"
    )

    sys_info_text = (
        f"System Parameters:\n"
        f"{'-' * 20}\n"
        f"Masses: \n"
        f"$m_1$={m1} kg, $m_2$={m2} kg\n"
        f"Spring: $k$={k} N/m\n"
        f"Length: $L$={L} m\n"
        f"Gravity: $g$={g} m/s¬≤\n"
        f"Initial Angles:\n"
        f"$\\theta_1={theta_1_init:.2f}^{{\\circ}}$, $\\theta_2={theta_2_init:.2f}^{{\\circ}}$"
    )
    ax.text(
        1.02,
        0.01,
        freq_text,
        transform=ax.transAxes,
        fontsize=9,
        va="bottom",
        ha="left",
        fontfamily="monospace",
        bbox=dict(boxstyle="round", facecolor="honeydew", alpha=0.8),
    )
    ax.text(
        1.02,
        0.5,
        sys_info_text,
        transform=ax.transAxes,
        fontsize=10,
        va="center",
        ha="left",
        fontfamily="monospace",
        bbox=dict(boxstyle="round", facecolor="lavender", alpha=0.8),
    )

    # Legend
    ax.legend(loc="lower right", fontsize=10, markerscale=0.5, ncol=2)

    # =========================================================================
    # ANIMATION FUNCTIONS
    # =========================================================================

    def init():
        """Initialize animation."""
        rod1_line.set_data([], [])
        rod2_line.set_data([], [])
        spring_line.set_data([], [])
        mass1_plot.set_data([], [])
        mass2_plot.set_data([], [])
        trace1_line.set_data([], [])
        trace2_line.set_data([], [])
        time_text.set_text("")
        info_box.set_text("")
        return (
            rod1_line,
            rod2_line,
            spring_line,
            mass1_plot,
            mass2_plot,
            trace1_line,
            trace2_line,
            time_text,
            info_box,
        )

    def animate(frame):
        """Update animation for each frame."""
        # Current time and angles
        current_time = t_eval[frame]
        theta1_deg = theta1_deg_all[frame]
        theta2_deg = theta2_deg_all[frame]

        # Current positions
        x1, y1 = x1_all[frame], y1_all[frame]
        x2, y2 = x2_all[frame], y2_all[frame]

        # Update rods
        rod1_line.set_data([pivot1_x, x1], [pivot_y, y1])
        rod2_line.set_data([pivot2_x, x2], [pivot_y, y2])

        spring_x, spring_y = draw_spring_with_hook(
            start_pos=(x1, y1),
            end_pos=(x2, y2),
            num_coils=spring_num_coils,
            radius=spring_radius,
            hook_length=spring_hook_length,
        )
        spring_line.set_data(spring_x, spring_y)

        # Update masses
        mass1_plot.set_data([x1], [y1])
        mass2_plot.set_data([x2], [y2])

        # Update traces
        trace1_x.append(x1)
        trace1_y.append(y1)
        trace2_x.append(x2)
        trace2_y.append(y2)

        # Limit trace length for performance
        max_trace = 50
        if len(trace1_x) > max_trace:
            trace1_x.pop(0)
            trace1_y.pop(0)
            trace2_x.pop(0)
            trace2_y.pop(0)

        trace1_line.set_data(trace1_x, trace1_y)
        trace2_line.set_data(trace2_x, trace2_y)

        # Update timer
        time_text.set_text(f"Time: {current_time:.2f} s")

        # Update info box with current angles and angular velocities
        omega1_val = omega1[frame]
        omega2_val = omega2[frame]

        # Accessing precomputed spring extension and energies
        spring_ext = spring_ext_all[frame]
        KE = KE_all[frame]
        PE_grav = PE_grav_all[frame]
        PE_spring = PE_spring_all[frame]
        total_E = total_E_all[frame]

        info_text = (
            f"$\\theta_1 = {theta1_deg:+7.2f}^{{\\circ}}$\n"
            f"$\\theta_2 = {theta2_deg:+7.2f}^{{\\circ}}$\n"
            f"$\\dot{{\\theta}}_1 = {omega1_val:+6.3f}$ rad/s\n"
            f"$\\dot{{\\theta}}_2 = {omega2_val:+6.3f}$ rad/s\n"
            f"‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ\n"
            f"Spring $\\Delta x = {spring_ext:+.3f}$ m\n"
            f"KE = {KE:.3f} J\n"
            f"PE = {PE_grav + PE_spring:.3f} J\n"
            f"Total E = {total_E:.3f} J"
        )
        info_box.set_text(info_text)

        return (
            rod1_line,
            rod2_line,
            spring_line,
            mass1_plot,
            mass2_plot,
            trace1_line,
            trace2_line,
            time_text,
            info_box,
        )

    # Create animation
    print("\nCreating animation...")
    anim = FuncAnimation(
        fig,
        animate,
        init_func=init,
        frames=len(t_eval),
        interval=int(1000 / fps),
        blit=False,
        repeat=False,
    )

    # Save animation if requested
    if save_anim:
        if filename is None:
            # Create default filename
            filename = (
                f"coupled_pendulum_m1_{m1}_m2_{m2}_k_{k}_L_{L}_"
                f"theta1_{theta_1_init}_theta2_{theta_2_init}.gif"
            )

        # Ensure ANIMATIONS directory exists
        anim_dir = "OUTPUTS/ANIMATIONS/coupled_pendulum/double_pendulum"
        os.makedirs(anim_dir, exist_ok=True)

        save_path = os.path.join(anim_dir, filename)
        print(f"Saving animation to: {save_path}")

        try:
            anim.save(save_path, writer="ffmpeg", fps=fps, codec="gif", dpi=120)
            print("Animation saved successfully!")
        except Exception as e:
            print(f"Error saving animation: {e}")
            print("Trying pillow writer...")
            try:
                gif_path = save_path.replace(".mp4", ".gif")
                anim.save(gif_path, writer="pillow", fps=fps)
                print(f"Animation saved as GIF: {gif_path}")
            except Exception as e2:
                print(f"Could not save animation: {e2}")

        plt.close(fig)
    else:
        plt.show()

    return anim


# =============================================================================
# EXAMPLE USAGE AND DEMONSTRATIONS
# =============================================================================


def demo_normal_modes():
    """
    Demonstrate the two normal modes of the coupled pendulum.
    """
    print("=" * 60)
    print("NORMAL MODE DEMONSTRATIONS")
    print("=" * 60)

    # Parameters
    m1, m2, k, L, g = 1.0, 1.0, 10.0, 2.0, 9.81

    # Calculate theoretical frequencies
    omega1, omega2, f1, f2 = th_normal_modes_double_pendulum(g, L, k, m1, m2)
    T1 = 1 / f1
    T2 = 1 / f2

    print("\nSystem Parameters:")
    print(f"  m‚ÇÅ = m‚ÇÇ = {m1} kg")
    print(f"  k = {k} N/m")
    print(f"  L = {L} m")
    print(f"  g = {g} m/s¬≤")
    print("\nTheoretical Normal Modes:")
    print(f"  Mode 1 (in-phase):     œâ‚ÇÅ = {omega1:.4f} rad/s, T‚ÇÅ = {T1:.4f} s")
    print(f"  Mode 2 (out-of-phase): œâ‚ÇÇ = {omega2:.4f} rad/s, T‚ÇÇ = {T2:.4f} s")
    print(f"  Frequency ratio: œâ‚ÇÇ/œâ‚ÇÅ = {omega2 / omega1:.4f}")

    return omega1, omega2


def main():
    """Main function to run simulations and demonstrations."""
    header = "STARTING SPRING-COUPLED DOUBLE PENDULUM SIMULATION"
    print("\n" + "=" * len(header))
    print(header)
    print("=" * len(header))

    default_params = {
        "theta_1_init": 0.0,
        "theta_2_init": 10.0,
        "m1": 1.0,
        "m2": 1.0,
        "k": 5.0,
        "L": 2.0,
        "g": 9.81,
        "simulation_time": 10.0,
        "fps": 30,
    }

    use_default = input("Use default simulation parameters? (y/n): ").strip().lower()
    if use_default == "y":
        params = default_params.copy()
    else:
        params = {}
        try:
            params["theta_1_init"] = float(
                input(
                    f"Enter initial angle for pendulum 1 (degrees)[{default_params['theta_1_init']}] degrees): "
                )
                or default_params["theta_1_init"]
            )
            params["theta_2_init"] = float(
                input(
                    f"Enter initial angle for pendulum 2 (degrees)[{default_params['theta_2_init']}] degrees): "
                )
                or default_params["theta_2_init"]
            )
            params["m1"] = float(
                input(
                    f"Enter mass of the first pendulum (kg)[{default_params['m1']}] kg): "
                )
                or default_params["m1"]
            )
            params["m2"] = float(
                input(
                    f"Enter mass of the second pendulum (kg)[{default_params['m2']}] kg): "
                )
                or default_params["m2"]
            )
            params["k"] = float(
                input(f"Enter spring constant (N/m)[{default_params['k']}] N/m): ")
                or default_params["k"]
            )
            params["L"] = float(
                input(
                    f"Enter length of the pendulum rods (m)[{default_params['L']}] m): "
                )
                or default_params["L"]
            )
            params["g"] = float(
                input(
                    f"Enter gravitational acceleration (m/s¬≤)[{default_params['g']}] m/s¬≤): "
                )
                or default_params["g"]
            )
            params["simulation_time"] = float(
                input(
                    f"Enter simulation time (seconds)[{default_params['simulation_time']}] seconds): "
                )
                or default_params["simulation_time"]
            )
            params["fps"] = int(
                input(
                    f"Enter frames per second for animation [{default_params['fps']}]: "
                )
                or default_params["fps"]
            )
        except ValueError:
            print("Invalid input. Using default parameters.")
            params = default_params.copy()

    save_anim = input("Save animation to file? (y/n): ").strip().lower() == "y"
    filename = None
    if save_anim:
        filename_input = input(
            "Enter filename for animation (or press Enter for default): "
        )
        if filename_input:
            if not filename_input.endswith(".gif"):
                filename_input += ".gif"
            filename = filename_input

    print("\nRunning simulation with parameters:")
    for key, value in params.items():
        print(f"  {key}: {value}")

    demo_normal_modes()

    animation = simulate_double_pendulum(
        theta_1_init=params["theta_1_init"],
        theta_2_init=params["theta_2_init"],
        m1=params["m1"],
        m2=params["m2"],
        k=params["k"],
        L=params["L"],
        g=params["g"],
        simulation_time=params["simulation_time"],
        fps=params["fps"],
        save_anim=save_anim,
        filename=filename,
    )

    return animation


if __name__ == "__main__":
    animation = main()


### 1.3 Time-Series Diagnostics (Angles)

This code simulates a **spring-coupled double pendulum system** with comprehensive visualization:

**Key Features:**

**Left Column (Main Animation):**
- Visual representation of two pendulums connected by a spring
- Real-time display of pendulum rods, masses, and spring deformation
- Motion traces showing trajectory paths
- Dynamic info boxes displaying:
    - Current angles ($\theta_1$, $\theta_2$) and angular velocities
    - Spring extension
    - Energy components (kinetic, potential, total)
    - System parameters and normal mode frequencies

**Right Column (Time Series Plots):**
- **Top plot**: $\theta_1$ vs time
- **Bottom plot**: $\theta_2$ vs time
- Both show the complete angle evolution with a moving marker indicating current position
- Helps visualize oscillation patterns, beats phenomena, and normal modes

**Physics Implementation:**
- Solves coupled ODEs using `solve_ivp` with RK45 method
- Computes theoretical and numerical normal mode frequencies
- Identifies motion types (in-phase, out-of-phase, energy transfer, mixed-mode)
- Includes comprehensive energy tracking for validation


In [None]:
def double_pendulum_animation_with_plots(
    theta_1_init=0.0,
    theta_2_init=10.0,
    m1=1.0,
    m2=1.0,
    k=5.0,
    L=2.0,
    g=9.81,
    simulation_time=20.0,
    fps=30,
    save_format="gif",
    save_anim=False,
    filename=None,
):
    """
    Simulate and animate a coupled pendulum system.

    Parameters
    ----------
    theta_1_init : float
        Initial angle of pendulum 1 in degrees
    theta_2_init : float
        Initial angle of pendulum 2 in degrees
    m1, m2 : float
        Masses of pendulum bobs
    k : float
        Spring constant
    L : float
        Length of pendulum rods
    g : float
        Gravitational acceleration
    simulation_time : float
        Total simulation time in seconds
    fps : int
        Frames per second for animation
    save_format : str
        Format to save animation ('gif' or 'mp4')
    save_anim : bool
        Whether to save the animation
    filename : str or None
        Filename for saved animation (auto-generated if None)

    Returns
    -------
    tuple
        (fig, anim) - Figure and animation objects
    """

    # Convert initial angles to radians
    theta1_0 = np.radians(theta_1_init)
    theta2_0 = np.radians(theta_2_init)

    # Initial state: [Œ∏‚ÇÅ, œâ‚ÇÅ, Œ∏‚ÇÇ, œâ‚ÇÇ] (starting from rest)
    y0 = [theta1_0, 0.0, theta2_0, 0.0]

    # Time span
    t_span = (0, simulation_time)
    n_frames = int(simulation_time * fps) + 1
    t = np.linspace(0, simulation_time, n_frames)

    # Solve the ODE
    print("Solving differential equations...")
    solution = solve_ivp(
        coupled_pendulum_derivatives,
        t_span,
        y0,
        args=(g, L, k, m1, m2),
        method="RK45",
        rtol=1e-5,
        atol=1e-7,
        dense_output=True,
    )

    if solution.sol is None:
        raise RuntimeError(
            "solve_ivp did not return a dense solution (solution.sol is None)"
        )

    y = solution.sol(t=t)
    theta1 = y[0]
    omega1 = y[1]
    theta2 = y[2]
    omega2 = y[3]

    print(f"Solution computed: {len(t)} time steps")

    # =========================================================================
    # PRECOMPUTE SPRING EXTENSION + ENERGIES
    # =========================================================================

    theta1_deg_all = np.degrees(theta1)
    theta2_deg_all = np.degrees(theta2)

    spring_ext_all = L * (np.sin(theta2) - np.sin(theta1))

    KE_all = 0.5 * m1 * (L * omega1) ** 2 + 0.5 * m2 * (L * omega2) ** 2
    PE_grav_all = -m1 * g * L * np.cos(theta1) - m2 * g * L * np.cos(theta2)
    PE_spring_all = 0.5 * k * spring_ext_all**2
    total_E_all = KE_all + PE_grav_all + PE_spring_all

    # =========================================================================
    # FREQUENCY ANALYSIS
    # =========================================================================

    # Theoretical normal modes
    omega1_theory, omega2_theory, f1_theory, f2_theory = (
        th_normal_modes_double_pendulum(g, L, k, m1, m2)
    )

    # Numerical frequency estimation
    f1_num, f2_num = est_freqs_double_pendulum(t, theta1, theta2, g, L, k, m1, m2)

    def _format_freq(f):
        return f"{f:.3f} Hz" if f > 0 else "N/A (Not Excited)"

    f1_str = _format_freq(f1_num)
    f2_str = _format_freq(f2_num)

    print("\nNormal Mode Frequencies:")
    print(f"  Theoretical: œâ‚ÇÅ = {omega1_theory:.3f} rad/s (f‚ÇÅ = {f1_theory:.3f} Hz)")
    print(f"              œâ‚ÇÇ = {omega2_theory:.3f} rad/s (f‚ÇÇ = {f2_theory:.3f} Hz)")
    print(f"  Numerical:   f‚ÇÅ ‚âà {f1_str}, f‚ÇÇ ‚âà {f2_str}")

    # =========================================================================
    # ANIMATION SETUP
    # =========================================================================

    # Pivot positions
    pivot_separation = 1.5
    pivot1_x = -pivot_separation / 2
    pivot2_x = pivot_separation / 2
    pivot_y = 0

    ceiling_length = pivot_separation + 1.0
    separation_ratio = pivot_separation / ceiling_length
    spring_hook_length = calculate_fixed_hook_length(
        ceiling_length, L, separation_ratio, hook_ratio=0.1
    )
    spring_radius = 0.03 * ceiling_length
    spring_num_coils = 10
    ceiling_height = 0.05 * ceiling_length

    # Precompute mass positions
    x1_all = pivot1_x + L * np.sin(theta1)
    y1_all = pivot_y - L * np.cos(theta1)
    x2_all = pivot2_x + L * np.sin(theta2)
    y2_all = pivot_y - L * np.cos(theta2)

    # Figure setup with gridspec
    fig = plt.figure(figsize=(14, 8))
    gs = GridSpec(
        2,
        2,
        figure=fig,
        width_ratios=[1.5, 1],
        left=0.02,
        right=0.98,
        top=0.84,
        bottom=0.08,
        hspace=0.32,
        wspace=0.47,
    )

    # Main pendulum animation (left column)
    ax = fig.add_subplot(gs[:, 0])

    # Time series plots (right column)
    ax_theta1 = fig.add_subplot(gs[0, 1])
    ax_theta2 = fig.add_subplot(gs[1, 1])

    fig.suptitle(
        "Spring-Coupled Double Pendulum Motion", fontsize=16, fontweight="bold"
    )

    xlim, ylim = calculate_plot_limits(
        ceiling_length,
        np.min(x1_all),
        np.max(x2_all),
        np.min(y1_all),
        np.min(y2_all),
        padding=0.2,
    )
    ax.set_xlim(xlim)
    ax.set_ylim(ylim)
    ax.grid(False)
    ax.set_xticks([])
    ax.set_yticks([])
    ax.set_xticklabels([])
    ax.set_yticklabels([])

    # Determine motion type based on initial conditions
    if abs(theta_1_init - theta_2_init) < 1e-6:
        motion_type = "In-Phase Oscillation (Normal Mode 1)"
    elif abs(theta_1_init + theta_2_init) < 1e-6:
        motion_type = "Out-of-Phase Oscillation (Normal Mode 2)"
    elif theta_1_init == 0 or theta_2_init == 0:
        motion_type = "Energy Transfer (Beats Phenomenon)"
    else:
        motion_type = "Mixed-Mode Oscillation (Superposition)"

    title_text = (
        f"Coupled Pendulum System: {motion_type}\n"
        f"Normal Mode Frequencies\n"
        f"$\\omega_1={omega1_theory:.3f}$ rad/s (in-phase), "
        f"$\\omega_2={omega2_theory:.3f}$ rad/s (out-of-phase)"
    )

    ax.set_title(title_text, fontsize=12, fontweight="bold", pad=10)

    # Draw ceiling
    draw_ceiling(ax, ceiling_length, ceiling_height)

    # Pivot points
    ax.plot(pivot1_x, pivot_y, "o", color="black", markersize=15, zorder=6)
    ax.plot(pivot2_x, pivot_y, "o", color="black", markersize=15, zorder=6)

    # Reference lines for equilibrium positions
    ax.axvline(
        pivot1_x, ymin=0, ymax=1, color="gray", ls="--", lw=1, zorder=0, alpha=0.7
    )
    ax.axvline(
        pivot2_x, ymin=0, ymax=1, color="gray", ls="--", lw=1, zorder=0, alpha=0.7
    )

    # Initialize plot elements
    # Rods
    (rod1_line,) = ax.plot([], [], "k-", lw=2.5, zorder=3)
    (rod2_line,) = ax.plot([], [], "k-", lw=2.5, zorder=3)

    # Spring
    (spring_line,) = ax.plot([], [], "gray", lw=1.5, zorder=2)

    # Masses (as markers)
    mass_radius_min = 20
    mass_radius_ref = 25
    avg_mass = (m1 + m2) / 2.0 if (m1 + m2) > 0 else 1.0
    mass_1_size = max(mass_radius_min, mass_radius_ref * np.cbrt(m1 / avg_mass))
    mass_2_size = max(mass_radius_min, mass_radius_ref * np.cbrt(m2 / avg_mass))

    (mass1_plot,) = ax.plot(
        [],
        [],
        "o",
        color="crimson",
        markersize=mass_1_size,
        zorder=4,
        label=f"Mass 1 ($m_1$={m1} kg)",
    )
    (mass2_plot,) = ax.plot(
        [],
        [],
        "o",
        color="royalblue",
        markersize=mass_2_size,
        zorder=4,
        label=f"Mass 2 ($m_2$={m2} kg)",
    )

    # Legend
    ax.legend(loc="lower right", fontsize=8, markerscale=0.5, ncol=2)
    # Traces
    trace1_x, trace1_y = [], []
    trace2_x, trace2_y = [], []
    (trace1_line,) = ax.plot([], [], "crimson", lw=0.8, alpha=0.5, zorder=1)
    (trace2_line,) = ax.plot([], [], "royalblue", lw=0.8, alpha=0.5, zorder=1)

    # Trace markers (front of trace)
    (trace1_marker,) = ax.plot(
        [], [], "o", color="crimson", markersize=8, zorder=5, alpha=0.8
    )
    (trace2_marker,) = ax.plot(
        [], [], "o", color="royalblue", markersize=8, zorder=5, alpha=0.8
    )

    # Setup time series plots
    # Theta1 vs time
    ax_theta1.set_xlim(0, 1.1 * simulation_time)
    ax_theta1.set_ylim(np.degrees(theta1).min() - 5, np.degrees(theta1).max() + 5)
    ax_theta1.set_xlabel("Time (s)", fontsize=10)
    ax_theta1.set_ylabel("$\\theta_1$ (deg)", fontsize=10)
    ax_theta1.grid(True, alpha=0.3)
    ax_theta1.set_title("$\\theta_1$ vs $t$", fontsize=12, fontweight="bold")
    (theta1_time_line,) = ax_theta1.plot([], [], "crimson", lw=2, label="$\\theta_1$")
    (theta1_current_point,) = ax_theta1.plot(
        [], [], "o", color="crimson", markersize=8, zorder=5
    )
    ax_theta1.legend(loc="upper right", fontsize=9)

    # Theta2 vs time
    ax_theta2.set_xlim(0, 1.1 * simulation_time)
    ax_theta2.set_ylim(np.degrees(theta2).min() - 5, np.degrees(theta2).max() + 5)
    ax_theta2.set_xlabel("Time (s)", fontsize=10)
    ax_theta2.set_ylabel("$\\theta_2$ (deg)", fontsize=10)
    ax_theta2.grid(True, alpha=0.3)
    ax_theta2.set_title("$\\theta_2$ vs $t$", fontsize=12, fontweight="bold")
    (theta2_time_line,) = ax_theta2.plot([], [], "royalblue", lw=2, label="$\\theta_2$")
    (theta2_current_point,) = ax_theta2.plot(
        [], [], "o", color="royalblue", markersize=8, zorder=5
    )
    ax_theta2.legend(loc="upper right", fontsize=9)

    # Store time series data
    time_history = []
    theta1_history = []
    theta2_history = []

    # Dynamic text annotations
    time_text = ax.text(
        0.02,
        0.98,
        "",
        transform=ax.transAxes,
        fontsize=11,
        va="top",
        fontfamily="monospace",
        bbox=dict(boxstyle="round", facecolor="wheat", alpha=0.8),
    )

    # Info box with angles and frequencies
    info_box = ax.text(
        1.02,
        0.99,
        "",
        transform=ax.transAxes,
        fontsize=9,
        va="top",
        ha="left",
        fontfamily="monospace",
        bbox=dict(boxstyle="round", facecolor="lightcyan", alpha=0.8),
    )

    # Frequency comparison box
    freq_text = (
        "Frequency Comparison:\n"
        f"{'-' * 20}\n"
        "Theoretical (linear):\n"
        f"$f_1={f1_theory:.3f}$ Hz\n$f_2={f2_theory:.3f}$ Hz\n"
        "Numerical:\n "
        f"$f_1\\approx${f1_str}\n"
        f"$f_2\\approx${f2_str}\n"
        f"Ratio: $\\omega_2/\\omega_1 = {omega2_theory / omega1_theory:.3f}$"
    )

    sys_info_text = (
        f"System Parameters:\n"
        f"{'-' * 20}\n"
        f"Masses: \n"
        f"$m_1$={m1} kg, $m_2$={m2} kg\n"
        f"Spring: $k$={k} N/m\n"
        f"Length: $L$={L} m\n"
        f"Gravity: $g$={g} m/s¬≤\n"
        f"Initial Angles:\n"
        f"$\\theta_1={theta_1_init:.2f}^{{\\circ}}$\n"
        f"$\\theta_2={theta_2_init:.2f}^{{\\circ}}$"
    )
    ax.text(
        1.02,
        0.01,
        freq_text,
        transform=ax.transAxes,
        fontsize=9,
        va="bottom",
        ha="left",
        fontfamily="monospace",
        bbox=dict(boxstyle="round", facecolor="honeydew", alpha=0.8),
    )
    ax.text(
        1.02,
        0.5,
        sys_info_text,
        transform=ax.transAxes,
        fontsize=9,
        va="center",
        ha="left",
        fontfamily="monospace",
        bbox=dict(boxstyle="round", facecolor="lavender", alpha=0.8),
    )

    # =========================================================================
    # ANIMATION FUNCTIONS
    # =========================================================================

    def init():
        """Initialize animation."""
        rod1_line.set_data([], [])
        rod2_line.set_data([], [])
        spring_line.set_data([], [])
        mass1_plot.set_data([], [])
        mass2_plot.set_data([], [])
        trace1_line.set_data([], [])
        trace2_line.set_data([], [])
        trace1_marker.set_data([], [])
        trace2_marker.set_data([], [])
        theta1_time_line.set_data([], [])
        theta2_time_line.set_data([], [])
        theta1_current_point.set_data([], [])
        theta2_current_point.set_data([], [])
        time_text.set_text("")
        info_box.set_text("")
        return (
            rod1_line,
            rod2_line,
            spring_line,
            mass1_plot,
            mass2_plot,
            trace1_line,
            trace2_line,
            trace1_marker,
            trace2_marker,
            theta1_time_line,
            theta2_time_line,
            theta1_current_point,
            theta2_current_point,
            time_text,
            info_box,
        )

    def animate(frame):
        """Update animation for each frame."""
        # Current time and angles
        current_time = t[frame]
        theta1_deg = theta1_deg_all[frame]
        theta2_deg = theta2_deg_all[frame]

        # Current positions
        x1, y1 = x1_all[frame], y1_all[frame]
        x2, y2 = x2_all[frame], y2_all[frame]

        # Update rods
        rod1_line.set_data([pivot1_x, x1], [pivot_y, y1])
        rod2_line.set_data([pivot2_x, x2], [pivot_y, y2])

        spring_x, spring_y = draw_spring_with_hook(
            start_pos=(x1, y1),
            end_pos=(x2, y2),
            num_coils=spring_num_coils,
            radius=spring_radius,
            hook_length=spring_hook_length,
        )
        spring_line.set_data(spring_x, spring_y)

        # Update masses
        mass1_plot.set_data([x1], [y1])
        mass2_plot.set_data([x2], [y2])

        # Update traces
        trace1_x.append(x1)
        trace1_y.append(y1)
        trace2_x.append(x2)
        trace2_y.append(y2)

        # Limit trace length for performance
        max_trace = 50
        if len(trace1_x) > max_trace:
            trace1_x.pop(0)
            trace1_y.pop(0)
            trace2_x.pop(0)
            trace2_y.pop(0)

        trace1_line.set_data(trace1_x, trace1_y)
        trace2_line.set_data(trace2_x, trace2_y)

        # Update trace markers (at the front of traces)
        if len(trace1_x) > 0:
            trace1_marker.set_data([trace1_x[-1]], [trace1_y[-1]])
            trace2_marker.set_data([trace2_x[-1]], [trace2_y[-1]])

        # Update time series data
        time_history.append(current_time)
        theta1_history.append(theta1_deg)
        theta2_history.append(theta2_deg)

        # Update time series plots
        theta1_time_line.set_data(time_history, theta1_history)
        theta2_time_line.set_data(time_history, theta2_history)

        # Update current point markers on time series
        theta1_current_point.set_data([current_time], [theta1_deg])
        theta2_current_point.set_data([current_time], [theta2_deg])

        # Update time text
        time_text.set_text(f"Time: {current_time:.2f} s")

        # Update info box with angular velocities
        omega1_val = omega1[frame]
        omega2_val = omega2[frame]

        # Use precomputed spring extension + energies
        spring_ext = spring_ext_all[frame]
        KE = KE_all[frame]
        PE_grav = PE_grav_all[frame]
        PE_spring = PE_spring_all[frame]
        total_E = total_E_all[frame]

        info_text = (
            f"$\\theta_1 = {theta1_deg:+7.2f}^{{\\circ}}$\n"
            f"$\\theta_2 = {theta2_deg:+7.2f}^{{\\circ}}$\n"
            f"$\\dot{{\\theta}}_1 = {omega1_val:+6.3f}$ rad/s\n"
            f"$\\dot{{\\theta}}_2 = {omega2_val:+6.3f}$ rad/s\n"
            f"‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ\n"
            f"Spring $\\Delta x = {spring_ext:+.3f}$ m\n"
            f"KE = {KE:.3f} J\n"
            f"PE = {PE_grav + PE_spring:.3f} J\n"
            f"Total E = {total_E:.3f} J"
        )
        info_box.set_text(info_text)

        return (
            rod1_line,
            rod2_line,
            spring_line,
            mass1_plot,
            mass2_plot,
            trace1_line,
            trace2_line,
            trace1_marker,
            trace2_marker,
            theta1_time_line,
            theta2_time_line,
            theta1_current_point,
            theta2_current_point,
            time_text,
            info_box,
        )

    # Create animation
    print("\nCreating animation...")
    anim = FuncAnimation(
        fig,
        animate,
        init_func=init,
        frames=len(t),
        interval=1000 / fps,
        blit=True,
        repeat=False,
    )

    # Save animation if requested
    if save_anim:
        if filename is None:
            ext = ".gif" if save_format == "gif" else ".mp4"
            # Create default filename
            filename = (
                f"time_series_m1_{m1}_m2_{m2}_k_{k}_L_{L}_"
                f"theta1_{theta_1_init}_theta2_{theta_2_init}{ext}"
            )

        # Ensure ANIMATIONS directory exists
        anim_dir = "OUTPUTS/ANIMATIONS/coupled_pendulum/double_pendulum"
        os.makedirs(anim_dir, exist_ok=True)

        save_path = os.path.join(anim_dir, filename)
        print(f"Saving animation to: {save_path}")

        try:
            if save_format == "gif":
                anim.save(save_path, writer="ffmpeg", fps=fps, codec="gif", dpi=120)
            else:
                anim.save(save_path, writer="ffmpeg", fps=fps, dpi=120)
            print("Animation saved successfully!")
        except Exception as e:
            print(f"Error saving animation: {e}")

        plt.close(fig)
    else:
        plt.show()

    return anim


def main():
    """Main function to run simulations and demonstrations."""
    header = "STARTING SPRING-COUPLED DOUBLE PENDULUM SIMULATION"
    print("\n" + "=" * len(header))
    print(header)
    print("=" * len(header))

    default_params = {
        "theta_1_init": 0.0,
        "theta_2_init": 10.0,
        "m1": 1.0,
        "m2": 1.0,
        "k": 5.0,
        "L": 2.0,
        "g": 9.81,
        "simulation_time": 10.0,
        "fps": 30,
        "save_format": "gif",
    }

    use_default = input("Use default simulation parameters? (y/n): ").strip().lower()
    if use_default == "y":
        params = default_params.copy()
    else:
        params = {}
        try:
            params["theta_1_init"] = float(
                input(
                    f"Enter initial angle for pendulum 1 (degrees)[{default_params['theta_1_init']}] degrees): "
                )
                or default_params["theta_1_init"]
            )
            params["theta_2_init"] = float(
                input(
                    f"Enter initial angle for pendulum 2 (degrees)[{default_params['theta_2_init']}] degrees): "
                )
                or default_params["theta_2_init"]
            )
            params["m1"] = float(
                input(
                    f"Enter mass of the first pendulum (kg)[{default_params['m1']}] kg): "
                )
                or default_params["m1"]
            )
            params["m2"] = float(
                input(
                    f"Enter mass of the second pendulum (kg)[{default_params['m2']}] kg): "
                )
                or default_params["m2"]
            )
            params["k"] = float(
                input(f"Enter spring constant (N/m)[{default_params['k']}] N/m): ")
                or default_params["k"]
            )
            params["L"] = float(
                input(
                    f"Enter length of the pendulum rods (m)[{default_params['L']}] m): "
                )
                or default_params["L"]
            )
            params["g"] = float(
                input(
                    f"Enter gravitational acceleration (m/s¬≤)[{default_params['g']}] m/s¬≤): "
                )
                or default_params["g"]
            )
            params["simulation_time"] = float(
                input(
                    f"Enter simulation time (seconds)[{default_params['simulation_time']}] seconds): "
                )
                or default_params["simulation_time"]
            )
            params["fps"] = int(
                input(
                    f"Enter frames per second for animation [{default_params['fps']}]: "
                )
                or default_params["fps"]
            )
            params["save_format"] = (
                input(
                    f"Enter animation save format ('gif' or 'mp4')[{default_params['save_format']}]: "
                )
                .strip()
                .lower()
                or default_params["save_format"]
            )
            if params["save_format"] not in ["gif", "mp4"]:
                print("Invalid format. Using default 'gif'.")
                params["save_format"] = "gif"
        except ValueError:
            print("Invalid input. Using default parameters.")
            params = default_params.copy()

    save_anim = input("Save animation to file? (y/n): ").strip().lower() == "y"
    filename = None
    if save_anim:
        filename_input = input(
            "Enter filename for animation (or press Enter for default): "
        )
        if filename_input:
            if not filename_input.endswith((".gif", ".mp4")):
                filename_input += f".{params['save_format']}"
            filename = filename_input

    print("\nRunning simulation with parameters:")
    for key, value in params.items():
        print(f"  {key}: {value}")

    demo_normal_modes()

    animation = double_pendulum_animation_with_plots(
        theta_1_init=params["theta_1_init"],
        theta_2_init=params["theta_2_init"],
        m1=params["m1"],
        m2=params["m2"],
        k=params["k"],
        L=params["L"],
        g=params["g"],
        simulation_time=params["simulation_time"],
        fps=params["fps"],
        save_format=params["save_format"],
        save_anim=save_anim,
        filename=filename,
    )

    return animation


if __name__ == "__main__":
    animation = main()


### 1.4 Frequency-Domain Analysis (FFT and Peak Picking)

**FFT idea (what it does):**
- If you have a sampled signal $\theta(t)$ over a time window $T$, the Fourier transform decomposes it into sinusoids at discrete frequencies $f_k=\frac{k}{T}$ (more precisely $f_k=\frac{k}{N\,\Delta t}$ for $N$ samples spaced by $\Delta t$). The FFT is just the fast algorithm to compute those Fourier coefficients.
- The **dominant frequency** is typically taken as the $f_k$ where the **magnitude spectrum** $|\Theta(f_k)|$ is largest (often ignoring the $f=0$ ‚ÄúDC‚Äù component).

**How it‚Äôs used in this notebook: two related but distinct methods**

1) **Modal-FFT dominant frequencies (used for the reported ‚Äúnumerical normal modes‚Äù)**  
`est_freqs_double_pendulum` first projects the simulated angles onto the *linearized normal-mode coordinates* $q_1(t),q_2(t)$, then does an FFT of each $q_i(t)$ and picks the largest non-DC bin as the mode frequency. It also applies a Hann window to reduce spectral leakage and uses a threshold to mark modes ‚Äúnot excited.‚Äù  
See the FFT + peak-pick part in notebook.ipynb.

1) **Direct FFT spectrum + explicit peak detection/plotting (used to ‚Äúdisplay dominant peaks‚Äù)**  
`plot_fft_spectrum_double_pendulum` computes FFTs of the *raw* angle signals $\theta_1(t)$ and $\theta_2(t)$ and then uses `scipy.signal.find_peaks` on the amplitude spectra:
- It first removes DC offset: $\theta_i \leftarrow \theta_i-\langle \theta_i\rangle$.
- Applies a Hann window: $\theta_i^{(w)}(t)=\theta_i(t)\,w(t)$.
- Computes real FFT: `rfft`, frequency axis: `rfftfreq`.
- Builds a (roughly) normalized one-sided amplitude spectrum:
  $$
  A(f_k)=\frac{2}{N}\,|\Theta(f_k)|.
  $$
- Detects peaks using constraints like `prominence`, `height`, and `distance`.  
This is implemented in notebook.ipynb.

**How it displays the dominant peaks:**
- **Console output:** it prints how many peaks were found for $\theta_1$ and $\theta_2$, and lists up to the first 5 peaks with their frequency and amplitude.
- **Plots:** it plots the full FFT amplitude spectrum and overlays the detected peaks as large markers. Then it annotates (labels) the **top 5 peaks by amplitude** with their frequencies (yellow callouts + arrows). It also draws vertical dashed lines at the *theoretical* $f_1,f_2$ for comparison, and adds a small text box with the modal-decomposition numerical estimates.

Net: the notebook both (a) estimates mode frequencies robustly via **modal FFT** and (b) **visualizes dominant spectral peaks** directly from $\theta_1,\theta_2$ using `find_peaks` + annotations.

In [None]:
def plot_fft_spectrum_double_pendulum(
    t,
    theta1,
    theta2,
    g,
    L,
    k,
    m1,
    m2,
    theta_1_init,
    theta_2_init,
    prominence=0.01,
    height=None,
    distance=5,
    save_fig=False,
    filename=None,
):
    """
    Plot FFT spectrum of theta_1 and theta_2 with peak detection and annotation.
    Also displays theoretical normal mode frequencies.

    Parameters
    ----------
    t : array
        Time array
    theta1, theta2 : array
        Angle arrays from simulation
    g, L, k, m1, m2 : float
        System parameters for calculating theoretical frequencies
    theta_1_init, theta_2_init : float
        Initial angles in degrees (for motion type identification)
    prominence : float, optional
        Minimum prominence of peaks (default: 0.01)
    height : float, optional
        Minimum height of peaks (default: None)
    distance : int, optional
        Minimum distance between peaks in samples (default: 5)
    save_fig : bool, optional
        Whether to save the figure (default: False)
    filename : str or None, optional
        Filename to save the figure (if None, auto-generate)

    Returns
    -------
    fig : matplotlib.figure.Figure
        Figure object containing the plots
    ax : array of matplotlib.axes.Axes
        Array of axes objects
    """
    # Calculate theoretical frequencies
    _, _, f1_theory, f2_theory = th_normal_modes_double_pendulum(g, L, k, m1, m2)

    # Numerical frequency estimation using the existing function
    f1_num, f2_num = est_freqs_double_pendulum(t, theta1, theta2, g, L, k, m1, m2)

    # Format frequency strings
    def _format_freq(f):
        return f"{f:.4f} Hz" if f > 0 else "N/A (Not Excited)"

    f1_str = _format_freq(f1_num)
    f2_str = _format_freq(f2_num)

    # Identify motion type based on initial conditions
    tol = 1e-6  # Tolerance for comparison

    if abs(theta_1_init - theta_2_init) < tol:
        if abs(theta_1_init) < tol:  # Both zero
            motion_type = "No Initial Displacement"
            excited_mode = "None"
        else:
            motion_type = "In-Phase Mode"
            excited_mode = "Mode 1 (Lower Frequency)"
    elif abs(theta_1_init + theta_2_init) < tol:
        motion_type = "Out-of-Phase Mode"
        excited_mode = "Mode 2 (Higher Frequency)"
    elif abs(theta_1_init) < tol or abs(theta_2_init) < tol:
        motion_type = "Energy Transfer (Beats Phenomenon)"
        excited_mode = "Both Modes (Beat Frequency Visible)"
    else:
        motion_type = "Mixed-Mode Oscillation"
        excited_mode = "Both Modes (Superposition)"

    # Compute FFT for both angles
    dt = t[1] - t[0]
    n = len(t)

    # Remove DC offset
    theta1_centered = theta1 - np.mean(theta1)
    theta2_centered = theta2 - np.mean(theta2)

    # Apply Hann window to reduce spectral leakage
    window = np.hanning(n)
    theta1_windowed = theta1_centered * window
    theta2_windowed = theta2_centered * window

    # Compute FFT
    freqs = np.fft.rfftfreq(n, dt)
    fft_theta1 = np.fft.rfft(theta1_windowed)
    fft_theta2 = np.fft.rfft(theta2_windowed)

    # Compute amplitude spectrum (normalized)
    amp_theta1 = 2.0 * np.abs(fft_theta1) / n
    amp_theta2 = 2.0 * np.abs(fft_theta2) / n

    # Find peaks for theta1
    peaks1, properties1 = find_peaks(
        amp_theta1, prominence=prominence, height=height, distance=distance
    )

    # Find peaks for theta2
    peaks2, properties2 = find_peaks(
        amp_theta2, prominence=prominence, height=height, distance=distance
    )

    # Print peak information
    print("\n--- FFT Peak Detection ---")
    print(f"Motion Type: {motion_type}")
    print(f"Excited Mode: {excited_mode}")
    print(f"\nŒ∏‚ÇÅ - Found {len(peaks1)} peaks:")
    if len(peaks1) > 0:
        for i, idx in enumerate(peaks1[:5]):  # Show first 5 peaks
            print(
                f"  Peak {i + 1}: f = {freqs[idx]:.4f} Hz, amplitude = {amp_theta1[idx]:.6f}"
            )
    else:
        print(
            f"  No peaks found! Max amplitude = {np.max(amp_theta1):.6f} at f = {freqs[np.argmax(amp_theta1)]:.4f} Hz"
        )

    print(f"\nŒ∏‚ÇÇ - Found {len(peaks2)} peaks:")
    if len(peaks2) > 0:
        for i, idx in enumerate(peaks2[:5]):  # Show first 5 peaks
            print(
                f"  Peak {i + 1}: f = {freqs[idx]:.4f} Hz, amplitude = {amp_theta2[idx]:.6f}"
            )
    else:
        print(
            f"  No peaks found! Max amplitude = {np.max(amp_theta2):.6f} at f = {freqs[np.argmax(amp_theta2)]:.4f} Hz"
        )

    # Create figure with shared x-axis
    fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(14, 10), sharex=True)
    fig.subplots_adjust(top=0.89, bottom=0.08, left=0.1, right=0.88, hspace=0.25)

    # Title with motion type
    title_text = (
        f"FFT Spectrum of Coupled Double Pendulum Angles\n"
        f"Motion Type: {motion_type} | Excited: {excited_mode}"
    )
    fig.suptitle(title_text, fontsize=15, fontweight="bold")

    # Determine appropriate xlim based on theoretical frequencies
    max_freq_display = max(f1_theory, f2_theory) * 3.0

    # Plot theta1 FFT
    ax1.plot(
        freqs, amp_theta1, "crimson", linewidth=1.5, label="FFT Amplitude", zorder=1
    )

    # Plot detected peaks
    if len(peaks1) > 0:
        ax1.plot(
            freqs[peaks1],
            amp_theta1[peaks1],
            "o",
            color="blue",
            markeredgecolor="black",
            markeredgewidth=1.5,
            markersize=10,
            label="Detected Peaks",
            zorder=5,
        )

        # Annotate peaks for theta1 (only top 5 by amplitude)
        peak_amps = amp_theta1[peaks1]
        top_peaks_idx = np.argsort(peak_amps)[-5:][::-1]  # Top 5 peaks

        for i, local_idx in enumerate(top_peaks_idx):
            peak_idx = peaks1[local_idx]
            peak_freq = freqs[peak_idx]
            peak_amp = amp_theta1[peak_idx]
            ax1.annotate(
                f"{peak_freq:.4f} Hz",
                xy=(peak_freq, peak_amp),
                xytext=(15, -25),
                textcoords="offset points",
                fontsize=9,
                fontweight="bold",
                bbox=dict(
                    boxstyle="round,pad=0.5",
                    facecolor="yellow",
                    edgecolor="black",
                    alpha=0.85,
                ),
                arrowprops=dict(
                    arrowstyle="->",
                    connectionstyle="arc3,rad=0.2",
                    color="blue",
                    lw=1.5,
                ),
                zorder=6,
            )

    # Add theoretical frequency lines for theta1
    ax1.axvline(
        f1_theory,
        color="green",
        linestyle="--",
        linewidth=2.5,
        alpha=0.8,
        label=f"Theory: $f_1={f1_theory:.4f}$ Hz",
        zorder=3,
    )
    ax1.axvline(
        f2_theory,
        color="orange",
        linestyle="--",
        linewidth=2.5,
        alpha=0.8,
        label=f"Theory: $f_2={f2_theory:.4f}$ Hz",
        zorder=3,
    )

    # Add numerical frequency estimates as text
    textstr1 = (
        f"Numerical Estimates:\n"
        f"(via modal decomposition)\n"
        f"$f_1 \\approx$ {f1_str}\n"
        f"$f_2 \\approx$ {f2_str}\n"
        f"‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ\n"
        f"Peaks detected: {len(peaks1)}"
    )
    ax1.text(
        0.99,
        0.03,
        textstr1,
        transform=ax1.transAxes,
        fontsize=9,
        va="bottom",
        ha="right",
        bbox=dict(
            boxstyle="round",
            facecolor="lightblue",
            edgecolor="black",
            alpha=0.6,
            linewidth=1.5,
        ),
        zorder=6,
    )

    # System info text box
    sys_info_text = (
        f"System Parameters:\n"
        f"{'‚îÄ' * 18}\n"
        f"Initial Angles:\n"
        f"  $\\theta_1^{{(0)}} = {theta_1_init:+.2f}¬∞$\n"
        f"  $\\theta_2^{{(0)}} = {theta_2_init:+.2f}¬∞$\n"
        f"Masses:\n"
        f"  $m_1 = {m1:.2f}$ kg\n"
        f"  $m_2 = {m2:.2f}$ kg\n"
        f"Spring: $k = {k:.2f}$ N/m\n"
        f"Length: $L = {L:.2f}$ m\n"
        f"Gravity: $g = {g:.2f}$ m/s¬≤\n"
        f"Simulation: $T = {t[-1]:.1f}$ s\n"
        f"Points: $N = {len(t)}$"
    )
    ax1.text(
        1.01,
        0.99,
        sys_info_text,
        transform=ax1.transAxes,
        fontsize=9,
        va="top",
        ha="left",
        bbox=dict(
            boxstyle="round",
            facecolor="lavender",
            edgecolor="black",
            alpha=0.9,
            linewidth=1.5,
        ),
        fontfamily="monospace",
        zorder=6,
    )

    ax1.set_ylabel("Amplitude (rad)", fontsize=12, fontweight="bold")
    ax1.set_title(r"$\theta_1$ (Pendulum 1)", fontsize=13, fontweight="bold", pad=10)
    ax1.grid(True, alpha=0.3, linestyle="--")
    ax1.legend(loc="upper right", fontsize=9, framealpha=0.9)
    ax1.set_xlim(0, max_freq_display)
    ax1.set_ylim(bottom=0)  # Start y-axis at 0

    # Plot theta2 FFT
    ax2.plot(
        freqs, amp_theta2, "royalblue", linewidth=1.5, label="FFT Amplitude", zorder=1
    )

    # Plot detected peaks
    if len(peaks2) > 0:
        ax2.plot(
            freqs[peaks2],
            amp_theta2[peaks2],
            "o",
            color="red",
            markeredgecolor="black",
            markeredgewidth=1.5,
            markersize=10,
            label="Detected Peaks",
            zorder=5,
        )

        # Annotate peaks for theta2 (only top 5 by amplitude)
        peak_amps = amp_theta2[peaks2]
        top_peaks_idx = np.argsort(peak_amps)[-5:][::-1]  # Top 5 peaks

        for i, local_idx in enumerate(top_peaks_idx):
            peak_idx = peaks2[local_idx]
            peak_freq = freqs[peak_idx]
            peak_amp = amp_theta2[peak_idx]
            ax2.annotate(
                f"{peak_freq:.4f} Hz",
                xy=(peak_freq, peak_amp),
                xytext=(15, -25),
                textcoords="offset points",
                fontsize=9,
                fontweight="bold",
                bbox=dict(
                    boxstyle="round,pad=0.5",
                    facecolor="yellow",
                    edgecolor="black",
                    alpha=0.85,
                ),
                arrowprops=dict(
                    arrowstyle="->", connectionstyle="arc3,rad=0.2", color="red", lw=1.5
                ),
                zorder=6,
            )

    # Add theoretical frequency lines for theta2
    ax2.axvline(
        f1_theory,
        color="green",
        linestyle="--",
        linewidth=2.5,
        alpha=0.8,
        label=f"Theory: $f_1={f1_theory:.4f}$ Hz",
        zorder=3,
    )
    ax2.axvline(
        f2_theory,
        color="orange",
        linestyle="--",
        linewidth=2.5,
        alpha=0.8,
        label=f"Theory: $f_2={f2_theory:.4f}$ Hz",
        zorder=3,
    )

    textstr2 = (
        f"Numerical Estimates:\n"
        f"(via modal decomposition)\n"
        f"$f_1 \\approx$ {f1_str}\n"
        f"$f_2 \\approx$ {f2_str}\n"
        f"‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ\n"
        f"Peaks detected: {len(peaks2)}"
    )
    ax2.text(
        0.99,
        0.03,
        textstr2,
        transform=ax2.transAxes,
        fontsize=9,
        va="bottom",
        ha="right",
        bbox=dict(
            boxstyle="round",
            facecolor="lightblue",
            edgecolor="black",
            alpha=0.6,
            linewidth=1.5,
        ),
        zorder=6,
    )

    ax2.set_xlabel("Frequency (Hz)", fontsize=12, fontweight="bold")
    ax2.set_ylabel("Amplitude (rad)", fontsize=12, fontweight="bold")
    ax2.set_title(r"$\theta_2$ (Pendulum 2)", fontsize=13, fontweight="bold", pad=10)
    ax2.grid(True, alpha=0.3, linestyle="--")
    ax2.legend(loc="upper right", fontsize=9, framealpha=0.9)
    ax2.set_xlim(0, max_freq_display)
    ax2.set_ylim(bottom=0)

    plt.show()

    # Save figure if requested
    if save_fig:
        if filename is None:
            filename = (
                f"fft_spectrum_double_pendulum_m1_{m1}_m2_{m2}_k_{k}_L_{L}_"
                f"theta1_{theta_1_init}_theta2_{theta_2_init}.png"
            )

        # Ensure FIGURES directory exists
        fig_dir = "OUTPUTS/FIGURES/coupled_pendulum/double_pendulum"
        os.makedirs(fig_dir, exist_ok=True)

        save_path = os.path.join(fig_dir, filename)
        print(f"Saving FFT spectrum figure to: {save_path}")

        try:
            fig.savefig(save_path, dpi=300, bbox_inches="tight")
            print("Figure saved successfully!")
        except Exception as e:
            print(f"Error saving figure: {e}")

    return fig, (ax1, ax2)


# Solve the coupled pendulum system and plot FFT
def analyze_coupled_pendulum_fft(
    theta_1_init=30.0,
    theta_2_init=0.0,
    m1=1.0,
    m2=1.0,
    k=5.0,
    L=2.0,
    g=9.81,
    simulation_time=50.0,
    n_points=1000,
    save_fig=False,
    filename=None,
):
    """
    Solve coupled pendulum equations and plot FFT spectrum.

    Parameters
    ----------
    theta_1_init, theta_2_init : float
        Initial angles in degrees
    m1, m2 : float
        Masses
    k : float
        Spring constant
    L : float
        Pendulum length
    g : float
        Gravitational acceleration
    simulation_time : float
        Simulation duration in seconds
    n_points : int
        Number of time points
    save_fig : bool, optional
        Whether to save the FFT figure (default: False)
    filename : str or None, optional
        Filename to save the figure (if None, auto-generate)

    Returns
    -------
    fig : matplotlib.figure.Figure
        Figure with FFT plots
    """
    print("=" * 60)
    print("FFT ANALYSIS OF COUPLED DOUBLE PENDULUM")
    print("=" * 60)

    # Convert initial angles to radians
    theta1_0 = np.radians(theta_1_init)
    theta2_0 = np.radians(theta_2_init)

    # Initial state: [Œ∏‚ÇÅ, œâ‚ÇÅ, Œ∏‚ÇÇ, œâ‚ÇÇ]
    y0 = [theta1_0, 0.0, theta2_0, 0.0]

    # Time span
    t_span = (0, simulation_time)
    t_eval = np.linspace(0, simulation_time, n_points)

    # Solve the ODE
    print("\nSolving differential equations...")
    print(f"  Initial angles: Œ∏‚ÇÅ={theta_1_init}¬∞, Œ∏‚ÇÇ={theta_2_init}¬∞")
    print(f"  Simulation time: {simulation_time} s")
    print(f"  Number of points: {n_points}")

    solution = solve_ivp(
        coupled_pendulum_derivatives,
        t_span,
        y0,
        args=(g, L, k, m1, m2),
        method="RK45",
        t_eval=t_eval,
        rtol=1e-6,
        atol=1e-8,
    )

    theta1 = solution.y[0]
    theta2 = solution.y[2]
    t = solution.t

    print("Solution computed successfully!")

    # Calculate theoretical frequencies
    omega1_th, omega2_th, f1_th, f2_th = th_normal_modes_double_pendulum(
        g, L, k, m1, m2
    )

    print("\nTheoretical Normal Mode Frequencies:")
    print(f"  œâ‚ÇÅ = {omega1_th:.4f} rad/s  ‚Üí  f‚ÇÅ = {f1_th:.4f} Hz (in-phase mode)")
    print(f"  œâ‚ÇÇ = {omega2_th:.4f} rad/s  ‚Üí  f‚ÇÇ = {f2_th:.4f} Hz (out-of-phase mode)")

    print("\nPlotting FFT spectrum...")

    # Plot FFT spectrum with adjusted parameters
    fig, axes = plot_fft_spectrum_double_pendulum(
        t,
        theta1,
        theta2,
        g,
        L,
        k,
        m1,
        m2,
        theta_1_init,
        theta_2_init,
        prominence=0.005,
        distance=5,
        save_fig=save_fig,
        filename=filename,
    )

    print("\nFFT analysis complete!")
    print(
        "\nNote: Both angles show similar frequency content because they are coupled."
    )
    print("The amplitudes differ based on how each pendulum participates in each mode.")

    return fig


fig = analyze_coupled_pendulum_fft(
    theta_1_init=10.0,
    theta_2_init=-10.0,
    m1=1.0,
    m2=1.0,
    k=5.0,
    L=2.0,
    g=9.81,
    simulation_time=50.0,
    n_points=1000,
    save_fig=True,
    filename="fft_spectrum_double_pendulum_out_of_phase.png",
)


### 1.5 Phase Space and Configuration Space Trajectories

**Phase space vs configuration space**

- **Configuration space** is the space of generalized coordinates that specify the system‚Äôs instantaneous geometric state.  
  For your coupled pendulums, the configuration at time $t$ is $(\theta_1(t),\theta_2(t))$, so the configuration space is a 2D plane spanned by $(\theta_1,\theta_2)$. In general, for $n$ degrees of freedom, configuration space is $n$-dimensional.

- **Phase space** is the space of coordinates *and their conjugate momenta* (or velocities).  
  For each pendulum you can use $(\theta_i,\dot\theta_i)$ as a 2D phase portrait; for the full 2-DOF system the phase space is 4D: $(\theta_1,\dot\theta_1,\theta_2,\dot\theta_2)$.

**Why it‚Äôs important to study them**
- **Qualitative dynamics at a glance:** closed loops suggest periodic motion; drifting/complex structures can indicate multi-frequency motion, quasi-periodicity, or (in other systems) chaos.
- **Energy/stability intuition:** for conservative systems, phase portraits often show families of trajectories corresponding to energy levels; fixed points and their stability are visually identifiable.
- **Mode identification:** in configuration space, motion aligned with $\theta_2=\theta_1$ corresponds to the in-phase mode; motion aligned with $\theta_2=-\theta_1$ corresponds to the out-of-phase mode (in the small-angle symmetric case).
- **Energy transfer/beats:** the trajectory can wander between these directions, showing how the system‚Äôs state moves between ‚Äúmostly mode 1‚Äù and ‚Äúmostly mode 2‚Äù.

**How this is visualized in the notebook**

All of this is implemented in `plot_phase_and_config_space` and driven by `analyze_coupled_pendulum_phase_space`:

- **Data source:** `analyze_coupled_pendulum_phase_space` solves the nonlinear ODEs with `solve_ivp`, producing arrays $\theta_1(t),\dot\theta_1(t),\theta_2(t),\dot\theta_2(t)$.

- **Phase space plots (two separate 2D portraits):**
  - Plot 1: $\dot\theta_1$ vs $\theta_1$ (actually $\theta_1$ converted to degrees on the x-axis).
  - Plot 2: $\dot\theta_2$ vs $\theta_2$ (actually $\theta_2$ converted to degrees on the x-axis).
  Each trajectory is drawn with a **time color-gradient** (colormap `plasma`), plus:
  - a green ‚ÄúStart‚Äù point and red ‚ÄúEnd‚Äù point,
  - directional arrows (`FancyArrowPatch`) placed at a chosen `arrow_density` to show flow direction.

- **Configuration space plot (2D state trajectory):**
  - Plot 3: $\theta_2$ vs $\theta_1$ (both in degrees).
  It also draws reference lines for the normal-mode directions:
  - in-phase line $\theta_2=\theta_1$,
  - out-of-phase line $\theta_2=-\theta_1$,
  and overlays the same time-gradient + arrows + start/end markers.

- **Context overlays:** it adds a system-parameters text box and a horizontal colorbar mapping color to time.


In [None]:
def plot_phase_and_config_space(
    t,
    theta1,
    omega1,
    theta2,
    omega2,
    g,
    L,
    k,
    m1,
    m2,
    theta_1_init,
    theta_2_init,
    arrow_density=8,
    save_fig=False,
    filename=None,
):
    """
    Plot phase space diagrams for both pendulums and configuration space.

    Parameters
    ----------
    t : array
        Time array
    theta1, omega1 : array
        Angle and angular velocity for pendulum 1
    theta2, omega2 : array
        Angle and angular velocity for pendulum 2
    g, L, k, m1, m2 : float
        System parameters
    theta_1_init, theta_2_init : float
        Initial angles in degrees
    arrow_density : int, optional
        Number of arrows to display along trajectory (default: 8)
    save_fig : bool, optional
        Whether to save the figure (default: False)
    filename : str or None, optional
        Filename for saved figure. Auto-generated if None (default: None)

    Returns
    -------
    fig : matplotlib.figure.Figure
        Figure object
    """
    # Calculate theoretical frequencies
    _, _, f1_theory, f2_theory = th_normal_modes_double_pendulum(g, L, k, m1, m2)

    # Identify motion type
    tol = 1e-6
    if abs(theta_1_init - theta_2_init) < tol:
        if abs(theta_1_init) < tol:
            motion_type = "No Initial Displacement"
        else:
            motion_type = "In-Phase Mode"
    elif abs(theta_1_init + theta_2_init) < tol:
        motion_type = "Out-of-Phase Mode"
    elif abs(theta_1_init) < tol or abs(theta_2_init) < tol:
        motion_type = "Energy Transfer (Beats Phenomenon)"
    else:
        motion_type = "Mixed-Mode Oscillation"

    # Create figure
    fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(16, 8))
    fig.subplots_adjust(top=0.88, bottom=0.12, left=0.05, right=0.85, wspace=0.3)

    # Main title
    fig.suptitle(
        "Phase Space and Configuration Space Analysis of Spring-Coupled Double Pendulum System"
        f"\nMotion Type: {motion_type}",
        fontsize=15,
        fontweight="bold",
    )

    colors = plt.cm.plasma(np.linspace(0, 1, len(t)))  # type: ignore

    n_points = len(t)

    # =========================================================================
    # Plot 1: Phase Space for Pendulum 1 (Œ∏‚ÇÅ vs œâ‚ÇÅ)
    # =========================================================================

    # Convert theta1 to degrees for plotting
    theta1_deg = np.degrees(theta1)

    # Plot trajectory with color gradient
    for i in range(len(t) - 1):
        ax1.plot(
            theta1_deg[i : i + 2],
            omega1[i : i + 2],
            color=colors[i],
            linewidth=1.5,
            alpha=0.7,
        )

    # Mark initial and final points
    ax1.plot(
        theta1_deg[0],
        omega1[0],
        "o",
        color="lime",
        ms=10,
        mec="black",
        mew=2,
        label="Start",
        zorder=5,
    )
    ax1.plot(
        theta1_deg[-1],
        omega1[-1],
        "o",
        color="red",
        ms=10,
        mec="black",
        mew=2,
        label="End",
        zorder=5,
    )

    # Set preliminary limits to calculate arrow scaling
    ax1.set_xlim(
        theta1_deg.min() - 0.1 * np.ptp(theta1_deg),
        theta1_deg.max() + 0.1 * np.ptp(theta1_deg),
    )
    ax1.set_ylim(
        omega1.min() - 0.1 * np.ptp(omega1), omega1.max() + 0.1 * np.ptp(omega1)
    )

    # Add direction arrows using FancyArrowPatch
    arrow_indices = np.linspace(0, n_points - 1, arrow_density, dtype=int)

    # Calculate plot scale for arrow sizing - use BOTH axis ranges
    theta1_span = np.ptp(theta1_deg)
    omega1_span = np.ptp(omega1)

    for idx in arrow_indices[:-1]:
        if idx + 1 < n_points:
            # Calculate direction vector in data coordinates
            dx_data = theta1_deg[idx + 1] - theta1_deg[idx]
            dy_data = omega1[idx + 1] - omega1[idx]

            # Normalize by axis ranges to get direction in "normalized" space
            dx_norm = dx_data / theta1_span if theta1_span > 0 else 0
            dy_norm = dy_data / omega1_span if omega1_span > 0 else 0

            norm = np.sqrt(dx_norm**2 + dy_norm**2)

            if norm > 1e-9:
                # Normalize direction vector
                dx_norm = dx_norm / norm
                dy_norm = dy_norm / norm

                # Scale back to data coordinates with desired arrow length
                # Arrow length is 8% of the average axis range
                arrow_scale = 0.08
                arrow_len_x = dx_norm * arrow_scale * theta1_span
                arrow_len_y = dy_norm * arrow_scale * omega1_span

                # Use FancyArrowPatch
                arrow = FancyArrowPatch(
                    posA=(theta1_deg[idx], omega1[idx]),
                    posB=(theta1_deg[idx] + arrow_len_x, omega1[idx] + arrow_len_y),
                    arrowstyle="-|>",
                    mutation_scale=20,
                    color=colors[idx],
                    linewidth=2,
                    zorder=10,
                    alpha=0.8,
                )
                ax1.add_patch(arrow)

    ax1.set_xlabel("$\\theta_1$ (degrees)", fontsize=12, fontweight="bold")
    ax1.set_ylabel("$\\dot{\\theta}_1$ (rad/s)", fontsize=12, fontweight="bold")
    ax1.set_title("Phase Space: Pendulum 1", fontsize=13, fontweight="bold")
    ax1.grid(True, alpha=0.3, linestyle="--")
    ax1.axhline(0, color="black", linewidth=0.5, alpha=0.5)
    ax1.axvline(0, color="black", linewidth=0.5, alpha=0.5)

    # =========================================================================
    # Plot 2: Phase Space for Pendulum 2 (Œ∏‚ÇÇ vs œâ‚ÇÇ)
    # =========================================================================

    # Convert theta2 to degrees for plotting
    theta2_deg = np.degrees(theta2)

    # Plot trajectory with color gradient
    for i in range(len(t) - 1):
        ax2.plot(
            theta2_deg[i : i + 2],
            omega2[i : i + 2],
            color=colors[i],
            linewidth=1.5,
            alpha=0.7,
        )

    # Mark initial and final points
    ax2.plot(
        theta2_deg[0],
        omega2[0],
        "o",
        color="lime",
        ms=10,
        mec="black",
        mew=2,
        label="Start",
        zorder=5,
    )
    ax2.plot(
        theta2_deg[-1],
        omega2[-1],
        "o",
        color="red",
        ms=10,
        mec="black",
        mew=2,
        label="End",
        zorder=5,
    )

    ax2.set_xlim(
        theta2_deg.min() - 0.1 * np.ptp(theta2_deg),
        theta2_deg.max() + 0.1 * np.ptp(theta2_deg),
    )
    ax2.set_ylim(
        omega2.min() - 0.1 * np.ptp(omega2), omega2.max() + 0.1 * np.ptp(omega2)
    )

    theta2_span = np.ptp(theta2_deg)
    omega2_span = np.ptp(omega2)

    for idx in arrow_indices[:-1]:
        if idx + 1 < n_points:
            dx_data = theta2_deg[idx + 1] - theta2_deg[idx]
            dy_data = omega2[idx + 1] - omega2[idx]

            dx_norm = dx_data / theta2_span if theta2_span > 0 else 0
            dy_norm = dy_data / omega2_span if omega2_span > 0 else 0

            norm = np.sqrt(dx_norm**2 + dy_norm**2)

            if norm > 1e-9:
                dx_norm = dx_norm / norm
                dy_norm = dy_norm / norm

                arrow_scale = 0.08
                arrow_len_x = dx_norm * arrow_scale * theta2_span
                arrow_len_y = dy_norm * arrow_scale * omega2_span

                arrow = FancyArrowPatch(
                    posA=(theta2_deg[idx], omega2[idx]),
                    posB=(theta2_deg[idx] + arrow_len_x, omega2[idx] + arrow_len_y),
                    arrowstyle="-|>",
                    mutation_scale=20,
                    color=colors[idx],
                    linewidth=2,
                    zorder=10,
                    alpha=0.8,
                )
                ax2.add_patch(arrow)

    ax2.set_xlabel("$\\theta_2$ (degrees)", fontsize=12, fontweight="bold")
    ax2.set_ylabel("$\\dot{\\theta}_2$ (rad/s)", fontsize=12, fontweight="bold")
    ax2.set_title("Phase Space: Pendulum 2", fontsize=13, fontweight="bold")
    ax2.grid(True, alpha=0.3, linestyle="--")
    ax2.axhline(0, color="black", linewidth=0.5, alpha=0.5)
    ax2.axvline(0, color="black", linewidth=0.5, alpha=0.5)

    # =========================================================================
    # Plot 3: Configuration Space (Œ∏‚ÇÇ vs Œ∏‚ÇÅ)
    # =========================================================================

    for i in range(len(t) - 1):
        ax3.plot(
            theta1_deg[i : i + 2],
            theta2_deg[i : i + 2],
            color=colors[i],
            linewidth=2,
            alpha=0.8,
        )

    ax3.plot(
        theta1_deg[0],
        theta2_deg[0],
        "o",
        color="lime",
        ms=10,
        mec="black",
        mew=2,
        label="Start",
        zorder=5,
    )
    ax3.plot(
        theta1_deg[-1],
        theta2_deg[-1],
        "o",
        color="red",
        ms=10,
        mec="black",
        mew=2,
        label="End",
        zorder=5,
    )

    arrow_indices_config = np.linspace(0, n_points - 1, arrow_density, dtype=int)

    theta1_config_span = theta1_deg.max() - theta1_deg.min()

    for idx in arrow_indices_config[:-1]:
        if idx + 1 < n_points:
            dx = theta1_deg[idx + 1] - theta1_deg[idx]
            dy = theta2_deg[idx + 1] - theta2_deg[idx]
            norm = np.sqrt(dx**2 + dy**2)

            if norm > 1e-9:
                dx_norm = dx / norm
                dy_norm = dy / norm

                arrow_len_x = dx_norm * (theta1_config_span * 0.08)
                arrow_len_y = dy_norm * (theta1_config_span * 0.08)

                arrow = FancyArrowPatch(
                    posA=(theta1_deg[idx], theta2_deg[idx]),
                    posB=(theta1_deg[idx] + arrow_len_x, theta2_deg[idx] + arrow_len_y),
                    arrowstyle="-|>",
                    mutation_scale=20,
                    color=colors[idx],
                    linewidth=1.5,
                    zorder=10,
                    alpha=0.8,
                )
                ax3.add_patch(arrow)

    # Add diagonal lines for normal modes
    xlim = ax3.get_xlim()
    ylim = ax3.get_ylim()
    max_range = max(abs(np.array([*xlim, *ylim])))

    # In-phase mode line (Œ∏‚ÇÅ = Œ∏‚ÇÇ)
    ax3.plot(
        [-max_range, max_range],
        [-max_range, max_range],
        "g--",
        linewidth=2,
        alpha=0.6,
        label="In-Phase Mode\n$(\\theta_1 = \\theta_2)$",
        zorder=1,
    )

    # Out-of-phase mode line (Œ∏‚ÇÅ = -Œ∏‚ÇÇ)
    ax3.plot(
        [-max_range, max_range],
        [max_range, -max_range],
        "orange",
        linestyle="--",
        linewidth=2,
        alpha=0.6,
        label="Out-of-Phase Mode\n$(\\theta_1 = -\\theta_2)$",
        zorder=1,
    )

    ax3.set_xlabel("$\\theta_1$ (degrees)", fontsize=12, fontweight="bold")
    ax3.set_ylabel("$\\theta_2$ (degrees)", fontsize=12, fontweight="bold")
    ax3.set_title(
        "Configuration Space: $\\theta_2$ vs $\\theta_1$",
        fontsize=13,
        fontweight="bold",
    )
    ax3.grid(True, alpha=0.3, linestyle="--")
    ax3.legend(
        loc="lower right", bbox_to_anchor=(1.53, -0.05), fontsize=9, framealpha=0.9
    )
    ax3.axhline(0, color="black", linewidth=0.8, alpha=0.5)
    ax3.axvline(0, color="black", linewidth=0.8, alpha=0.5)
    ax3.set_aspect("equal", adjustable="box")

    # Add system info box
    sys_info_text = (
        f"System Parameters:\n"
        f"{'‚îÄ' * 16}\n"
        f"Initial Conditions:\n"
        f"$\\theta_1^{{(0)}} = {theta_1_init:+.2f}¬∞$, $\\theta_2^{{(0)}} = {theta_2_init:+.2f}¬∞$\n"
        f"$\\dot{{\\theta}}_1^{{(0)}} = 0.00$ rad/s, $\\dot{{\\theta}}_2^{{(0)}} = 0.00$ rad/s\n"
        f"$m_1 = {m1:.2f}$ kg, $m_2 = {m2:.2f}$ kg\n"
        f"$k = {k:.2f}$ N/m\n"
        f"$L = {L:.2f}$ m\n"
        f"$g = {g:.2f}$ m/s¬≤\n"
        f"$T_{{sim}} = {t[-1]:.1f}$ s\n"
        f"Normal Modes:\n"
        f"$f_1 = {f1_theory:.4f}$ Hz\n"
        f"$f_2 = {f2_theory:.4f}$ Hz"
    )
    ax3.text(
        1.05,
        0.99,
        sys_info_text,
        transform=ax3.transAxes,
        fontsize=9,
        va="top",
        ha="left",
        bbox=dict(
            boxstyle="round",
            facecolor="lightyellow",
            alpha=0.9,
            edgecolor="black",
            linewidth=1.5,
        ),
        fontfamily="monospace",
    )

    # Add colorbar for time
    sm = plt.cm.ScalarMappable(cmap="plasma", norm=plt.Normalize(vmin=0, vmax=t[-1]))  # type: ignore
    sm.set_array([])
    cbar = fig.colorbar(
        sm,
        ax=[ax1, ax2, ax3],
        orientation="horizontal",
        pad=0.15,
        aspect=40,
        shrink=0.5,
    )
    cbar.set_label("Time (s)", fontsize=11, fontweight="bold")

    # Save figure if requested
    if save_fig:
        if filename is None:
            # Auto-generate filename
            filename = (
                f"phase_config_space_"
                f"m1_{m1}_m2_{m2}_k_{k}_L_{L}_"
                f"theta1_{theta_1_init:.1f}_theta2_{theta_2_init:.1f}_"
                f"t_{t[-1]:.0f}s.png"
            )

        # Ensure output directory exists
        output_dir = "OUTPUTS/FIGURES/coupled_pendulum/double_pendulum"
        os.makedirs(output_dir, exist_ok=True)

        save_path = os.path.join(output_dir, filename)

        print(f"\nSaving figure to: {save_path}")
        fig.savefig(save_path, dpi=300, bbox_inches="tight")
        print("Figure saved successfully!")

    plt.show()

    return fig


def analyze_coupled_pendulum_phase_space(
    theta_1_init=10.0,
    theta_2_init=0.0,
    m1=1.0,
    m2=1.0,
    k=5.0,
    L=2.0,
    g=9.81,
    simulation_time=50.0,
    n_points=1000,
    arrow_density=8,
    save_fig=False,
    filename=None,
):
    """
    Solve coupled pendulum equations and plot phase/configuration space.

    Parameters
    ----------
    theta_1_init, theta_2_init : float
        Initial angles in degrees
    m1, m2 : float
        Masses
    k : float
        Spring constant
    L : float
        Pendulum length
    g : float
        Gravitational acceleration
    simulation_time : float
        Simulation duration in seconds
    n_points : int
        Number of time points
    arrow_density : int
        Number of arrows along trajectory (default: 8)
    save_fig : bool, optional
        Whether to save the figure
    filename : str or None, optional
        Filename for saved figure

    Returns
    -------
    fig : matplotlib.figure.Figure
        Figure with phase space plots
    """
    print("=" * 60)
    print("PHASE SPACE ANALYSIS OF COUPLED DOUBLE PENDULUM")
    print("=" * 60)

    # Convert initial angles to radians
    theta1_0 = np.radians(theta_1_init)
    theta2_0 = np.radians(theta_2_init)

    # Initial state: [Œ∏‚ÇÅ, œâ‚ÇÅ, Œ∏‚ÇÇ, œâ‚ÇÇ]
    y0 = [theta1_0, 0.0, theta2_0, 0.0]

    # Time span
    t_span = (0, simulation_time)
    t_eval = np.linspace(0, simulation_time, n_points)

    # Solve the ODE
    print("\nSolving differential equations...")
    print(f"  Initial angles: Œ∏‚ÇÅ={theta_1_init}¬∞, Œ∏‚ÇÇ={theta_2_init}¬∞")
    print(f"  Simulation time: {simulation_time} s")
    print(f"  Number of points: {n_points}")

    solution = solve_ivp(
        coupled_pendulum_derivatives,
        t_span,
        y0,
        args=(g, L, k, m1, m2),
        method="RK45",
        t_eval=t_eval,
        rtol=1e-6,
        atol=1e-8,
    )

    theta1 = solution.y[0]
    omega1 = solution.y[1]
    theta2 = solution.y[2]
    omega2 = solution.y[3]
    t = solution.t

    print("Solution computed successfully!")

    # Calculate theoretical frequencies
    omega1_th, omega2_th, f1_th, f2_th = th_normal_modes_double_pendulum(
        g, L, k, m1, m2
    )

    print("\nTheoretical Normal Mode Frequencies:")
    print(f"  œâ‚ÇÅ = {omega1_th:.4f} rad/s  ‚Üí  f‚ÇÅ = {f1_th:.4f} Hz (in-phase mode)")
    print(f"  œâ‚ÇÇ = {omega2_th:.4f} rad/s  ‚Üí  f‚ÇÇ = {f2_th:.4f} Hz (out-of-phase mode)")

    print("\nPlotting phase and configuration space...")
    # Plot phase space
    fig = plot_phase_and_config_space(
        t,
        theta1,
        omega1,
        theta2,
        omega2,
        g,
        L,
        k,
        m1,
        m2,
        theta_1_init,
        theta_2_init,
        arrow_density=arrow_density,
        save_fig=save_fig,
        filename=filename,
    )

    print("\nPhase space analysis complete!")

    return fig


# Run the analysis
fig = analyze_coupled_pendulum_phase_space(
    theta_1_init=10.0,
    theta_2_init=0.0,
    m1=1.0,
    m2=1.0,
    k=5.0,
    L=2.0,
    g=9.81,
    simulation_time=25.0,
    n_points=1000,
    arrow_density=20,
    save_fig=True,
    filename="double_pendulum_phase_space_beats.png",
)


## 2. Three Spring-Coupled Pendulums (Nonlinear Model)

Before implementing any code, we need to derive the equations of motion for three spring-coupled pendulums. Below we present the full derivation using the Lagrangian formalism.

Assume three planar pendulums of equal length $L$ with bob masses $m_1,m_2,m_3$. Their generalized coordinates are the angles from the downward vertical: $\theta_1(t),\theta_2(t),\theta_3(t)$. Each bob is coupled to its neighbor by a spring: one spring between bobs (1‚Äì2) with constant $k_{12}$ and one between (2‚Äì3) with constant $k_{23}$. I‚Äôll use the same coupling model as in your 2-pendulum derivation: each spring‚Äôs extension is proportional to the **horizontal separation** of the bobs, i.e. it depends on $\sin\theta$ (no small-angle approximation unless stated).


**1. Kinetic energy (exact)**

For pendulum $i$, the bob position relative to its pivot is
$$
x_i=L\sin\theta_i,\qquad y_i=-L\cos\theta_i.
$$
Differentiate:
$$
\dot x_i=L\cos\theta_i\,\dot\theta_i,\qquad \dot y_i=L\sin\theta_i\,\dot\theta_i,
$$
so
$$
v_i^2=\dot x_i^2+\dot y_i^2=L^2\dot\theta_i^2.
$$
Therefore the total kinetic energy is
$$
\boxed{
T=\frac12\sum_{i=1}^3 m_i L^2 \dot\theta_i^{\,2}
=\frac12 m_1L^2\dot\theta_1^2+\frac12 m_2L^2\dot\theta_2^2+\frac12 m_3L^2\dot\theta_3^2.
}
$$


**2. Potential energy (exact)**

**(a) Gravitational potential**

Using the zero reference at $\theta_i=0$ (lowest point), each bob rises by $\Delta h_i=L(1-\cos\theta_i)$, hence

$$
\boxed{
V_g=\sum_{i=1}^3 m_i gL(1-\cos\theta_i).
}
$$

**(b) Spring potential (neighbor coupling, horizontal-extension model)**

Horizontal displacement of bob $i$ is $x_i=L\sin\theta_i$, so the horizontal separation changes are
$$
\Delta x_{12}=x_2-x_1=L(\sin\theta_2-\sin\theta_1),\qquad
\Delta x_{23}=x_3-x_2=L(\sin\theta_3-\sin\theta_2).
$$
Taking equilibrium spring extension to be zero in this coordinate (same convention as your 2-pendulum model),
$$
\boxed{
V_s=\frac12 k_{12}\Delta x_{12}^2+\frac12 k_{23}\Delta x_{23}^2
=\frac12 k_{12}L^2(\sin\theta_2-\sin\theta_1)^2+\frac12 k_{23}L^2(\sin\theta_3-\sin\theta_2)^2.
}
$$

Total potential:
$$
\boxed{V=V_g+V_s.}
$$


**3. Lagrangian**
$$
\boxed{\mathcal{L}=T-V}
$$

Explicitly:

$$
\boxed{
\mathcal{L}=
\frac12\sum_{i=1}^3 m_i L^2\dot\theta_i^{\,2}
-\sum_{i=1}^3 m_i gL(1-\cos\theta_i)
-\frac12 k_{12}L^2(\sin\theta_2-\sin\theta_1)^2
-\frac12 k_{23}L^2(\sin\theta_3-\sin\theta_2)^2
}
$$


**4. Euler‚ÄìLagrange equations (no small-angle approximation)**

For each $i=1,2,3$:

$$
\frac{d}{dt}\left(\frac{\partial\mathcal{L}}{\partial\dot\theta_i}\right)-\frac{\partial\mathcal{L}}{\partial\theta_i}=0.
$$

**Useful derivatives**
From the kinetic terms:
$$
\begin{aligned}
& \frac{\partial\mathcal{L}}{\partial\dot\theta_i}=m_iL^2\dot\theta_i \\
\Rightarrow & \frac{d}{dt}\left(\frac{\partial\mathcal{L}}{\partial\dot\theta_i}\right)=m_iL^2\ddot\theta_i.
\end{aligned}
$$

From gravity:
$$
\frac{\partial V_g}{\partial\theta_i}=m_i gL\sin\theta_i.
$$

Now differentiate the spring terms carefully.

**Spring (1‚Äì2)**

Let $s_{12}=\sin\theta_2-\sin\theta_1$. Then
$$
V_{12}=\frac12 k_{12}L^2 s_{12}^2.
$$
So
$$
\begin{aligned}
\frac{\partial V_{12}}{\partial\theta_1}
&=\frac12 k_{12}L^2\cdot 2 s_{12}\cdot\frac{\partial}{\partial\theta_1}(\sin\theta_2-\sin\theta_1)
= k_{12}L^2 s_{12}(-\cos\theta_1)
= -k_{12}L^2(\sin\theta_2-\sin\theta_1)\cos\theta_1\\
\frac{\partial V_{12}}{\partial\theta_2}
&=k_{12}L^2(\sin\theta_2-\sin\theta_1)\cos\theta_2.
\end{aligned}
$$

**Spring (2‚Äì3)**

Let $s_{23}=\sin\theta_3-\sin\theta_2$. Then
$$
V_{23}=\frac12 k_{23}L^2 s_{23}^2,
$$
so
$$
\begin{aligned}
\frac{\partial V_{23}}{\partial\theta_2}
&= k_{23}L^2 s_{23}\cdot\frac{\partial}{\partial\theta_2}(\sin\theta_3-\sin\theta_2)
= k_{23}L^2 s_{23}(-\cos\theta_2)
= -k_{23}L^2(\sin\theta_3-\sin\theta_2)\cos\theta_2\\
\frac{\partial V_{23}}{\partial\theta_3}
&= k_{23}L^2(\sin\theta_3-\sin\theta_2)\cos\theta_3
\end{aligned}
$$

**Put into Euler‚ÄìLagrange**
Since $\mathcal{L}=T-V$, the equations can be written compactly as
$$
m_iL^2\ddot\theta_i+\frac{\partial V}{\partial\theta_i}=0.
$$

So the **exact nonlinear equations of motion** are:
$$
\begin{aligned}
m_1L^2\ddot\theta_1+m_1 gL\sin\theta_1
-k_{12}L^2(\sin\theta_2-\sin\theta_1)\cos\theta_1&=0\\
m_2L^2\ddot\theta_2+m_2 gL\sin\theta_2
+k_{12}L^2(\sin\theta_2-\sin\theta_1)\cos\theta_2
-k_{23}L^2(\sin\theta_3-\sin\theta_2)\cos\theta_2&=0\\
m_3L^2\ddot\theta_3+m_3 gL\sin\theta_3
+k_{23}L^2(\sin\theta_3-\sin\theta_2)\cos\theta_3&=0
\end{aligned}
$$

Equivalently, solving for angular accelerations:
$$
\boxed{
\begin{aligned}
\ddot\theta_1
&=
-\frac{g}{L}\sin\theta_1
+\frac{k_{12}}{m_1}(\sin\theta_2-\sin\theta_1)\cos\theta_1\\
\ddot\theta_2
&=
-\frac{g}{L}\sin\theta_2
-\frac{k_{12}}{m_2}(\sin\theta_2-\sin\theta_1)\cos\theta_2
+\frac{k_{23}}{m_2}(\sin\theta_3-\sin\theta_2)\cos\theta_2\\
\ddot\theta_3
&=
-\frac{g}{L}\sin\theta_3
-\frac{k_{23}}{m_3}(\sin\theta_3-\sin\theta_2)\cos\theta_3
\end{aligned}
}
$$


**5. Small-angle approximation (linearized equations)**

For $|\theta_i|\ll 1$, use $\sin\theta_i\approx \theta_i$ and $\cos\theta_i\approx 1$. Then
$$
(\sin\theta_2-\sin\theta_1)\cos\theta_1 \approx (\theta_2-\theta_1),\qquad
(\sin\theta_3-\sin\theta_2)\cos\theta_2 \approx (\theta_3-\theta_2),
$$
and the system becomes linear:
$$
\boxed{
\begin{aligned}
\ddot\theta_1+\left(\frac{g}{L}+\frac{k_{12}}{m_1}\right)\theta_1-\frac{k_{12}}{m_1}\theta_2=0,\\
\ddot\theta_2-\frac{k_{12}}{m_2}\theta_1+\left(\frac{g}{L}+\frac{k_{12}+k_{23}}{m_2}\right)\theta_2-\frac{k_{23}}{m_2}\theta_3=0\\
\ddot\theta_3-\frac{k_{23}}{m_3}\theta_2+\left(\frac{g}{L}+\frac{k_{23}}{m_3}\right)\theta_3=0
\end{aligned}
}
$$

**Matrix / normal-mode form**
Write $\boldsymbol{\theta}=[\theta_1,\theta_2,\theta_3]^T$. One convenient generalized-eigenvalue form is
$$
M\ddot{\boldsymbol{\theta}}+K\boldsymbol{\theta}=0,
$$
with
$$
M=\mathrm{diag}(m_1,m_2,m_3),
$$
$$
K=
\begin{bmatrix}
m_1\omega_0^2+k_{12} & -k_{12} & 0\\
-k_{12} & m_2\omega_0^2+k_{12}+k_{23} & -k_{23}\\
0 & -k_{23} & m_3\omega_0^2+k_{23}
\end{bmatrix},
\qquad
\omega_0^2=\frac{g}{L}.
$$
Assume $\boldsymbol{\theta}(t)=\mathbf{a}e^{i\omega t}$, then
$$
(K-\omega^2 M)\mathbf{a}=0,
\qquad
\boxed{\det(K-\omega^2 M)=0.}
$$
This gives three normal-mode frequencies $\omega_1,\omega_2,\omega_3$ (for general unequal parameters the closed form is a cubic and is usually solved numerically).


**6. Small-angle solutions (explicit, clean case: equal masses and equal springs)**

Take the common case
$$
m_1=m_2=m_3=m,\qquad k_{12}=k_{23}=k.
$$
Then the normal-mode frequencies become especially simple. Define $\omega_0^2=g/L$. The three mode frequencies are
$$
\boxed{
\omega_1^2=\omega_0^2,
\qquad
\omega_2^2=\omega_0^2+\frac{k}{m},
\qquad
\omega_3^2=\omega_0^2+\frac{3k}{m}.}
$$

Corresponding (unnormalized) mode shapes (eigenvectors) are:
- **Mode 1 (all in-phase):**
  $$
  \mathbf{v}_1=\begin{bmatrix}1\\1\\1\end{bmatrix}
  \quad\Rightarrow\quad
  \theta_1=\theta_2=\theta_3
  $$
  (springs are not stretched in the linear approximation, so the frequency stays $\sqrt{g/L}$).

- **Mode 2 (‚Äúouter opposite, middle nearly stationary‚Äù):**
  $$
  \mathbf{v}_2=\begin{bmatrix}1\\0\\-1\end{bmatrix}
  \quad\Rightarrow\quad
  \theta_2\approx 0,\;\theta_1=-\theta_3
  $$

- **Mode 3 (outer in-phase, middle opposite with larger amplitude):**
  $$
  \mathbf{v}_3=\begin{bmatrix}1\\-2\\1\end{bmatrix}
  \quad\Rightarrow\quad
  \theta_2=-2\theta_1=-2\theta_3
  $$

**General small-angle time-domain solution (superposition)**

Any small-angle motion can be written as a sum of modes:

$$
\boxed{
\boldsymbol{\theta}(t) = A_1\mathbf{v}_1\cos(\omega_1 t+\phi_1) + A_2\mathbf{v}_2\cos(\omega_2 t+\phi_2) + A_3\mathbf{v}_3\cos(\omega_3 t+\phi_3)
}
$$

where the constants $A_j,\phi_j$ are fixed by initial conditions.

**Mode interpretation (physics):**
- Pure mode excitation $\Rightarrow$ all bobs oscillate with a single $\omega_j$ and a fixed amplitude ratio (no ‚Äúenergy exchange‚Äù pattern in the linear model).
- Mixed excitation (more than one $A_j\neq 0$) $\Rightarrow$ the angles show multi-frequency motion; if two frequencies are close enough, you can see **beating/energy transfer** between groups of pendulums via slow envelopes.


### 2.1 Geometry and Coordinates (Three Pendulums)

We extend the geometry routines to a three-pendulum chain with nearest-neighbor spring couplings, so that the visual model matches the ODE model used in the simulation.

In [None]:
def calculate_fixed_hook_length_three_pendulum(
    ceiling_length, string_length, separation_ratio, hook_ratio=0.12
):
    """
    Calculate the fixed hook length for three-pendulum system based on the relaxed state (theta=0).
    Returns hook length based on the distance between adjacent masses.
    """
    # Geometry in relaxed state
    anchor_separation = separation_ratio * ceiling_length
    left_anchor_x = -anchor_separation
    middle_anchor_x = 0

    # Positions at rest (theta = 0)
    left_mass_x, _ = calculate_pendulum_mass_positions(left_anchor_x, string_length, 0)
    middle_mass_x, _ = calculate_pendulum_mass_positions(
        middle_anchor_x, string_length, 0
    )

    # Distance between adjacent masses at rest
    relaxed_dist = np.abs(middle_mass_x - left_mass_x)

    # Calculate hook length
    relaxed_spring_len = relaxed_dist / (1 + 2 * hook_ratio)
    hook_len = hook_ratio * relaxed_spring_len

    return hook_len


def calculate_plot_limits_3(
    ceiling_length,
    left_mass_x,
    right_mass_x,
    left_mass_y,
    right_mass_y,
    padding=0.15,
    middle_mass_x=None,
    middle_mass_y=None,
):
    """Calculate appropriate plot limits for 2 or 3 pendulum system."""
    mass_x_values = [left_mass_x, right_mass_x]
    mass_y_values = [left_mass_y, right_mass_y]

    if middle_mass_x is not None and middle_mass_y is not None:
        mass_x_values.append(middle_mass_x)
        mass_y_values.append(middle_mass_y)

    x_min = min(-ceiling_length / 2, *mass_x_values)
    x_max = max(ceiling_length / 2, *mass_x_values)
    y_min = min(mass_y_values)
    y_max = 0

    x_range = x_max - x_min
    y_range = y_max - y_min

    x_pad = x_range * padding
    y_pad = y_range * padding

    xlim = (x_min - x_pad, x_max + x_pad)
    ylim = (y_min - y_pad, y_max + y_pad)

    return xlim, ylim


def draw_three_coupled_pendulum(
    theta_1=0,
    theta_2=0,
    theta_3=10,
    ceiling_length=12,
    string_length=12,
    separation_ratio=0.6,
    num_coils=8,
    mass_size=25,
    figsize=(12, 8),
    padding=0.15,
):
    """
    Draw a three coupled pendulum system with two springs.
    m1 connected to m2 by spring 1, m2 connected to m3 by spring 2.

    Parameters:
    -----------
    theta_1, theta_2, theta_3 : float
        Angular displacements in degrees for the three pendulums
    ceiling_length : float
        Length of the ceiling support
    string_length : float
        Length of each pendulum string
    separation_ratio : float
        Ratio determining spacing between anchors
    num_coils : int
        Number of coils in each spring
    mass_size : float
        Size of the mass markers
    figsize : tuple
        Figure size
    padding : float
        Padding around the plot
    """
    # Calculate system dimensions
    ceiling_height = 0.05 * ceiling_length
    anchor_separation = separation_ratio * ceiling_length

    # Calculate anchor positions (evenly spaced)
    left_anchor_x = -anchor_separation
    middle_anchor_x = 0
    right_anchor_x = anchor_separation
    anchor_y = 0

    # Calculate mass positions
    left_mass_x, left_mass_y = calculate_pendulum_mass_positions(
        left_anchor_x, string_length, theta_1
    )
    middle_mass_x, middle_mass_y = calculate_pendulum_mass_positions(
        middle_anchor_x, string_length, theta_2
    )
    right_mass_x, right_mass_y = calculate_pendulum_mass_positions(
        right_anchor_x, string_length, theta_3
    )

    # Calculate FIXED hook length based on relaxed state
    hook_length = calculate_fixed_hook_length_three_pendulum(
        ceiling_length, string_length, separation_ratio
    )

    # Fixed spring radius
    spring_radius = 0.03 * ceiling_length

    # Create figure
    fig, ax = plt.subplots(figsize=figsize)

    # Draw ceiling (extended for three pendulums)
    ceiling_width = 2 * anchor_separation + ceiling_length * 0.3
    draw_ceiling(ax, ceiling_width, ceiling_height)

    # Drawing the reference equilibrium lines (dashed)
    ax.axvline(x=left_anchor_x, color="gray", linestyle="--", lw=1, alpha=0.7)
    ax.axvline(x=middle_anchor_x, color="gray", linestyle="--", lw=1, alpha=0.7)
    ax.axvline(x=right_anchor_x, color="gray", linestyle="--", lw=1, alpha=0.7)

    # Draw pendulum strings
    draw_pendulum_string(ax, left_anchor_x, anchor_y, left_mass_x, left_mass_y)
    draw_pendulum_string(ax, middle_anchor_x, anchor_y, middle_mass_x, middle_mass_y)
    draw_pendulum_string(ax, right_anchor_x, anchor_y, right_mass_x, right_mass_y)

    # Draw first spring (m1 to m2)
    x_spring1, y_spring1 = draw_spring_with_hook(
        (left_mass_x, left_mass_y),
        (middle_mass_x, middle_mass_y),
        num_coils=num_coils,
        radius=spring_radius,
        hook_length=hook_length,
    )
    ax.plot(x_spring1, y_spring1, color="darkgrey", lw=2)

    # Draw second spring (m2 to m3)
    x_spring2, y_spring2 = draw_spring_with_hook(
        (middle_mass_x, middle_mass_y),
        (right_mass_x, right_mass_y),
        num_coils=num_coils,
        radius=spring_radius,
        hook_length=hook_length,
    )
    ax.plot(x_spring2, y_spring2, color="darkgrey", lw=2)

    # Draw masses (different colors for each)
    ax.plot(left_mass_x, left_mass_y, "ro", ms=mass_size, label="$m_1$")
    ax.plot(middle_mass_x, middle_mass_y, "go", ms=mass_size, label="$m_2$")
    ax.plot(right_mass_x, right_mass_y, "bo", ms=mass_size, label="$m_3$")

    # Calculate and set plot limits
    xlim, ylim = calculate_plot_limits_3(
        ceiling_length,
        left_mass_x,
        right_mass_x,
        left_mass_y,
        right_mass_y,
        padding,
        middle_mass_x,
        middle_mass_y,
    )
    ax.set_xlim(xlim)
    ax.set_ylim(ylim)
    ax.set_xticks([])
    ax.set_yticks([])
    ax.set_xticklabels([])
    ax.set_yticklabels([])
    ax.set_title(
        f"Three Coupled Pendulums: $\\theta_1 = {theta_1}^\\circ, \\theta_2 = {theta_2}^\\circ, \\theta_3 = {theta_3}^\\circ$"
    )
    ax.legend(loc="upper right", ncol=3, fontsize=8, markerscale=0.5)

    return fig, ax


fig, ax = draw_three_coupled_pendulum(theta_1=-10, theta_2=0, theta_3=10)
plt.show()


### 2.2 Numerical Integration and Animation (Three Pendulums)


As in the two-pendulum case, we express the dynamics as a first-order ODE system and integrate it numerically. We then compare theoretical normal-mode predictions (small-angle linearization) against dominant frequencies extracted from the simulated motion.

#### 2.2.1 Equations of Motion (State-Space Form)

In [None]:
def three_coupled_pendulum_derivatives(t, y, g, L, k1, k2, m1, m2, m3):
    """
    Compute the derivatives for the three coupled pendulum system.

    Uses the exact non-linear equations of motion:
    Œ∏Ãà‚ÇÅ = -(g/L)sin(Œ∏‚ÇÅ) + (k‚ÇÅ/m‚ÇÅ)(sin(Œ∏‚ÇÇ) - sin(Œ∏‚ÇÅ))cos(Œ∏‚ÇÅ)
    Œ∏Ãà‚ÇÇ = -(g/L)sin(Œ∏‚ÇÇ) - (k‚ÇÅ/m‚ÇÇ)(sin(Œ∏‚ÇÇ) - sin(Œ∏‚ÇÅ))cos(Œ∏‚ÇÇ) + (k‚ÇÇ/m‚ÇÇ)(sin(Œ∏‚ÇÉ) - sin(Œ∏‚ÇÇ))cos(Œ∏‚ÇÇ)
    Œ∏Ãà‚ÇÉ = -(g/L)sin(Œ∏‚ÇÉ) - (k‚ÇÇ/m‚ÇÉ)(sin(Œ∏‚ÇÉ) - sin(Œ∏‚ÇÇ))cos(Œ∏‚ÇÉ)

    Parameters
    ----------
    t : float
        Time (not used explicitly, required by solve_ivp)
    y : array-like
        State vector [Œ∏‚ÇÅ, œâ‚ÇÅ, Œ∏‚ÇÇ, œâ‚ÇÇ, Œ∏‚ÇÉ, œâ‚ÇÉ] where œâ = dŒ∏/dt
    g : float
        Gravitational acceleration
    L : float
        Length of pendulum rods
    k1, k2 : float
        Spring constants (k1 connects m1-m2, k2 connects m2-m3)
    m1, m2, m3 : float
        Masses of the three pendulum bobs

    Returns
    -------
    list
        Derivatives [dŒ∏‚ÇÅ/dt, dœâ‚ÇÅ/dt, dŒ∏‚ÇÇ/dt, dœâ‚ÇÇ/dt, dŒ∏‚ÇÉ/dt, dœâ‚ÇÉ/dt]
    """
    theta1, omega1, theta2, omega2, theta3, omega3 = y

    # Spring extension terms
    delta_sin_12 = np.sin(theta2) - np.sin(theta1)
    delta_sin_23 = np.sin(theta3) - np.sin(theta2)

    # Angular accelerations (exact non-linear equations)
    d_theta1 = omega1
    d_omega1 = -(g / L) * np.sin(theta1) + (k1 / m1) * delta_sin_12 * np.cos(theta1)

    d_theta2 = omega2
    d_omega2 = (
        -(g / L) * np.sin(theta2)
        - (k1 / m2) * delta_sin_12 * np.cos(theta2)
        + (k2 / m2) * delta_sin_23 * np.cos(theta2)
    )

    d_theta3 = omega3
    d_omega3 = -(g / L) * np.sin(theta3) - (k2 / m3) * delta_sin_23 * np.cos(theta3)

    return [d_theta1, d_omega1, d_theta2, d_omega2, d_theta3, d_omega3]


#### 2.2.2 Linearization and Theoretical Normal-Mode Frequencies


The function `th_normal_modes_triple_pendulum` computes ‚Äútheoretical‚Äù normal‚Äëmode frequencies using the **linearized (small‚Äëangle) model**:
- It forms a $3\times 3$ **dynamical matrix** (already in acceleration form) corresponding to
  $$
  \ddot{\boldsymbol{\theta}} = -A\,\boldsymbol{\theta},
  $$
  where $\omega_0^2=g/L$ and the coupling terms scale like $k_1/m_i$, $k_2/m_i$.
- It then takes eigenvalues $\lambda_j$ of that matrix and returns $\omega_j=\sqrt{\lambda_j}$ (sorted), plus $f_j=\omega_j/(2\pi)$.

This matches the idea that, in the linear regime, each normal mode oscillates sinusoidally at a single frequency.

In [None]:
def th_normal_modes_triple_pendulum(g, L, k1, k2, m1, m2, m3):
    """
    Calculate the theoretical normal mode frequencies for the linearized three-pendulum system.

    For equal masses (m1 = m2 = m3 = m) and equal springs (k1 = k2 = k):
        œâ‚ÇÅ¬≤ = g/L               (sloshing mode - all move together)
        œâ‚ÇÇ¬≤ = g/L + k/m         (anti-symmetric mode)
        œâ‚ÇÉ¬≤ = g/L + 3k/m        (breathing mode)

    For unequal masses/springs, we solve the eigenvalue problem.

    Parameters
    ----------
    g, L, k1, k2, m1, m2, m3 : float
        System parameters

    Returns
    -------
    tuple
        (omega1, omega2, omega3, f1, f2, f3) - angular frequencies and frequencies in Hz
    """
    # Construct the dynamical matrix for the linearized system
    # The linearized equations give: M*Œ∏Ãà = -K*Œ∏
    # where Œ∏ = [Œ∏‚ÇÅ, Œ∏‚ÇÇ, Œ∏‚ÇÉ]

    # Stiffness matrix K (from potential energy)
    omega0_sq = g / L

    K = np.array(
        [
            [omega0_sq + k1 / m1, -k1 / m1, 0],
            [-k1 / m2, omega0_sq + (k1 + k2) / m2, -k2 / m2],
            [0, -k2 / m3, omega0_sq + k2 / m3],
        ]
    )

    # Eigenvalues give œâ¬≤
    eigenvalues = np.linalg.eigvals(K)

    # Sort eigenvalues (œâ¬≤ values)
    omega_sq = np.sort(np.real(eigenvalues))

    omega1 = np.sqrt(omega_sq[0])  # Lowest frequency (sloshing)
    omega2 = np.sqrt(omega_sq[1])  # Middle frequency (anti-symmetric)
    omega3 = np.sqrt(omega_sq[2])  # Highest frequency (breathing)

    # Convert to Hz
    f1 = omega1 / (2 * np.pi)
    f2 = omega2 / (2 * np.pi)
    f3 = omega3 / (2 * np.pi)

    return omega1, omega2, omega3, f1, f2, f3


#### 2.2.3 Dominant Frequencies from Numerical Data

 
The function `est_freqs_triple_pendulum` estimates the three dominant mode frequencies from the simulated $\theta_1(t),\theta_2(t),\theta_3(t)$ by:

- **Generalized eigenproblem (linear model):** builds $M=\mathrm{diag}(m_1,m_2,m_3)$ and a stiffness matrix $K$ for the small‚Äëangle system and solves
  $$
  K\mathbf{v}=\omega^2 M\mathbf{v}
  $$
  using `eigh(K, M)` (good here because $K,M$ are symmetric).
- **Modal projection:** forms modal coordinates
  $$
  \mathbf{q}(t)=V^T M\,\boldsymbol{\theta}(t),
  $$
  so each $q_i(t)$ ideally isolates one mode.  
- **FFT peak pick:** applies an optional Hann window, computes `rfft` of each $q_i(t)$, and chooses the largest non‚ÄëDC spectral bin as the ‚Äúdominant‚Äù frequency for that mode; it also uses an RMS threshold to report modes as ‚Äúnot excited‚Äù if their spectral energy is tiny.

In [None]:
def est_freqs_triple_pendulum(
    t, theta1, theta2, theta3, g, L, k1, k2, m1, m2, m3, use_hann_window=True
):
    """
    Estimate the 3 normal-mode frequencies for unequal masses/springs by:
    1) solving the *generalized* eigenproblem K v = œâ¬≤ M v (small-angle linear model)
    2) projecting simulated Œ∏(t) onto the eigenvectors (modal coordinates)
    3) FFT on each modal coordinate and picking the dominant peak

    Returns
    -------
    (f1, f2, f3) in Hz, ordered low -> high (matching œâ1 <= œâ2 <= œâ3).
    """
    t = np.asarray(t)
    dt = t[1] - t[0]
    n = t.size

    # Stack angles: shape (3, n)
    theta = np.vstack([np.asarray(theta1), np.asarray(theta2), np.asarray(theta3)])

    # Remove DC offset per channel
    theta = theta - theta.mean(axis=1, keepdims=True)
    omega0_sq = g / L

    M = np.diag([m1, m2, m3]).astype(float)
    K = np.array(
        [
            [m1 * omega0_sq + k1, -k1, 0.0],
            [-k1, m2 * omega0_sq + k1 + k2, -k2],
            [0.0, -k2, m3 * omega0_sq + k2],
        ],
        dtype=float,
    )

    omega_sq, V = eigh(K, M)
    omega_sq = np.maximum(omega_sq, 0.0)
    q = V.T @ (M @ theta)

    # Optional Hann window to reduce spectral leakage
    if use_hann_window:
        w = np.hanning(n)
        q = q * w

    # FFT on each modal coordinate
    freqs = np.fft.rfftfreq(n, dt)
    Q = np.fft.rfft(q, axis=1)
    amp = np.abs(Q)

    rms = np.sqrt(np.mean(amp**2, axis=1))
    rms_rel = rms / (np.max(rms) + 1e-30)

    f_est = []
    for i in range(3):
        if rms_rel[i] < 1e-2:
            f_est.append(0.0)
            continue

        idx = int(np.argmax(amp[i, 1:])) + 1
        f_est.append(float(freqs[idx]))

    return f_est[0], f_est[1], f_est[2]

#### 2.2.4 Animation of the Three-Pendulum System

**Animation (three spring‚Äëcoupled pendulums)**  
The triple‚Äëpendulum animation is implemented in `simulate_three_coupled_pendulum`. It follows the same overall structure as the double‚Äëpendulum animation, with these key steps:

- **Simulate first, animate second:** it integrates the nonlinear ODEs with `solve_ivp` using `three_coupled_pendulum_derivatives`, producing time series $\theta_i(t)$ and $\omega_i(t)$ on a uniform grid `t_eval`.
- **Precompute per-frame geometry/energies:** it computes bob positions $x_i=p_i+L\sin\theta_i$, $y_i=-L\cos\theta_i$, spring extensions $\Delta x_{12}=L(\sin\theta_2-\sin\theta_1)$ and $\Delta x_{23}=L(\sin\theta_3-\sin\theta_2)$, plus energies, *before* animation so the frame update stays fast.
- **Matplotlib ‚Äúartists‚Äù + update loop:** it creates line objects for rods/springs and marker objects for masses; then `animate(frame)` updates those artists using the precomputed arrays, and draws the two springs by calling `draw_spring_with_hook` between bob endpoints.
- **FuncAnimation:** it wraps `animate` with `FuncAnimation(fig, animate, frames=len(t_eval), interval=1000/fps, ...)` and optionally saves via ffmpeg/pillow.


In [None]:
def simulate_three_coupled_pendulum(
    theta_1_init=0.0,
    theta_2_init=0.0,
    theta_3_init=10.0,
    m1=1.0,
    m2=1.0,
    m3=1.0,
    k1=5.0,
    k2=5.0,
    L=2.0,
    g=9.81,
    simulation_time=20.0,
    fps=30,
    save_anim=False,
    filename=None,
):
    """
    Simulate and animate a three coupled pendulum system.

    Parameters
    ----------
    theta_1_init, theta_2_init, theta_3_init : float
        Initial angles in degrees
    m1, m2, m3 : float
        Masses of pendulum bobs
    k1, k2 : float
        Spring constants (k1 connects m1-m2, k2 connects m2-m3)
    L : float
        Length of pendulum rods
    g : float
        Gravitational acceleration
    simulation_time : float
        Total simulation time in seconds
    fps : int
        Frames per second for animation
    save_anim : bool
        Whether to save the animation
    filename : str or None
        Filename for saved animation (auto-generated if None)

    Returns
    -------
    tuple
        (fig, anim) - Figure and animation objects
    """

    # Convert initial angles to radians
    theta1_0 = np.radians(theta_1_init)
    theta2_0 = np.radians(theta_2_init)
    theta3_0 = np.radians(theta_3_init)

    # Initial state: [Œ∏‚ÇÅ, œâ‚ÇÅ, Œ∏‚ÇÇ, œâ‚ÇÇ, Œ∏‚ÇÉ, œâ‚ÇÉ] (starting from rest)
    y0 = [theta1_0, 0.0, theta2_0, 0.0, theta3_0, 0.0]

    # Time span
    t_span = (0, simulation_time)
    n_frames = int(simulation_time * fps) + 1
    t_eval = np.linspace(0, simulation_time, n_frames)

    # Solve the ODE
    print("Solving differential equations...")
    solution = solve_ivp(
        three_coupled_pendulum_derivatives,
        t_span,
        y0,
        args=(g, L, k1, k2, m1, m2, m3),
        method="RK45",
        rtol=1e-5,
        atol=1e-7,
        dense_output=True,
    )

    if solution.sol is None:
        raise RuntimeError(
            "solve_ivp did not return a dense solution (solution.sol is None)"
        )

    y = solution.sol(t=t_eval)
    theta1 = y[0]
    omega1 = y[1]
    theta2 = y[2]
    omega2 = y[3]
    theta3 = y[4]
    omega3 = y[5]

    print(f"Solution computed: {len(t_eval)} time steps")

    # =======================================
    # PRECOMPUTE SPRING EXTENSION + ENERGIES
    # =======================================

    theta1_deg_all = np.degrees(theta1)
    theta2_deg_all = np.degrees(theta2)
    theta3_deg_all = np.degrees(theta3)

    spring_ext_all_1 = L * (np.sin(theta2) - np.sin(theta1))
    spring_ext_all_2 = L * (np.sin(theta3) - np.sin(theta2))

    KE_all = (
        0.5 * m1 * (L * omega1) ** 2
        + 0.5 * m2 * (L * omega2) ** 2
        + 0.5 * m3 * (L * omega3) ** 2
    )
    PE_grav_all = (
        -m1 * g * L * np.cos(theta1)
        - m2 * g * L * np.cos(theta2)
        - m3 * g * L * np.cos(theta3)
    )
    PE_spring_all = 0.5 * k1 * spring_ext_all_1**2 + 0.5 * k2 * spring_ext_all_2**2
    total_E_all = KE_all + PE_grav_all + PE_spring_all

    # =========================
    # FREQUENCY ANALYSIS
    # =========================

    # Theoretical normal modes
    omega1_theory, omega2_theory, omega3_theory, f1_theory, f2_theory, f3_theory = (
        th_normal_modes_triple_pendulum(g, L, k1, k2, m1, m2, m3)
    )

    # Numerical frequency estimation
    f1_num, f2_num, f3_num = est_freqs_triple_pendulum(
        t_eval, theta1, theta2, theta3, g, L, k1, k2, m1, m2, m3
    )

    def _fmt_freq(f):
        return f"{f:.4f} Hz" if f > 0 else "N/A (Not excited)"

    f1_str = _fmt_freq(f1_num)
    f2_str = _fmt_freq(f2_num)
    f3_str = _fmt_freq(f3_num)

    print("\nNormal Mode Frequencies:")
    print(f"  Theoretical: œâ‚ÇÅ = {omega1_theory:.4f} rad/s (f‚ÇÅ = {f1_theory:.4f} Hz)")
    print(f"               œâ‚ÇÇ = {omega2_theory:.4f} rad/s (f‚ÇÇ = {f2_theory:.4f} Hz)")
    print(f"               œâ‚ÇÉ = {omega3_theory:.4f} rad/s (f‚ÇÉ = {f3_theory:.4f} Hz)")
    print(f"  Numerical:   f‚ÇÅ ‚âà {f1_str}, f‚ÇÇ ‚âà {f2_str}, f‚ÇÉ ‚âà {f3_str}")

    # =======================
    # ANIMATION SETUP
    # =======================

    # Pivot positions (evenly spaced)
    pivot_separation = 1.5  # Distance between adjacent pivots
    pivot1_x = -pivot_separation
    pivot2_x = 0
    pivot3_x = pivot_separation
    pivot_y = 0

    ceiling_length = 2 * pivot_separation + 0.5
    separation_ratio = pivot_separation / ceiling_length
    spring_hook_length = calculate_fixed_hook_length_three_pendulum(
        ceiling_length, L, separation_ratio, hook_ratio=0.15
    )
    spring_radius = 0.02 * ceiling_length
    spring_num_coils = 10
    ceiling_height = 0.05 * ceiling_length

    # Precompute mass positions
    x1_all = pivot1_x + L * np.sin(theta1)
    y1_all = pivot_y - L * np.cos(theta1)
    x2_all = pivot2_x + L * np.sin(theta2)
    y2_all = pivot_y - L * np.cos(theta2)
    x3_all = pivot3_x + L * np.sin(theta3)
    y3_all = pivot_y - L * np.cos(theta3)

    # Figure setup
    fig, ax = plt.subplots(figsize=(12, 8))
    fig.subplots_adjust(top=0.88, bottom=0.02, left=0.05, right=0.8)
    fig.suptitle(
        "Three Spring-Coupled Pendulums Motion", fontsize=16, fontweight="bold"
    )

    # Calculate plot limits using geometry3.py function
    xlim, ylim = calculate_plot_limits_3(
        ceiling_length,
        np.min(x1_all),
        np.max(x3_all),
        np.min([y1_all.min(), y2_all.min(), y3_all.min()]),
        np.min([y1_all.min(), y2_all.min(), y3_all.min()]),
        padding=0.2,
        middle_mass_x=np.mean(x2_all),
        middle_mass_y=np.mean(y2_all),
    )
    ax.set_xlim(xlim)
    ax.set_ylim(ylim)
    ax.grid(False)
    ax.set_xticks([])
    ax.set_yticks([])
    ax.set_xticklabels([])
    ax.set_yticklabels([])

    # Determine motion type based on initial conditions
    if (
        abs(theta_1_init - theta_2_init) < 1e-6
        and abs(theta_2_init - theta_3_init) < 1e-6
    ):
        motion_type = "Sloshing Mode (Normal Mode 1 - All in phase)"
    elif abs(theta_1_init + theta_3_init) < 1e-6 and abs(theta_2_init) < 1e-6:
        motion_type = "Anti-Symmetric Mode (Normal Mode 2 - Middle at rest)"
    elif (
        abs(theta_1_init - theta_3_init) < 1e-6
        and abs(theta_2_init + 2 * theta_1_init) < 1e-6
    ):
        motion_type = "Breathing Mode (Normal Mode 3 - Maximum spring stretch)"
    else:
        motion_type = "Mixed-Mode Oscillation (Superposition)"

    title_text = (
        f"Three Coupled Pendulum System: {motion_type}\n"
        f"Normal Mode Frequencies: "
        f"$\\omega_1={omega1_theory:.3f}$ rad/s, "
        f"$\\omega_2={omega2_theory:.3f}$ rad/s, "
        f"$\\omega_3={omega3_theory:.3f}$ rad/s"
    )

    ax.set_title(title_text, fontsize=13, fontweight="bold")

    # Draw ceiling/support using geometry3.py function
    draw_ceiling(ax, ceiling_length, ceiling_height)

    # Pivot points
    ax.plot(pivot1_x, pivot_y, marker="o", color="black", markersize=15, zorder=6)
    ax.plot(pivot2_x, pivot_y, marker="o", color="black", markersize=15, zorder=6)
    ax.plot(pivot3_x, pivot_y, marker="o", color="black", markersize=15, zorder=6)

    # Reference lines for equilibrium positions
    ax.axvline(
        pivot1_x, ymin=0, ymax=1, color="gray", ls="--", lw=1, zorder=0, alpha=0.5
    )
    ax.axvline(
        pivot2_x, ymin=0, ymax=1, color="gray", ls="--", lw=1, zorder=0, alpha=0.5
    )
    ax.axvline(
        pivot3_x, ymin=0, ymax=1, color="gray", ls="--", lw=1, zorder=0, alpha=0.5
    )

    # Initialize plot elements
    # Rods
    (rod1_line,) = ax.plot([], [], "k-", lw=2.5, zorder=3)
    (rod2_line,) = ax.plot([], [], "k-", lw=2.5, zorder=3)
    (rod3_line,) = ax.plot([], [], "k-", lw=2.5, zorder=3)

    # Springs
    (spring1_line,) = ax.plot([], [], "gray", lw=1.5, zorder=2)
    (spring2_line,) = ax.plot([], [], "gray", lw=1.5, zorder=2)

    # Masses (as markers)
    mass_size_min = 20
    mass_size_ref = 25
    avg_mass = (m1 + m2 + m3) / 3.0 if (m1 + m2 + m3) > 0 else 1.0
    mass_1_size = max(mass_size_min, mass_size_ref * np.cbrt(m1 / avg_mass))
    mass_2_size = max(mass_size_min, mass_size_ref * np.cbrt(m2 / avg_mass))
    mass_3_size = max(mass_size_min, mass_size_ref * np.cbrt(m3 / avg_mass))

    (mass1_plot,) = ax.plot(
        [],
        [],
        marker="o",
        color="crimson",
        markersize=mass_1_size,
        zorder=4,
        label=f"$m_1$={m1} kg",
    )
    (mass2_plot,) = ax.plot(
        [],
        [],
        marker="o",
        color="green",
        markersize=mass_2_size,
        zorder=4,
        label=f"$m_2$={m2} kg",
    )
    (mass3_plot,) = ax.plot(
        [],
        [],
        marker="o",
        color="royalblue",
        markersize=mass_3_size,
        zorder=4,
        label=f"$m_3$={m3} kg",
    )

    # Traces
    trace1_x, trace1_y = [], []
    trace2_x, trace2_y = [], []
    trace3_x, trace3_y = [], []
    (trace1_line,) = ax.plot([], [], "crimson", lw=0.8, alpha=0.5, zorder=1)
    (trace2_line,) = ax.plot([], [], "green", lw=0.8, alpha=0.5, zorder=1)
    (trace3_line,) = ax.plot([], [], "royalblue", lw=0.8, alpha=0.5, zorder=1)

    # Dynamic text annotations
    time_text = ax.text(
        0.02,
        0.98,
        "",
        transform=ax.transAxes,
        fontsize=12,
        va="top",
        fontfamily="monospace",
        bbox=dict(boxstyle="round", facecolor="wheat", alpha=0.8),
    )

    # Info box with angles and frequencies
    info_box = ax.text(
        1.02,
        0.99,
        "",
        transform=ax.transAxes,
        fontsize=9,
        va="top",
        ha="left",
        fontfamily="monospace",
        bbox=dict(boxstyle="round", facecolor="lightcyan", alpha=0.8),
    )

    # Frequency comparison box
    freq_text = (
        f"Frequency Comparison:\n"
        f"Theoretical (linear):\n"
        f"  $f_1={f1_theory:.3f}$ Hz (sloshing)\n"
        f"  $f_2={f2_theory:.3f}$ Hz (anti-sym)\n"
        f"  $f_3={f3_theory:.3f}$ Hz (breathing)\n"
        f"Numerical:\n"
        f"  $f_1\\approx${f1_str}\n"
        f"  $f_2\\approx${f2_str}\n"
        f"  $f_3\\approx${f3_str}"
    )

    sys_info_text = (
        f"System Parameters:\n"
        f"{'-' * 20}\n"
        f"Masses: \n"
        f"$m_1$={m1} kg\n"
        f"$m_2$={m2} kg\n"
        f"$m_3$={m3} kg\n"
        f"Springs:\n"
        f"$k_1$={k1} N/m\n"
        f"$k_2$={k2} N/m\n"
        f"Length: $L$={L} m\n"
        f"Gravity: $g$={g} m/s¬≤\n"
        f"Initial Angles:\n"
        f"$\\theta_1={theta_1_init:.1f}^{{\\circ}}$\n"
        f"$\\theta_2={theta_2_init:.1f}^{{\\circ}}$\n"
        f"$\\theta_3={theta_3_init:.1f}^{{\\circ}}$"
    )

    ax.text(
        1.02,
        0.01,
        freq_text,
        transform=ax.transAxes,
        fontsize=9,
        va="bottom",
        ha="left",
        fontfamily="monospace",
        bbox=dict(boxstyle="round", facecolor="honeydew", alpha=0.8),
    )
    ax.text(
        1.02,
        0.64,
        sys_info_text,
        transform=ax.transAxes,
        fontsize=9,
        va="top",
        ha="left",
        fontfamily="monospace",
        bbox=dict(boxstyle="round", facecolor="lavender", alpha=0.8),
    )

    # Legend
    ax.legend(loc="lower right", fontsize=9, ncol=3, framealpha=0.9, markerscale=0.5)

    # =================================
    # ANIMATION FUNCTIONS
    # =================================

    def init():
        """Initialize animation."""
        rod1_line.set_data([], [])
        rod2_line.set_data([], [])
        rod3_line.set_data([], [])
        spring1_line.set_data([], [])
        spring2_line.set_data([], [])
        mass1_plot.set_data([], [])
        mass2_plot.set_data([], [])
        mass3_plot.set_data([], [])
        trace1_line.set_data([], [])
        trace2_line.set_data([], [])
        trace3_line.set_data([], [])
        time_text.set_text("")
        info_box.set_text("")
        return (
            rod1_line,
            rod2_line,
            rod3_line,
            spring1_line,
            spring2_line,
            mass1_plot,
            mass2_plot,
            mass3_plot,
            trace1_line,
            trace2_line,
            trace3_line,
            time_text,
            info_box,
        )

    def animate(frame):
        """Update animation for each frame."""
        # Current time and angles
        current_time = t_eval[frame]
        theta1_deg = theta1_deg_all[frame]
        theta2_deg = theta2_deg_all[frame]
        theta3_deg = theta3_deg_all[frame]

        # Current positions
        x1, y1 = x1_all[frame], y1_all[frame]
        x2, y2 = x2_all[frame], y2_all[frame]
        x3, y3 = x3_all[frame], y3_all[frame]

        # Update rods
        rod1_line.set_data([pivot1_x, x1], [pivot_y, y1])
        rod2_line.set_data([pivot2_x, x2], [pivot_y, y2])
        rod3_line.set_data([pivot3_x, x3], [pivot_y, y3])

        spring1_x, spring1_y = draw_spring_with_hook(
            start_pos=(x1, y1),
            end_pos=(x2, y2),
            num_coils=spring_num_coils,
            radius=spring_radius,
            hook_length=spring_hook_length,
        )
        spring1_line.set_data(spring1_x, spring1_y)

        spring2_x, spring2_y = draw_spring_with_hook(
            start_pos=(x2, y2),
            end_pos=(x3, y3),
            num_coils=spring_num_coils,
            radius=spring_radius,
            hook_length=spring_hook_length,
        )
        spring2_line.set_data(spring2_x, spring2_y)

        # Update masses
        mass1_plot.set_data([x1], [y1])
        mass2_plot.set_data([x2], [y2])
        mass3_plot.set_data([x3], [y3])

        # Update traces
        trace1_x.append(x1)
        trace1_y.append(y1)
        trace2_x.append(x2)
        trace2_y.append(y2)
        trace3_x.append(x3)
        trace3_y.append(y3)

        # Limit trace length for performance
        max_trace = 50
        if len(trace1_x) > max_trace:
            trace1_x.pop(0)
            trace1_y.pop(0)
            trace2_x.pop(0)
            trace2_y.pop(0)
            trace3_x.pop(0)
            trace3_y.pop(0)

        trace1_line.set_data(trace1_x, trace1_y)
        trace2_line.set_data(trace2_x, trace2_y)
        trace3_line.set_data(trace3_x, trace3_y)

        # Update timer
        time_text.set_text(f"Time: {current_time:.2f} s")

        # Update info box with angular velocities
        omega1_val = omega1[frame]
        omega2_val = omega2[frame]
        omega3_val = omega3[frame]

        # Using precomputed spring extensions and energies
        spring1_ext = spring_ext_all_1[frame]
        spring2_ext = spring_ext_all_2[frame]

        KE = KE_all[frame]
        PE_grav = PE_grav_all[frame]
        PE_spring = PE_spring_all[frame]
        total_E = total_E_all[frame]

        info_text = (
            f"$\\theta_1 = {theta1_deg:+7.2f}^{{\\circ}}$\n"
            f"$\\theta_2 = {theta2_deg:+7.2f}^{{\\circ}}$\n"
            f"$\\theta_3 = {theta3_deg:+7.2f}^{{\\circ}}$\n"
            f"$\\dot{{\\theta}}_1 = {omega1_val:+6.3f}$ rad/s\n"
            f"$\\dot{{\\theta}}_2 = {omega2_val:+6.3f}$ rad/s\n"
            f"$\\dot{{\\theta}}_3 = {omega3_val:+6.3f}$ rad/s\n"
            f"‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ\n"
            f"Spring 1: $\\Delta x = {spring1_ext:+.3f}$ m\n"
            f"Spring 2: $\\Delta x = {spring2_ext:+.3f}$ m\n"
            f"KE = {KE:.3f} J\n"
            f"PE = {PE_grav + PE_spring:.3f} J\n"
            f"Total E = {total_E:.3f} J"
        )
        info_box.set_text(info_text)

        return (
            rod1_line,
            rod2_line,
            rod3_line,
            spring1_line,
            spring2_line,
            mass1_plot,
            mass2_plot,
            mass3_plot,
            trace1_line,
            trace2_line,
            trace3_line,
            time_text,
            info_box,
        )

    # Create animation
    print("\nCreating animation...")
    anim = FuncAnimation(
        fig,
        animate,
        init_func=init,
        frames=len(t_eval),
        interval=int(1000 / fps),
        blit=False,
        repeat=False,
    )

    # Save animation if requested
    if save_anim:
        if filename is None:
            # Create default filename
            filename = (
                f"three_coupled_pendulum_m1_{m1}_m2_{m2}_m3_{m3}_k1_{k1}_k2_{k2}_"
                f"theta1_{theta_1_init}_theta2_{theta_2_init}_theta3_{theta_3_init}.gif"
            )

        # Ensure ANIMATIONS directory exists
        anim_dir = "OUTPUTS/ANIMATIONS/coupled_pendulum/three_pendulum"
        os.makedirs(anim_dir, exist_ok=True)

        save_path = os.path.join(anim_dir, filename)
        print(f"Saving animation to: {save_path}")

        try:
            anim.save(save_path, writer="ffmpeg", fps=fps, codec="gif", dpi=100)
            print("Animation saved successfully!")
        except Exception as e:
            print(f"Error saving animation: {e}")
            print("Trying pillow writer...")
            try:
                gif_path = save_path.replace(".mp4", ".gif")
                anim.save(gif_path, writer="pillow", fps=fps)
                print(f"Animation saved successfully as GIF: {gif_path}")
            except Exception as e2:
                print(f"Failed to save animation: {e2}")

        plt.close(fig)
    else:
        plt.show()

    return fig, anim


# =============================================================================
# EXAMPLE USAGE AND DEMONSTRATIONS
# =============================================================================


def demo_normal_modes_three():
    """
    Demonstrate the three normal modes of the coupled pendulum system.
    """
    print("=" * 60)
    print("THREE COUPLED PENDULUM - NORMAL MODE DEMONSTRATIONS")
    print("=" * 60)

    # Parameters (equal masses and springs for simplicity)
    m1, m2, m3 = 1.0, 1.0, 1.0
    k1, k2 = 10.0, 10.0
    L, g = 2.0, 9.81

    # Calculate theoretical frequencies
    omega1, omega2, omega3, f1, f2, f3 = th_normal_modes_triple_pendulum(
        g, L, k1, k2, m1, m2, m3
    )
    T1 = 1 / f1
    T2 = 1 / f2
    T3 = 1 / f3

    print("\nSystem Parameters:")
    print(f"  m‚ÇÅ = m‚ÇÇ = m‚ÇÉ = {m1} kg")
    print(f"  k‚ÇÅ = k‚ÇÇ = {k1} N/m")
    print(f"  L = {L} m")
    print(f"  g = {g} m/s¬≤")
    print("\nTheoretical Normal Modes:")
    print(f"  Mode 1 (sloshing):      œâ‚ÇÅ = {omega1:.4f} rad/s, T‚ÇÅ = {T1:.4f} s")
    print(f"  Mode 2 (anti-symmetric): œâ‚ÇÇ = {omega2:.4f} rad/s, T‚ÇÇ = {T2:.4f} s")
    print(f"  Mode 3 (breathing):      œâ‚ÇÉ = {omega3:.4f} rad/s, T‚ÇÉ = {T3:.4f} s")
    print(
        f"  Frequency ratios: œâ‚ÇÇ/œâ‚ÇÅ = {omega2 / omega1:.4f}, œâ‚ÇÉ/œâ‚ÇÅ = {omega3 / omega1:.4f}"
    )

    return omega1, omega2, omega3


def main():
    """Main function to run simulations and demonstrations."""
    header = "STARTING SPRING-COUPLED TRIPLE PENDULUM SIMULATION"
    print("\n" + "=" * len(header))
    print(header)
    print("=" * len(header))

    default_params = {
        "theta_1_init": 0.0,
        "theta_2_init": 0.0,
        "theta_3_init": 10.0,
        "m1": 1.0,
        "m2": 1.0,
        "m3": 1.0,
        "k1": 5.0,
        "k2": 5.0,
        "L": 2.0,
        "g": 9.81,
        "simulation_time": 10.0,
        "fps": 30,
    }

    use_default = input("Use default simulation parameters? (y/n): ").strip().lower()
    if use_default == "y":
        params = default_params.copy()
    else:
        params = {}
        try:
            params["theta_1_init"] = float(
                input(
                    f"Enter initial angle for pendulum 1 (degrees)[{default_params['theta_1_init']}] degrees): "
                )
                or default_params["theta_1_init"]
            )
            params["theta_2_init"] = float(
                input(
                    f"Enter initial angle for pendulum 2 (degrees)[{default_params['theta_2_init']}] degrees): "
                )
                or default_params["theta_2_init"]
            )
            params["theta_3_init"] = float(
                input(
                    f"Enter initial angle for pendulum 3 (degrees)[{default_params['theta_3_init']}] degrees): "
                )
                or default_params["theta_3_init"]
            )
            params["m1"] = float(
                input(
                    f"Enter mass of the first pendulum (kg)[{default_params['m1']}] kg): "
                )
                or default_params["m1"]
            )
            params["m2"] = float(
                input(
                    f"Enter mass of the second pendulum (kg)[{default_params['m2']}] kg): "
                )
                or default_params["m2"]
            )
            params["m3"] = float(
                input(
                    f"Enter mass of the third pendulum (kg)[{default_params['m3']}] kg): "
                )
                or default_params["m3"]
            )
            params["k1"] = float(
                input(f"Enter spring constant k1 (N/m)[{default_params['k1']}] N/m): ")
                or default_params["k1"]
            )
            params["k2"] = float(
                input(f"Enter spring constant k2 (N/m)[{default_params['k2']}] N/m): ")
                or default_params["k2"]
            )
            params["L"] = float(
                input(
                    f"Enter length of the pendulum rods (m)[{default_params['L']}] m): "
                )
                or default_params["L"]
            )
            params["g"] = float(
                input(
                    f"Enter gravitational acceleration (m/s¬≤)[{default_params['g']}] m/s¬≤): "
                )
                or default_params["g"]
            )
            params["simulation_time"] = float(
                input(
                    f"Enter simulation time (seconds)[{default_params['simulation_time']}] seconds): "
                )
                or default_params["simulation_time"]
            )
            params["fps"] = int(
                input(
                    f"Enter frames per second for animation [{default_params['fps']}]: "
                )
                or default_params["fps"]
            )
        except ValueError:
            print("Invalid input. Using default parameters.")
            params = default_params.copy()

    save_anim = input("Save animation to file? (y/n): ").strip().lower() == "y"
    filename = None
    if save_anim:
        filename_input = input(
            "Enter filename for animation (or press Enter for default): "
        )
        if filename_input:
            if not filename_input.endswith(".gif"):
                filename_input += ".gif"
            filename = filename_input

    print("\nRunning simulation with parameters:")
    for key, value in params.items():
        print(f"  {key}: {value}")

    demo_normal_modes_three()

    animation = simulate_three_coupled_pendulum(
        theta_1_init=params["theta_1_init"],
        theta_2_init=params["theta_2_init"],
        theta_3_init=params["theta_3_init"],
        m1=params["m1"],
        m2=params["m2"],
        m3=params["m3"],
        k1=params["k1"],
        k2=params["k2"],
        L=params["L"],
        g=params["g"],
        simulation_time=params["simulation_time"],
        fps=params["fps"],
        save_anim=save_anim,
        filename=filename,
    )

    return animation


if __name__ == "__main__":
    animation = main()


### 2.3 Time Series Diagnostics (Angles)

In [None]:
def three_coupled_pendulum_animation_with_plots(
    theta_1_init=0.0,
    theta_2_init=10.0,
    theta_3_init=0.0,
    m1=1.0,
    m2=1.0,
    m3=1.0,
    k1=5.0,
    k2=5.0,
    L=2.0,
    g=9.81,
    simulation_time=20.0,
    fps=30,
    save_format="gif",
    save_anim=False,
    filename=None,
):
    """
    Simulate and animate a coupled pendulum system.

    Parameters
    ----------
    theta_1_init : float
        Initial angle of pendulum 1 in degrees
    theta_2_init : float
        Initial angle of pendulum 2 in degrees
    theta_3_init : float
        Initial angle of pendulum 3 in degrees
    m1, m2, m3 : float
        Masses of pendulum bobs
    k1, k2 : float
        Spring constants
    L : float
        Length of pendulum rods
    g : float
        Gravitational acceleration
    simulation_time : float
        Total simulation time in seconds
    fps : int
        Frames per second for animation
    save_format : str
        Format to save animation ('gif' or 'mp4')
    save_anim : bool
        Whether to save the animation
    filename : str or None
        Filename for saved animation (auto-generated if None)

    Returns
    -------
        anim - Animation object
    """

    # =========================================================================
    # NUMERICAL SOLUTION
    # =========================================================================

    # Convert initial angles to radians
    theta1_0 = np.radians(theta_1_init)
    theta2_0 = np.radians(theta_2_init)
    theta3_0 = np.radians(theta_3_init)

    # Initial state: [Œ∏‚ÇÅ, œâ‚ÇÅ, Œ∏‚ÇÇ, œâ‚ÇÇ, Œ∏‚ÇÉ, œâ‚ÇÉ] (starting from rest)
    y0 = [theta1_0, 0.0, theta2_0, 0.0, theta3_0, 0.0]

    # Time span
    t_span = (0, simulation_time)
    n_frames = int(simulation_time * fps) + 1
    t_eval = np.linspace(0, simulation_time, n_frames)

    # Solve the ODE
    print("Solving differential equations...")
    solution = solve_ivp(
        three_coupled_pendulum_derivatives,
        t_span,
        y0,
        args=(g, L, k1, k2, m1, m2, m3),
        method="RK45",
        rtol=1e-5,
        atol=1e-7,
        dense_output=True,
    )

    if solution.sol is None:
        raise RuntimeError(
            "solve_ivp did not return a dense solution (solution.sol is None)"
        )

    y = solution.sol(t=t_eval)
    theta1 = y[0]
    omega1 = y[1]
    theta2 = y[2]
    omega2 = y[3]
    theta3 = y[4]
    omega3 = y[5]
    print(f"Solution computed: {len(t_eval)} time steps")

    # =========================================================================
    # PRECOMPUTE SPRING EXTENSION + ENERGIES
    # =========================================================================

    theta1_deg_all = np.degrees(theta1)
    theta2_deg_all = np.degrees(theta2)
    theta3_deg_all = np.degrees(theta3)

    spring_ext_all_1 = L * (np.sin(theta2) - np.sin(theta1))
    spring_ext_all_2 = L * (np.sin(theta3) - np.sin(theta2))

    KE_all = (
        0.5 * m1 * (L * omega1) ** 2
        + 0.5 * m2 * (L * omega2) ** 2
        + 0.5 * m3 * (L * omega3) ** 2
    )
    PE_grav_all = (
        -m1 * g * L * np.cos(theta1)
        - m2 * g * L * np.cos(theta2)
        - m3 * g * L * np.cos(theta3)
    )
    PE_spring_all = 0.5 * k1 * spring_ext_all_1**2 + 0.5 * k2 * spring_ext_all_2**2
    total_E_all = KE_all + PE_grav_all + PE_spring_all

    # =========================================================================
    # FREQUENCY ANALYSIS
    # =========================================================================

    # Theoretical normal modes
    omega1_theory, omega2_theory, omega3_theory, f1_theory, f2_theory, f3_theory = (
        th_normal_modes_triple_pendulum(g, L, k1, k2, m1, m2, m3)
    )

    # Numerical frequency estimation
    f1_num, f2_num, f3_num = est_freqs_triple_pendulum(
        t_eval, theta1, theta2, theta3, g, L, k1, k2, m1, m2, m3
    )

    def _fmt_freq(freq):
        return f"{freq:.4f} Hz" if freq > 0 else "N/A (Not Excited)"

    f1_str = _fmt_freq(f1_num)
    f2_str = _fmt_freq(f2_num)
    f3_str = _fmt_freq(f3_num)

    print("\nNormal Mode Frequencies:")
    print(f"  Theoretical: œâ‚ÇÅ = {omega1_theory:.4f} rad/s (f‚ÇÅ = {f1_theory:.4f} Hz)")
    print(f"              œâ‚ÇÇ = {omega2_theory:.4f} rad/s (f‚ÇÇ = {f2_theory:.4f} Hz)")
    print(f"              œâ‚ÇÉ = {omega3_theory:.4f} rad/s (f‚ÇÉ = {f3_theory:.4f} Hz)")
    print(f"  Numerical:   f‚ÇÅ ‚âà {f1_str}, f‚ÇÇ ‚âà {f2_str}, f‚ÇÉ ‚âà {f3_str}")

    # =========================================================================
    # ANIMATION SETUP
    # =========================================================================

    # Pivot positions
    pivot_separation = 1.5
    pivot1_x = -pivot_separation
    pivot2_x = 0
    pivot3_x = pivot_separation
    pivot_y = 0

    ceiling_length = 2 * pivot_separation + 0.5
    separation_ratio = pivot_separation / ceiling_length
    spring_hook_length = calculate_fixed_hook_length_three_pendulum(
        ceiling_length, L, separation_ratio, hook_ratio=0.15
    )
    spring_radius = 0.02 * ceiling_length
    spring_num_coils = 10
    ceiling_height = 0.05 * ceiling_length

    # Precompute mass positions
    x1_all = pivot1_x + L * np.sin(theta1)
    y1_all = pivot_y - L * np.cos(theta1)
    x2_all = pivot2_x + L * np.sin(theta2)
    y2_all = pivot_y - L * np.cos(theta2)
    x3_all = pivot3_x + L * np.sin(theta3)
    y3_all = pivot_y - L * np.cos(theta3)

    # Figure setup with gridspec
    fig = plt.figure(figsize=(14, 10))
    gs = GridSpec(
        3,
        2,
        figure=fig,
        width_ratios=[1.5, 1],
        left=0.05,
        right=0.98,
        top=0.88,
        bottom=0.08,
        hspace=0.35,
        wspace=0.5,
    )

    # Main pendulum animation (left column)
    ax = fig.add_subplot(gs[:, 0])

    # Time series plots (right column)
    ax_theta1 = fig.add_subplot(gs[0, 1])
    ax_theta2 = fig.add_subplot(gs[1, 1])
    ax_theta3 = fig.add_subplot(gs[2, 1])

    fig.suptitle("Three Spring-Coupled Pendulum System", fontsize=16, fontweight="bold")

    xlim, ylim = calculate_plot_limits_3(
        ceiling_length,
        np.min(x1_all),
        np.max(x3_all),
        np.min([y1_all.min(), y2_all.min(), y3_all.min()]),
        np.min([y1_all.min(), y2_all.min(), y3_all.min()]),
        padding=0.2,
        middle_mass_x=np.mean(x2_all),
        middle_mass_y=np.mean(y2_all),
    )
    ax.set_xlim(xlim)
    ax.set_ylim(ylim)
    ax.grid(False)
    ax.set_xticks([])
    ax.set_yticks([])
    ax.set_xticklabels([])
    ax.set_yticklabels([])

    # Determine motion type based on initial conditions
    if (
        abs(theta_1_init - theta_2_init) < 1e-6
        and abs(theta_2_init - theta_3_init) < 1e-6
    ):
        motion_type = "Sloshing Mode (Normal Mode 1 - All in phase)"
    elif abs(theta_1_init + theta_3_init) < 1e-6 and abs(theta_2_init) < 1e-6:
        motion_type = "Anti-Symmetric Mode (Normal Mode 2 - Middle at rest)"
    elif (
        abs(theta_1_init - theta_3_init) < 1e-6
        and abs(theta_2_init + 2 * theta_1_init) < 1e-6
    ):
        motion_type = "Breathing Mode (Normal Mode 3 - Maximum spring stretch)"
    else:
        motion_type = "Mixed-Mode Oscillation (Superposition)"

    title_text = (
        f"{motion_type}\n"
        f"Normal Mode Frequencies\n"
        f"$\\omega_1={omega1_theory:.3f}$ rad/s, "
        f"$\\omega_2={omega2_theory:.3f}$ rad/s, "
        f"$\\omega_3={omega3_theory:.3f}$ rad/s"
    )

    ax.set_title(title_text, fontsize=12, fontweight="bold")

    draw_ceiling(ax, ceiling_length, ceiling_height)

    # Pivot points
    ax.plot(pivot1_x, pivot_y, marker="o", color="black", markersize=15, zorder=6)
    ax.plot(pivot2_x, pivot_y, marker="o", color="black", markersize=15, zorder=6)
    ax.plot(pivot3_x, pivot_y, marker="o", color="black", markersize=15, zorder=6)

    # Reference lines for equilibrium positions
    ax.axvline(
        pivot1_x, ymin=0, ymax=1, color="gray", ls="--", lw=1, zorder=0, alpha=0.5
    )
    ax.axvline(
        pivot2_x, ymin=0, ymax=1, color="gray", ls="--", lw=1, zorder=0, alpha=0.5
    )
    ax.axvline(
        pivot3_x, ymin=0, ymax=1, color="gray", ls="--", lw=1, zorder=0, alpha=0.5
    )

    # Initialize plot elements
    # Rods
    (rod1_line,) = ax.plot([], [], "k-", lw=2.5, zorder=3)
    (rod2_line,) = ax.plot([], [], "k-", lw=2.5, zorder=3)
    (rod3_line,) = ax.plot([], [], "k-", lw=2.5, zorder=3)

    # Springs
    (spring1_line,) = ax.plot([], [], "gray", lw=1.5, zorder=2)
    (spring2_line,) = ax.plot([], [], "gray", lw=1.5, zorder=2)

    # Masses (as markers)
    mass_size_min = 20
    mass_size_ref = 25
    avg_mass = (m1 + m2 + m3) / 3.0 if (m1 + m2 + m3) > 0 else 1.0
    mass_1_size = max(mass_size_min, mass_size_ref * np.cbrt(m1 / avg_mass))
    mass_2_size = max(mass_size_min, mass_size_ref * np.cbrt(m2 / avg_mass))
    mass_3_size = max(mass_size_min, mass_size_ref * np.cbrt(m3 / avg_mass))

    (mass1_plot,) = ax.plot(
        [],
        [],
        marker="o",
        color="crimson",
        markersize=mass_1_size,
        zorder=4,
        label=f"$m_1$={m1} kg",
    )
    (mass2_plot,) = ax.plot(
        [],
        [],
        marker="o",
        color="green",
        markersize=mass_2_size,
        zorder=4,
        label=f"$m_2$={m2} kg",
    )
    (mass3_plot,) = ax.plot(
        [],
        [],
        marker="o",
        color="royalblue",
        markersize=mass_3_size,
        zorder=4,
        label=f"$m_3$={m3} kg",
    )

    # Traces
    trace1_x, trace1_y = [], []
    trace2_x, trace2_y = [], []
    trace3_x, trace3_y = [], []
    (trace1_line,) = ax.plot([], [], "crimson", lw=0.8, alpha=0.5, zorder=1)
    (trace2_line,) = ax.plot([], [], "green", lw=0.8, alpha=0.5, zorder=1)
    (trace3_line,) = ax.plot([], [], "royalblue", lw=0.8, alpha=0.5, zorder=1)

    # Trace markers (front of trace)
    (trace1_marker,) = ax.plot(
        [], [], "o", color="crimson", markersize=8, zorder=5, alpha=0.8
    )
    (trace2_marker,) = ax.plot(
        [], [], "o", color="green", markersize=8, zorder=5, alpha=0.8
    )
    (trace3_marker,) = ax.plot(
        [], [], "o", color="royalblue", markersize=8, zorder=5, alpha=0.8
    )

    # Setup time series plots
    # Theta1 vs time
    ax_theta1.set_xlim(0, 1.1 * simulation_time)
    ax_theta1.set_ylim(np.degrees(theta1).min() - 5, np.degrees(theta1).max() + 5)
    ax_theta1.set_xlabel("Time (s)", fontsize=10)
    ax_theta1.set_ylabel("$\\theta_1$ (deg)", fontsize=10)
    ax_theta1.grid(True, alpha=0.3)
    ax_theta1.set_title("$\\theta_1$ vs $t$", fontsize=12, fontweight="bold")
    (theta1_time_line,) = ax_theta1.plot([], [], "crimson", lw=2, label="$\\theta_1$")
    (theta1_current_point,) = ax_theta1.plot(
        [], [], "o", color="crimson", markersize=8, zorder=5
    )
    ax_theta1.legend(loc="upper right", fontsize=9)

    # Theta2 vs time
    ax_theta2.set_xlim(0, 1.1 * simulation_time)
    ax_theta2.set_ylim(np.degrees(theta2).min() - 5, np.degrees(theta2).max() + 5)
    ax_theta2.set_xlabel("Time (s)", fontsize=10)
    ax_theta2.set_ylabel("$\\theta_2$ (deg)", fontsize=10)
    ax_theta2.grid(True, alpha=0.3)
    ax_theta2.set_title("$\\theta_2$ vs $t$", fontsize=12, fontweight="bold")
    (theta2_time_line,) = ax_theta2.plot([], [], "green", lw=2, label="$\\theta_2$")
    (theta2_current_point,) = ax_theta2.plot(
        [], [], "o", color="green", markersize=8, zorder=5
    )
    ax_theta2.legend(loc="upper right", fontsize=9)

    # Theta3 vs time
    ax_theta3.set_xlim(0, 1.1 * simulation_time)
    ax_theta3.set_ylim(np.degrees(theta3).min() - 5, np.degrees(theta3).max() + 5)
    ax_theta3.set_xlabel("Time (s)", fontsize=10)
    ax_theta3.set_ylabel("$\\theta_3$ (deg)", fontsize=10)
    ax_theta3.grid(True, alpha=0.3)
    ax_theta3.set_title("$\\theta_3$ vs $t$", fontsize=12, fontweight="bold")
    (theta3_time_line,) = ax_theta3.plot([], [], "royalblue", lw=2, label="$\\theta_3$")
    (theta3_current_point,) = ax_theta3.plot(
        [], [], "o", color="royalblue", markersize=8, zorder=5
    )
    ax_theta3.legend(loc="upper right", fontsize=9)

    # Store time series data
    time_history = []
    theta1_history = []
    theta2_history = []
    theta3_history = []

    # Dynamic text annotations
    time_text = ax.text(
        0.02,
        0.98,
        "",
        transform=ax.transAxes,
        fontsize=12,
        va="top",
        fontfamily="monospace",
        bbox=dict(boxstyle="round", facecolor="wheat", alpha=0.8),
    )

    # Info box with angles and frequencies
    info_box = ax.text(
        1.02,
        0.99,
        "",
        transform=ax.transAxes,
        fontsize=9,
        va="top",
        ha="left",
        fontfamily="monospace",
        bbox=dict(boxstyle="round", facecolor="lightcyan", alpha=0.8),
    )

    # Frequency comparison box
    freq_text = (
        f"Frequency Comparison:\n"
        f"{'-' * 20}\n"
        f"Theoretical (linear):\n"
        f"  $f_1={f1_theory:.3f}$ Hz (sloshing)\n"
        f"  $f_2={f2_theory:.3f}$ Hz (anti-sym)\n"
        f"  $f_3={f3_theory:.3f}$ Hz (breathing)\n"
        f"Numerical:\n"
        f"  $f_1\\approx${f1_str}\n"
        f"  $f_2\\approx${f2_str}\n"
        f"  $f_3\\approx${f3_str}"
    )

    sys_info_text = (
        f"System Parameters:\n"
        f"{'-' * 20}\n"
        f"Masses: \n"
        f"$m_1$={m1} kg\n"
        f"$m_2$={m2} kg\n"
        f"$m_3$={m3} kg\n"
        f"Springs:\n"
        f"$k_1$={k1} N/m\n"
        f"$k_2$={k2} N/m\n"
        f"Length: $L$={L} m\n"
        f"Gravity: $g$={g} m/s¬≤\n"
        f"Initial Angles:\n"
        f"$\\theta_1={theta_1_init:.1f}^{{\\circ}}$\n"
        f"$\\theta_2={theta_2_init:.1f}^{{\\circ}}$\n"
        f"$\\theta_3={theta_3_init:.1f}^{{\\circ}}$"
    )

    ax.text(
        1.02,
        0.01,
        freq_text,
        transform=ax.transAxes,
        fontsize=8,
        va="bottom",
        ha="left",
        fontfamily="monospace",
        bbox=dict(boxstyle="round", facecolor="honeydew", alpha=0.8),
    )
    ax.text(
        1.02,
        0.5,
        sys_info_text,
        transform=ax.transAxes,
        fontsize=9,
        va="center",
        ha="left",
        fontfamily="monospace",
        bbox=dict(boxstyle="round", facecolor="lavender", alpha=0.8),
    )

    # Legend
    ax.legend(loc="lower right", fontsize=9, ncol=3, framealpha=0.9, markerscale=0.5)

    # =========================================================================
    # ANIMATION FUNCTIONS
    # =========================================================================

    def init():
        """Initialize animation."""
        rod1_line.set_data([], [])
        rod2_line.set_data([], [])
        rod3_line.set_data([], [])
        spring1_line.set_data([], [])
        spring2_line.set_data([], [])
        mass1_plot.set_data([], [])
        mass2_plot.set_data([], [])
        mass3_plot.set_data([], [])
        trace1_line.set_data([], [])
        trace2_line.set_data([], [])
        trace3_line.set_data([], [])
        trace1_marker.set_data([], [])
        trace2_marker.set_data([], [])
        trace3_marker.set_data([], [])
        theta1_time_line.set_data([], [])
        theta2_time_line.set_data([], [])
        theta3_time_line.set_data([], [])
        theta1_current_point.set_data([], [])
        theta2_current_point.set_data([], [])
        theta3_current_point.set_data([], [])
        time_text.set_text("")
        info_box.set_text("")
        return (
            rod1_line,
            rod2_line,
            rod3_line,
            spring1_line,
            spring2_line,
            mass1_plot,
            mass2_plot,
            mass3_plot,
            trace1_line,
            trace2_line,
            trace3_line,
            trace1_marker,
            trace2_marker,
            trace3_marker,
            theta1_time_line,
            theta2_time_line,
            theta3_time_line,
            theta1_current_point,
            theta2_current_point,
            theta3_current_point,
            time_text,
            info_box,
        )

    def animate(frame):
        """Update animation for each frame."""
        # Current time and angles
        current_time = t_eval[frame]
        theta1_deg = theta1_deg_all[frame]
        theta2_deg = theta2_deg_all[frame]
        theta3_deg = theta3_deg_all[frame]

        # Current positions
        x1, y1 = x1_all[frame], y1_all[frame]
        x2, y2 = x2_all[frame], y2_all[frame]
        x3, y3 = x3_all[frame], y3_all[frame]

        # Update rods
        rod1_line.set_data([pivot1_x, x1], [pivot_y, y1])
        rod2_line.set_data([pivot2_x, x2], [pivot_y, y2])
        rod3_line.set_data([pivot3_x, x3], [pivot_y, y3])

        spring1_x, spring1_y = draw_spring_with_hook(
            start_pos=(x1, y1),
            end_pos=(x2, y2),
            num_coils=spring_num_coils,
            radius=spring_radius,
            hook_length=spring_hook_length,
        )
        spring1_line.set_data(spring1_x, spring1_y)

        spring2_x, spring2_y = draw_spring_with_hook(
            start_pos=(x2, y2),
            end_pos=(x3, y3),
            num_coils=spring_num_coils,
            radius=spring_radius,
            hook_length=spring_hook_length,
        )
        spring2_line.set_data(spring2_x, spring2_y)

        # Update masses
        mass1_plot.set_data([x1], [y1])
        mass2_plot.set_data([x2], [y2])
        mass3_plot.set_data([x3], [y3])

        # Update traces
        trace1_x.append(x1)
        trace1_y.append(y1)
        trace2_x.append(x2)
        trace2_y.append(y2)
        trace3_x.append(x3)
        trace3_y.append(y3)

        # Limit trace length for performance
        max_trace = 50
        if len(trace1_x) > max_trace:
            trace1_x.pop(0)
            trace1_y.pop(0)
            trace2_x.pop(0)
            trace2_y.pop(0)
            trace3_x.pop(0)
            trace3_y.pop(0)

        trace1_line.set_data(trace1_x, trace1_y)
        trace2_line.set_data(trace2_x, trace2_y)
        trace3_line.set_data(trace3_x, trace3_y)

        # Update trace markers (at the front of traces)
        if len(trace1_x) > 0:
            trace1_marker.set_data([trace1_x[-1]], [trace1_y[-1]])
            trace2_marker.set_data([trace2_x[-1]], [trace2_y[-1]])
            trace3_marker.set_data([trace3_x[-1]], [trace3_y[-1]])

        # Update time series data
        time_history.append(current_time)
        theta1_history.append(theta1_deg)
        theta2_history.append(theta2_deg)
        theta3_history.append(theta3_deg)

        # Update time series plots
        theta1_time_line.set_data(time_history, theta1_history)
        theta2_time_line.set_data(time_history, theta2_history)
        theta3_time_line.set_data(time_history, theta3_history)
        # Update current point markers on time series
        theta1_current_point.set_data([current_time], [theta1_deg])
        theta2_current_point.set_data([current_time], [theta2_deg])
        theta3_current_point.set_data([current_time], [theta3_deg])

        # Update time text
        time_text.set_text(f"Time: {current_time:.2f} s")

        # Update info box with angular velocities
        omega1_val = omega1[frame]
        omega2_val = omega2[frame]
        omega3_val = omega3[frame]

        # Using precomputed spring extensions and energies
        spring1_ext = spring_ext_all_1[frame]
        spring2_ext = spring_ext_all_2[frame]

        KE = KE_all[frame]
        PE_grav = PE_grav_all[frame]
        PE_spring = PE_spring_all[frame]
        total_E = total_E_all[frame]

        info_text = (
            f"$\\theta_1 = {theta1_deg:+7.2f}^{{\\circ}}$\n"
            f"$\\theta_2 = {theta2_deg:+7.2f}^{{\\circ}}$\n"
            f"$\\theta_3 = {theta3_deg:+7.2f}^{{\\circ}}$\n"
            f"$\\dot{{\\theta}}_1 = {omega1_val:+6.3f}$ rad/s\n"
            f"$\\dot{{\\theta}}_2 = {omega2_val:+6.3f}$ rad/s\n"
            f"$\\dot{{\\theta}}_3 = {omega3_val:+6.3f}$ rad/s\n"
            f"‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ\n"
            f"Spring 1: $\\Delta x = {spring1_ext:+.3f}$ m\n"
            f"Spring 2: $\\Delta x = {spring2_ext:+.3f}$ m\n"
            f"KE = {KE:.3f} J\n"
            f"PE = {PE_grav + PE_spring:.3f} J\n"
            f"Total E = {total_E:.3f} J"
        )
        info_box.set_text(info_text)

        return (
            rod1_line,
            rod2_line,
            rod3_line,
            spring1_line,
            spring2_line,
            mass1_plot,
            mass2_plot,
            mass3_plot,
            trace1_line,
            trace2_line,
            trace3_line,
            trace1_marker,
            trace2_marker,
            trace3_marker,
            theta1_time_line,
            theta2_time_line,
            theta3_time_line,
            theta1_current_point,
            theta2_current_point,
            theta3_current_point,
            time_text,
            info_box,
        )

    # Create animation
    print("\nCreating animation...")
    anim = FuncAnimation(
        fig,
        animate,
        init_func=init,
        frames=len(t_eval),
        interval=1000 / fps,
        blit=True,
        repeat=False,
    )

    # Save animation if requested
    if save_anim:
        if filename is None:
            ext = ".gif" if save_format == "gif" else ".mp4"
            # Create default filename
            filename = (
                f"time_series_m1_{m1}_m2_{m2}_m3_{m3}_k1_{k1}_k2_{k2}_L_{L}_"
                f"theta1_{theta_1_init}_theta2_{theta_2_init}{ext}"
            )

        # Ensure ANIMATIONS directory exists
        anim_dir = "OUTPUTS/ANIMATIONS/coupled_pendulum/three_pendulum"
        os.makedirs(anim_dir, exist_ok=True)

        save_path = os.path.join(anim_dir, filename)
        print(f"Saving animation to: {save_path}")

        try:
            if save_format == "gif":
                anim.save(save_path, writer="ffmpeg", fps=fps, codec="gif", dpi=120)
            else:
                anim.save(save_path, writer="ffmpeg", fps=fps, dpi=120)
            print("Animation saved successfully!")
        except Exception as e:
            print(f"Error saving animation: {e}")

        plt.close(fig)
    else:
        plt.show()

    return anim


def main():
    """Main function to run simulations and demonstrations."""
    header = "STARTING SPRING-COUPLED TRIPLE PENDULUM SIMULATION"
    print("\n" + "=" * len(header))
    print(header)
    print("=" * len(header))

    default_params = {
        "theta_1_init": 0.0,
        "theta_2_init": 0.0,
        "theta_3_init": 10.0,
        "m1": 1.0,
        "m2": 1.0,
        "m3": 1.0,
        "k1": 5.0,
        "k2": 5.0,
        "L": 2.0,
        "g": 9.81,
        "simulation_time": 10.0,
        "fps": 30,
        "save_format": "gif",
    }

    use_default = input("Use default simulation parameters? (y/n): ").strip().lower()
    if use_default == "y":
        params = default_params.copy()
    else:
        params = {}
        try:
            params["theta_1_init"] = float(
                input(
                    f"Enter initial angle for pendulum 1 (degrees)[{default_params['theta_1_init']}] degrees): "
                )
                or default_params["theta_1_init"]
            )
            params["theta_2_init"] = float(
                input(
                    f"Enter initial angle for pendulum 2 (degrees)[{default_params['theta_2_init']}] degrees): "
                )
                or default_params["theta_2_init"]
            )
            params["theta_3_init"] = float(
                input(
                    f"Enter initial angle for pendulum 3 (degrees)[{default_params['theta_3_init']}] degrees): "
                )
                or default_params["theta_3_init"]
            )
            params["m1"] = float(
                input(
                    f"Enter mass of the first pendulum (kg)[{default_params['m1']}] kg): "
                )
                or default_params["m1"]
            )
            params["m2"] = float(
                input(
                    f"Enter mass of the second pendulum (kg)[{default_params['m2']}] kg): "
                )
                or default_params["m2"]
            )
            params["m3"] = float(
                input(
                    f"Enter mass of the third pendulum (kg)[{default_params['m3']}] kg): "
                )
                or default_params["m3"]
            )
            params["k1"] = float(
                input(f"Enter spring constant k1 (N/m)[{default_params['k1']}] N/m): ")
                or default_params["k1"]
            )
            params["k2"] = float(
                input(f"Enter spring constant k2 (N/m)[{default_params['k2']}] N/m): ")
                or default_params["k2"]
            )
            params["L"] = float(
                input(
                    f"Enter length of the pendulum rods (m)[{default_params['L']}] m): "
                )
                or default_params["L"]
            )
            params["g"] = float(
                input(
                    f"Enter gravitational acceleration (m/s¬≤)[{default_params['g']}] m/s¬≤): "
                )
                or default_params["g"]
            )
            params["simulation_time"] = float(
                input(
                    f"Enter simulation time (seconds)[{default_params['simulation_time']}] seconds): "
                )
                or default_params["simulation_time"]
            )
            params["fps"] = int(
                input(
                    f"Enter frames per second for animation [{default_params['fps']}]: "
                )
                or default_params["fps"]
            )
            params["save_format"] = (
                input(
                    f"Enter animation save format ('gif' or 'mp4')[{default_params['save_format']}]: "
                )
                .strip()
                .lower()
                or default_params["save_format"]
            )
            if params["save_format"] not in ["gif", "mp4"]:
                print("Invalid format. Using default 'gif'.")
                params["save_format"] = "gif"
        except ValueError:
            print("Invalid input. Using default parameters.")
            params = default_params.copy()

    save_anim = input("Save animation to file? (y/n): ").strip().lower() == "y"
    filename = None
    if save_anim:
        filename_input = input(
            "Enter filename for animation (or press Enter for default): "
        ).strip()
        if filename_input:
            if not filename_input.endswith((".gif", ".mp4")):
                filename_input += f".{params['save_format']}"
            filename = filename_input

    print("\nRunning simulation with parameters:")
    for key, value in params.items():
        print(f"  {key}: {value}")

    demo_normal_modes_three()

    animation = three_coupled_pendulum_animation_with_plots(
        theta_1_init=params["theta_1_init"],
        theta_2_init=params["theta_2_init"],
        theta_3_init=params["theta_3_init"],
        m1=params["m1"],
        m2=params["m2"],
        m3=params["m3"],
        k1=params["k1"],
        k2=params["k2"],
        L=params["L"],
        g=params["g"],
        simulation_time=params["simulation_time"],
        fps=params["fps"],
        save_format=params["save_format"],
        save_anim=save_anim,
        filename=filename,
    )

    return animation


if __name__ == "__main__":
    animation = main()

### 2.4 Frequency-Domain Analysis (FFT)

In [None]:
def plot_fft_spectrum_triple_pendulum(
    t,
    theta1,
    theta2,
    theta3,
    g,
    L,
    k1,
    k2,
    m1,
    m2,
    m3,
    theta_1_init,
    theta_2_init,
    theta_3_init,
    prominence=0.01,
    height=None,
    distance=5,
    save_fig=False,
    filename=None,
):
    """
    Plot FFT spectrum of theta_1, theta_2, and theta_3 with peak detection and annotation.
    Also displays theoretical normal mode frequencies for three-pendulum system.

    Parameters
    ----------
    t : array
        Time array
    theta1, theta2, theta3 : array
        Angle arrays from simulation
    g, L, k1, k2, m1, m2, m3 : float
        System parameters for calculating theoretical frequencies
    theta_1_init, theta_2_init, theta_3_init : float
        Initial angles in degrees (for motion type identification)
    prominence : float, optional
        Minimum prominence of peaks (default: 0.01)
    height : float, optional
        Minimum height of peaks (default: None)
    distance : int, optional
        Minimum distance between peaks in samples (default: 5)
    save_fig : bool, optional
        Whether to save the figure (default: False)
    filename : str, optional
        Filename to save the figure (default: None)

    Returns
    -------
    fig : matplotlib.figure.Figure
        Figure object containing the plots
    axes : array of matplotlib.axes.Axes
        Array of axes objects
    """
    # Calculate theoretical frequencies
    _, _, _, f1_theory, f2_theory, f3_theory = th_normal_modes_triple_pendulum(
        g, L, k1, k2, m1, m2, m3
    )

    # Numerical frequency estimation using the existing function
    f1_num, f2_num, f3_num = est_freqs_triple_pendulum(
        t, theta1, theta2, theta3, g, L, k1, k2, m1, m2, m3
    )

    # Format numerical frequencies for display
    def _fmt_freq(f):
        return f"{f:.3f} Hz" if f > 0 else "N/A (Not Excited)"

    f1_str = _fmt_freq(f1_num)
    f2_str = _fmt_freq(f2_num)
    f3_str = _fmt_freq(f3_num)

    # Identify motion type based on initial conditions
    tol = 1e-6
    if (
        abs(theta_1_init - theta_2_init) < tol
        and abs(theta_2_init - theta_3_init) < tol
    ):
        if abs(theta_1_init) < tol:
            motion_type = "No Initial Displacement"
            excited_mode = "None"
        else:
            motion_type = "Sloshing Mode (All in Phase)"
            excited_mode = "Mode 1 Only"
    elif abs(theta_1_init + theta_3_init) < tol and abs(theta_2_init) < tol:
        motion_type = "Anti-Symmetric Mode"
        excited_mode = "Mode 2 Only"
    elif (
        abs(theta_1_init - theta_3_init) < tol
        and abs(theta_2_init + 2 * theta_1_init) < tol
    ):
        motion_type = "Breathing Mode"
        excited_mode = "Mode 3 Only"
    else:
        motion_type = "Mixed-Mode Oscillation"
        excited_mode = "Multiple Modes (Superposition)"

    # Compute FFT for all three angles
    dt = t[1] - t[0]
    n = len(t)

    # Remove DC offset
    theta1_centered = theta1 - np.mean(theta1)
    theta2_centered = theta2 - np.mean(theta2)
    theta3_centered = theta3 - np.mean(theta3)

    # Apply Hann window to reduce spectral leakage
    window = np.hanning(n)
    theta1_windowed = theta1_centered * window
    theta2_windowed = theta2_centered * window
    theta3_windowed = theta3_centered * window

    # Compute FFT
    freqs = np.fft.rfftfreq(n, dt)
    fft_theta1 = np.fft.rfft(theta1_windowed)
    fft_theta2 = np.fft.rfft(theta2_windowed)
    fft_theta3 = np.fft.rfft(theta3_windowed)

    # Compute amplitude spectrum (normalized)
    amp_theta1 = 2.0 * np.abs(fft_theta1) / n
    amp_theta2 = 2.0 * np.abs(fft_theta2) / n
    amp_theta3 = 2.0 * np.abs(fft_theta3) / n

    # Find peaks for each angle
    peaks1, _ = find_peaks(
        amp_theta1, prominence=prominence, height=height, distance=distance
    )
    peaks2, _ = find_peaks(
        amp_theta2, prominence=prominence, height=height, distance=distance
    )
    peaks3, _ = find_peaks(
        amp_theta3, prominence=prominence, height=height, distance=distance
    )

    # Print peak information
    print("\n--- FFT Peak Detection (Three Pendulum System) ---")
    print(f"Motion Type: {motion_type}")
    print(f"Excited Mode: {excited_mode}")

    for i, (peaks, amp, label) in enumerate(
        [
            (peaks1, amp_theta1, "Œ∏‚ÇÅ"),
            (peaks2, amp_theta2, "Œ∏‚ÇÇ"),
            (peaks3, amp_theta3, "Œ∏‚ÇÉ"),
        ],
        1,
    ):
        print(f"\n{label} - Found {len(peaks)} peaks:")
        if len(peaks) > 0:
            for j, idx in enumerate(peaks[:5]):
                print(
                    f"  Peak {j + 1}: f = {freqs[idx]:.4f} Hz, amplitude = {amp[idx]:.6f}"
                )
        else:
            print(
                f"  No peaks found! Max amplitude = {np.max(amp):.6f} at f = {freqs[np.argmax(amp)]:.4f} Hz"
            )

    fig, (ax1, ax2, ax3) = plt.subplots(3, 1, figsize=(14, 12), sharex=True)
    fig.subplots_adjust(top=0.9, bottom=0.1, left=0.1, right=0.85, hspace=0.3)

    title_text = (
        f"FFT Spectrum of Three Coupled Pendulum Angles\n"
        f"Motion Type: {motion_type} | Excited: {excited_mode}"
    )
    fig.suptitle(title_text, fontsize=15, fontweight="bold")

    max_freq_display = max(f1_theory, f2_theory, f3_theory) * 2.5

    mode_colors = ["green", "orange", "purple"]

    # =========================================================================
    # Plot 1: FFT for Œ∏‚ÇÅ
    # =========================================================================
    ax1.plot(
        freqs, amp_theta1, "crimson", linewidth=1.5, label="FFT Amplitude", zorder=1
    )

    if len(peaks1) > 0:
        ax1.plot(
            freqs[peaks1],
            amp_theta1[peaks1],
            "o",
            color="blue",
            markeredgecolor="black",
            markeredgewidth=1.5,
            markersize=10,
            label="Detected Peaks",
            zorder=5,
        )

        peak_amps = amp_theta1[peaks1]
        top_peaks_idx = np.argsort(peak_amps)[-5:][::-1]

        for i, local_idx in enumerate(top_peaks_idx):
            peak_idx = peaks1[local_idx]
            peak_freq = freqs[peak_idx]
            peak_amp = amp_theta1[peak_idx]

            ax1.annotate(
                f"{peak_freq:.4f} Hz",
                xy=(peak_freq, peak_amp),
                xytext=(15, -25),
                textcoords="offset points",
                fontsize=9,
                fontweight="bold",
                bbox=dict(
                    boxstyle="round,pad=0.5",
                    facecolor="yellow",
                    edgecolor="black",
                    alpha=0.85,
                ),
                arrowprops=dict(
                    arrowstyle="->",
                    connectionstyle="arc3,rad=0.2",
                    color="blue",
                    lw=1.5,
                ),
                zorder=6,
            )

    for freq, color, label_text in [
        (f1_theory, mode_colors[0], f"$f_1={f1_theory:.4f}$ Hz"),
        (f2_theory, mode_colors[1], f"$f_2={f2_theory:.4f}$ Hz"),
        (f3_theory, mode_colors[2], f"$f_3={f3_theory:.4f}$ Hz"),
    ]:
        ax1.axvline(
            freq,
            color=color,
            linestyle="--",
            linewidth=2.5,
            alpha=0.8,
            label=f"Theory: {label_text}",
            zorder=3,
        )

    textstr1 = (
        f"Numerical Estimates:\n"
        f"(via modal decomposition)\n"
        f"$f_1 \\approx$ {f1_str}\n"
        f"$f_2 \\approx$ {f2_str}\n"
        f"$f_3 \\approx$ {f3_str}\n"
        f"‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ\n"
        f"Peaks detected: {len(peaks1)}"
    )
    ax1.text(
        0.99,
        0.05,
        textstr1,
        transform=ax1.transAxes,
        fontsize=9,
        va="bottom",
        ha="right",
        bbox=dict(
            boxstyle="round",
            facecolor="lightblue",
            edgecolor="black",
            alpha=0.6,
            linewidth=1.5,
        ),
        zorder=6,
    )

    sys_info_text = (
        f"System Parameters:\n"
        f"{'‚îÄ' * 18}\n"
        f"Initial Angles:\n"
        f"  $\\theta_1^{{(0)}} = {theta_1_init:+.2f}¬∞$\n"
        f"  $\\theta_2^{{(0)}} = {theta_2_init:+.2f}¬∞$\n"
        f"  $\\theta_3^{{(0)}} = {theta_3_init:+.2f}¬∞$\n"
        f"Masses:\n"
        f"  $m_1={m1:.2f}$ kg, $m_2={m2:.2f}$ kg\n"
        f"  $m_3={m3:.2f}$ kg\n"
        f"Springs:\n"
        f"  $k_1={k1:.2f}$ N/m, $k_2={k2:.2f}$ N/m\n"
        f"Length: $L = {L:.2f}$ m\n"
        f"Gravity: $g = {g:.2f}$ m/s¬≤\n"
        f"Simulation: $T = {t[-1]:.1f}$ s\n"
        f"Points: $N = {len(t)}$"
    )
    ax1.text(
        1.02,
        0.99,
        sys_info_text,
        transform=ax1.transAxes,
        fontsize=9,
        verticalalignment="top",
        horizontalalignment="left",
        bbox=dict(
            boxstyle="round",
            facecolor="lavender",
            edgecolor="black",
            alpha=0.6,
            linewidth=1.5,
        ),
        fontfamily="monospace",
        zorder=6,
    )

    ax1.set_ylabel("Amplitude (rad)", fontsize=12, fontweight="bold")
    ax1.set_title("$\\theta_1$ (Pendulum 1)", fontsize=13, fontweight="bold", pad=10)
    ax1.grid(True, alpha=0.3, linestyle="--")
    ax1.legend(loc="upper right", fontsize=8, framealpha=0.9, ncol=2)
    ax1.set_xlim(0, max_freq_display)
    ax1.set_ylim(bottom=0)

    # =========================================================================
    # Plot 2: FFT for Œ∏‚ÇÇ
    # =========================================================================
    ax2.plot(freqs, amp_theta2, "green", linewidth=1.5, label="FFT Amplitude", zorder=1)

    if len(peaks2) > 0:
        ax2.plot(
            freqs[peaks2],
            amp_theta2[peaks2],
            "o",
            color="red",
            markeredgecolor="black",
            markeredgewidth=1.5,
            markersize=10,
            label="Detected Peaks",
            zorder=5,
        )

        peak_amps = amp_theta2[peaks2]
        top_peaks_idx = np.argsort(peak_amps)[-5:][::-1]

        for i, local_idx in enumerate(top_peaks_idx):
            peak_idx = peaks2[local_idx]
            peak_freq = freqs[peak_idx]
            peak_amp = amp_theta2[peak_idx]

            ax2.annotate(
                f"{peak_freq:.4f} Hz",
                xy=(peak_freq, peak_amp),
                xytext=(15, -25),
                textcoords="offset points",
                fontsize=9,
                fontweight="bold",
                bbox=dict(
                    boxstyle="round,pad=0.5",
                    facecolor="yellow",
                    edgecolor="black",
                    alpha=0.85,
                ),
                arrowprops=dict(
                    arrowstyle="->", connectionstyle="arc3,rad=0.2", color="red", lw=1.5
                ),
                zorder=6,
            )

    for freq, color, label_text in [
        (f1_theory, mode_colors[0], f"$f_1={f1_theory:.4f}$ Hz"),
        (f2_theory, mode_colors[1], f"$f_2={f2_theory:.4f}$ Hz"),
        (f3_theory, mode_colors[2], f"$f_3={f3_theory:.4f}$ Hz"),
    ]:
        ax2.axvline(
            freq,
            color=color,
            linestyle="--",
            linewidth=2.5,
            alpha=0.8,
            label=f"Theory: {label_text}",
            zorder=3,
        )

    textstr2 = (
        f"Numerical Estimates:\n"
        f"(via modal decomposition)\n"
        f"$f_1 \\approx$ {f1_str}\n"
        f"$f_2 \\approx$ {f2_str}\n"
        f"$f_3 \\approx$ {f3_str}\n"
        f"‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ\n"
        f"Peaks detected: {len(peaks2)}"
    )
    ax2.text(
        0.99,
        0.05,
        textstr2,
        transform=ax2.transAxes,
        fontsize=9,
        va="bottom",
        ha="right",
        bbox=dict(
            boxstyle="round",
            facecolor="lightblue",
            edgecolor="black",
            alpha=0.6,
            linewidth=1.5,
        ),
        zorder=6,
    )

    ax2.set_ylabel("Amplitude (rad)", fontsize=12, fontweight="bold")
    ax2.set_title("$\\theta_2$ (Pendulum 2)", fontsize=13, fontweight="bold", pad=10)
    ax2.grid(True, alpha=0.3, linestyle="--")
    ax2.legend(loc="upper right", fontsize=8, framealpha=0.9, ncol=2)
    ax2.set_xlim(0, max_freq_display)
    ax2.set_ylim(bottom=0)

    # =========================================================================
    # Plot 3: FFT for Œ∏‚ÇÉ
    # =========================================================================
    ax3.plot(
        freqs, amp_theta3, "royalblue", linewidth=1.5, label="FFT Amplitude", zorder=1
    )

    if len(peaks3) > 0:
        ax3.plot(
            freqs[peaks3],
            amp_theta3[peaks3],
            "o",
            color="darkviolet",
            markeredgecolor="black",
            markeredgewidth=1.5,
            markersize=10,
            label="Detected Peaks",
            zorder=5,
        )

        peak_amps = amp_theta3[peaks3]
        top_peaks_idx = np.argsort(peak_amps)[-5:][::-1]

        for i, local_idx in enumerate(top_peaks_idx):
            peak_idx = peaks3[local_idx]
            peak_freq = freqs[peak_idx]
            peak_amp = amp_theta3[peak_idx]

            ax3.annotate(
                f"{peak_freq:.4f} Hz",
                xy=(peak_freq, peak_amp),
                xytext=(15, -25),
                textcoords="offset points",
                fontsize=9,
                fontweight="bold",
                bbox=dict(
                    boxstyle="round,pad=0.5",
                    facecolor="yellow",
                    edgecolor="black",
                    alpha=0.85,
                ),
                arrowprops=dict(
                    arrowstyle="->",
                    connectionstyle="arc3,rad=0.2",
                    color="darkviolet",
                    lw=1.5,
                ),
                zorder=6,
            )

    for freq, color, label_text in [
        (f1_theory, mode_colors[0], f"$f_1={f1_theory:.4f}$ Hz"),
        (f2_theory, mode_colors[1], f"$f_2={f2_theory:.4f}$ Hz"),
        (f3_theory, mode_colors[2], f"$f_3={f3_theory:.4f}$ Hz"),
    ]:
        ax3.axvline(
            freq,
            color=color,
            linestyle="--",
            linewidth=2.5,
            alpha=0.8,
            label=f"Theory: {label_text}",
            zorder=3,
        )

    textstr3 = (
        f"Numerical Estimates:\n"
        f"(via modal decomposition)\n"
        f"$f_1 \\approx$ {f1_str}\n"
        f"$f_2 \\approx$ {f2_str}\n"
        f"$f_3 \\approx$ {f3_str}\n"
        f"‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ\n"
        f"Peaks detected: {len(peaks3)}"
    )
    ax3.text(
        0.99,
        0.05,
        textstr3,
        transform=ax3.transAxes,
        fontsize=9,
        va="bottom",
        ha="right",
        bbox=dict(
            boxstyle="round",
            facecolor="lightblue",
            edgecolor="black",
            alpha=0.6,
            linewidth=1.5,
        ),
        zorder=6,
    )

    ax3.set_xlabel("Frequency (Hz)", fontsize=12, fontweight="bold")
    ax3.set_ylabel("Amplitude (rad)", fontsize=12, fontweight="bold")
    ax3.set_title("$\\theta_3$ (Pendulum 3)", fontsize=13, fontweight="bold", pad=10)
    ax3.grid(True, alpha=0.3, linestyle="--")
    ax3.legend(loc="upper right", fontsize=8, framealpha=0.9, ncol=2)
    ax3.set_xlim(0, max_freq_display)
    ax3.set_ylim(bottom=0)

    plt.show()

    # Save figure if requested
    if save_fig:
        if filename is None:
            filename = (
                f"fft_triple_pendulum_m1_{m1}_m2_{m2}_m3_{m3}_k1_{k1}_k2_{k2}_L_{L}_"
                f"theta1_{theta_1_init}_theta2_{theta_2_init}_theta3_{theta_3_init}.png"
            )

        # Ensure FIGURES directory exists
        fig_dir = "OUTPUTS/FIGURES/coupled_pendulum/triple_pendulum"
        os.makedirs(fig_dir, exist_ok=True)

        save_path = os.path.join(fig_dir, filename)
        print(f"Saving FFT figure to: {save_path}")

        try:
            fig.savefig(save_path, dpi=300, bbox_inches="tight")
            print("Figure saved successfully!")
        except Exception as e:
            print(f"Error saving figure: {e}")

    return fig, (ax1, ax2, ax3)


def analyze_triple_pendulum_fft(
    theta_1_init=0.0,
    theta_2_init=0.0,
    theta_3_init=10.0,
    m1=1.0,
    m2=1.0,
    m3=1.0,
    k1=5.0,
    k2=5.0,
    L=2.0,
    g=9.81,
    simulation_time=50.0,
    n_points=1000,
    save_fig=False,
    filename=None,
):
    """
    Solve three coupled pendulum equations and plot FFT spectrum.

    Parameters
    ----------
    theta_1_init, theta_2_init, theta_3_init : float
        Initial angles in degrees
    m1, m2, m3 : float
        Masses
    k1, k2 : float
        Spring constants
    L : float
        Pendulum length
    g : float
        Gravitational acceleration
    simulation_time : float
        Simulation duration in seconds
    n_points : int
        Number of time points
    save_fig : bool, optional
        Whether to save the FFT figure (default: False)
    filename : str, optional
        Filename to save the FFT figure (default: None)

    Returns
    -------
    fig : matplotlib.figure.Figure
        Figure with FFT plots
    """
    print("=" * 60)
    print("FFT ANALYSIS OF THREE COUPLED PENDULUM SYSTEM")
    print("=" * 60)

    # Convert initial angles to radians
    theta1_0 = np.radians(theta_1_init)
    theta2_0 = np.radians(theta_2_init)
    theta3_0 = np.radians(theta_3_init)

    # Initial state: [Œ∏‚ÇÅ, œâ‚ÇÅ, Œ∏‚ÇÇ, œâ‚ÇÇ, Œ∏‚ÇÉ, œâ‚ÇÉ]
    y0 = [theta1_0, 0.0, theta2_0, 0.0, theta3_0, 0.0]

    # Time span
    t_span = (0, simulation_time)
    t_eval = np.linspace(0, simulation_time, n_points)

    # Solve the ODE
    print("\nSolving differential equations...")
    print(
        f"  Initial angles: Œ∏‚ÇÅ={theta_1_init}¬∞, Œ∏‚ÇÇ={theta_2_init}¬∞, Œ∏‚ÇÉ={theta_3_init}¬∞"
    )
    print(f"  Simulation time: {simulation_time} s")
    print(f"  Number of points: {n_points}")

    solution = solve_ivp(
        three_coupled_pendulum_derivatives,
        t_span,
        y0,
        args=(g, L, k1, k2, m1, m2, m3),
        method="RK45",
        t_eval=t_eval,
        rtol=1e-6,
        atol=1e-8,
    )

    theta1 = solution.y[0]
    theta2 = solution.y[2]
    theta3 = solution.y[4]
    t = solution.t

    print("Solution computed successfully!")

    # Calculate theoretical frequencies
    omega1_th, omega2_th, omega3_th, f1_th, f2_th, f3_th = (
        th_normal_modes_triple_pendulum(g, L, k1, k2, m1, m2, m3)
    )

    print("\nTheoretical Normal Mode Frequencies:")
    print(f"  œâ‚ÇÅ = {omega1_th:.4f} rad/s  ‚Üí  f‚ÇÅ = {f1_th:.4f} Hz (sloshing mode)")
    print(f"  œâ‚ÇÇ = {omega2_th:.4f} rad/s  ‚Üí  f‚ÇÇ = {f2_th:.4f} Hz (anti-symmetric mode)")
    print(f"  œâ‚ÇÉ = {omega3_th:.4f} rad/s  ‚Üí  f‚ÇÉ = {f3_th:.4f} Hz (breathing mode)")

    print("\nPlotting FFT spectrum...")

    # Plot FFT spectrum with adjusted parameters
    fig, axes = plot_fft_spectrum_triple_pendulum(
        t,
        theta1,
        theta2,
        theta3,
        g,
        L,
        k1,
        k2,
        m1,
        m2,
        m3,
        theta_1_init,
        theta_2_init,
        theta_3_init,
        prominence=0.005,
        distance=5,
        save_fig=save_fig,
        filename=filename,
    )

    return fig


# Run the analysis
fig = analyze_triple_pendulum_fft(
    theta_1_init=5.0,
    theta_2_init=-10.0,
    theta_3_init=5.0,
    m1=1.0,
    m2=1.0,
    m3=1.0,
    k1=5.0,
    k2=5.0,
    L=2.0,
    g=9.81,
    simulation_time=50.0,
    n_points=1000,
    save_fig=True,
    filename="fft_triple_pendulum_mode_3.png",
)


### 2.5 Phase Space and Configuration-Space Trajectories

In [None]:
def plot_phase_and_config_space_triple(
    t,
    theta1,
    omega1,
    theta2,
    omega2,
    theta3,
    omega3,
    g,
    L,
    k1,
    k2,
    m1,
    m2,
    m3,
    theta_1_init,
    theta_2_init,
    theta_3_init,
    arrow_density=8,
    save_fig=False,
    filename=None,
):
    """
    Plot phase space diagrams for three pendulums and configuration space (Lissajous-like figures).

    Layout: 2 rows x 3 columns
    - Row 1: Phase space for each pendulum (Œ∏ vs œâ)
    - Row 2: Configuration space pairs (Œ∏‚ÇÇ vs Œ∏‚ÇÅ), (Œ∏‚ÇÉ vs Œ∏‚ÇÅ), (Œ∏‚ÇÉ vs Œ∏‚ÇÇ)

    Parameters
    ----------
    t : array
        Time array
    theta1, omega1, theta2, omega2, theta3, omega3 : array
        Angles and angular velocities for three pendulums
    g, L, k1, k2, m1, m2, m3 : float
        System parameters
    theta_1_init, theta_2_init, theta_3_init : float
        Initial angles in degrees
    arrow_density : int, optional
        Number of arrows to display along trajectory (default: 8)
    save_fig : bool, optional
        Whether to save the figure (default: False)
    filename : str or None, optional
        Filename for saved figure. Auto-generated if None (default: None)

    Returns
    -------
    fig : matplotlib.figure.Figure
        Figure object
    """
    # Calculate theoretical frequencies
    omega1_theory, omega2_theory, omega3_theory, f1_theory, f2_theory, f3_theory = (
        th_normal_modes_triple_pendulum(g, L, k1, k2, m1, m2, m3)
    )

    # Identify motion type
    tol = 1e-6
    if (
        abs(theta_1_init - theta_2_init) < tol
        and abs(theta_2_init - theta_3_init) < tol
    ):
        if abs(theta_1_init) < tol:
            motion_type = "No Initial Displacement"
        else:
            motion_type = "Sloshing Mode (All in Phase)"
    elif (
        abs(theta_1_init - theta_3_init) < tol
        and abs(theta_2_init + 2 * theta_1_init) < tol
    ):
        motion_type = "Breathing Mode"
    elif abs(theta_1_init + theta_3_init) < tol and abs(theta_2_init) < tol:
        motion_type = "Anti-Symmetric Mode"
    else:
        motion_type = "Mixed-Mode Oscillation"

    # Create figure with 2 rows x 3 columns
    fig, axes = plt.subplots(2, 3, figsize=(16, 12))
    fig.subplots_adjust(
        top=0.9, bottom=0.1, left=0.1, right=0.85, hspace=0.35, wspace=0.3
    )

    # Create subplots
    ax_phase1 = axes[0, 0]  # Phase space Œ∏‚ÇÅ
    ax_phase2 = axes[0, 1]  # Phase space Œ∏‚ÇÇ
    ax_phase3 = axes[0, 2]  # Phase space Œ∏‚ÇÉ

    ax_config1 = axes[1, 0]  # Config space: Œ∏‚ÇÇ vs Œ∏‚ÇÅ
    ax_config2 = axes[1, 1]  # Config space: Œ∏‚ÇÉ vs Œ∏‚ÇÅ
    ax_config3 = axes[1, 2]  # Config space: Œ∏‚ÇÉ vs Œ∏‚ÇÇ

    # Main title
    fig.suptitle(
        f"Phase Space and Configuration Space Analysis - Three Coupled Pendulums\n"
        f"Motion Type: {motion_type}",
        fontsize=16,
        fontweight="bold",
    )

    # Color map for time evolution
    colors = plt.cm.viridis(np.linspace(0, 1, len(t)))  # type: ignore
    n_points = len(t)

    # Convert angles to degrees
    theta1_deg = np.degrees(theta1)
    theta2_deg = np.degrees(theta2)
    theta3_deg = np.degrees(theta3)

    # Arrow indices
    arrow_indices = np.linspace(0, n_points - 1, arrow_density, dtype=int)

    # =========================================================================
    # ROW 1: PHASE SPACE PLOTS
    # =========================================================================

    phase_data = [
        (
            ax_phase1,
            theta1_deg,
            omega1,
            "$\\theta_1$",
            "$\\dot{\\theta}_1$",
        ),
        (
            ax_phase2,
            theta2_deg,
            omega2,
            "$\\theta_2$",
            "$\\dot{\\theta}_2$",
        ),
        (
            ax_phase3,
            theta3_deg,
            omega3,
            "$\\theta_3$",
            "$\\dot{\\theta}_3$",
        ),
    ]

    for (
        ax,
        theta_deg,
        omega,
        theta_label,
        omega_label,
    ) in phase_data:
        # Plot trajectory with color gradient
        for i in range(len(t) - 1):
            ax.plot(
                theta_deg[i : i + 2],
                omega[i : i + 2],
                color=colors[i],
                linewidth=1.5,
                alpha=0.7,
            )

        # Mark initial and final points
        ax.plot(
            theta_deg[0],
            omega[0],
            "o",
            color="lime",
            ms=10,
            mec="black",
            mew=2,
            label="Start",
            zorder=5,
        )
        ax.plot(
            theta_deg[-1],
            omega[-1],
            "o",
            color="red",
            ms=10,
            mec="black",
            mew=2,
            label="End",
            zorder=5,
        )

        # Set preliminary limits
        ax.set_xlim(
            theta_deg.min() - 0.1 * np.ptp(theta_deg),
            theta_deg.max() + 0.1 * np.ptp(theta_deg),
        )
        ax.set_ylim(
            omega.min() - 0.1 * np.ptp(omega), omega.max() + 0.1 * np.ptp(omega)
        )

        # Add direction arrows with proper scaling
        theta_span = np.ptp(theta_deg)
        omega_span = np.ptp(omega)

        for idx in arrow_indices[:-1]:
            if idx + 1 < n_points:
                dx_data = theta_deg[idx + 1] - theta_deg[idx]
                dy_data = omega[idx + 1] - omega[idx]

                dx_norm = dx_data / theta_span if theta_span > 0 else 0
                dy_norm = dy_data / omega_span if omega_span > 0 else 0

                norm = np.sqrt(dx_norm**2 + dy_norm**2)

                if norm > 1e-9:
                    dx_norm = dx_norm / norm
                    dy_norm = dy_norm / norm

                    arrow_scale = 0.08
                    arrow_len_x = dx_norm * arrow_scale * theta_span
                    arrow_len_y = dy_norm * arrow_scale * omega_span

                    arrow = FancyArrowPatch(
                        posA=(theta_deg[idx], omega[idx]),
                        posB=(theta_deg[idx] + arrow_len_x, omega[idx] + arrow_len_y),
                        arrowstyle="-|>",
                        mutation_scale=20,
                        color=colors[idx],
                        linewidth=2,
                        zorder=10,
                        alpha=0.8,
                    )
                    ax.add_patch(arrow)

        ax.set_xlabel(f"{theta_label} (degrees)", fontsize=11, fontweight="bold")
        ax.set_ylabel(f"{omega_label} (rad/s)", fontsize=11, fontweight="bold")
        ax.set_title(
            f"Phase Space: Pendulum {theta_label[-2]}", fontsize=12, fontweight="bold"
        )
        ax.grid(True, alpha=0.3, linestyle="--")
        ax.legend(loc="best", fontsize=9, framealpha=0.9)
        ax.axhline(0, color="black", linewidth=0.5, alpha=0.5)
        ax.axvline(0, color="black", linewidth=0.5, alpha=0.5)

    # =========================================================================
    # ROW 2: CONFIGURATION SPACE PLOTS (Lissajous-like figures)
    # =========================================================================

    config_data = [
        (
            ax_config1,
            theta1_deg,
            theta2_deg,
            "$\\theta_1$",
            "$\\theta_2$",
            "Configuration Space: $\\theta_2$ vs $\\theta_1$",
        ),
        (
            ax_config2,
            theta1_deg,
            theta3_deg,
            "$\\theta_1$",
            "$\\theta_3$",
            "Configuration Space: $\\theta_3$ vs $\\theta_1$",
        ),
        (
            ax_config3,
            theta2_deg,
            theta3_deg,
            "$\\theta_2$",
            "$\\theta_3$",
            "Configuration Space: $\\theta_3$ vs $\\theta_2$",
        ),
    ]

    for ax, theta_x, theta_y, xlabel, ylabel, title in config_data:
        # Plot trajectory with color gradient
        for i in range(len(t) - 1):
            ax.plot(
                theta_x[i : i + 2],
                theta_y[i : i + 2],
                color=colors[i],
                linewidth=2,
                alpha=0.8,
            )

        # Mark initial and final points
        ax.plot(
            theta_x[0],
            theta_y[0],
            "o",
            color="lime",
            ms=10,
            mec="black",
            mew=2,
            zorder=5,
        )
        ax.plot(
            theta_x[-1],
            theta_y[-1],
            "o",
            color="red",
            ms=10,
            mec="black",
            mew=2,
            zorder=5,
        )

        # Add direction arrows
        theta_x_span = np.ptp(theta_x)

        for idx in arrow_indices[:-1]:
            if idx + 1 < n_points:
                dx = theta_x[idx + 1] - theta_x[idx]
                dy = theta_y[idx + 1] - theta_y[idx]
                norm = np.sqrt(dx**2 + dy**2)

                if norm > 1e-9:
                    dx_norm = dx / norm
                    dy_norm = dy / norm

                    arrow_len_x = dx_norm * (theta_x_span * 0.08)
                    arrow_len_y = dy_norm * (theta_x_span * 0.08)

                    arrow = FancyArrowPatch(
                        posA=(theta_x[idx], theta_y[idx]),
                        posB=(theta_x[idx] + arrow_len_x, theta_y[idx] + arrow_len_y),
                        arrowstyle="-|>",
                        mutation_scale=20,
                        color=colors[idx],
                        linewidth=2,
                        zorder=10,
                        alpha=0.8,
                    )
                    ax.add_patch(arrow)

        # Add diagonal reference lines for special modes
        xlim = ax.get_xlim()
        ylim = ax.get_ylim()
        max_range = max(abs(np.array([*xlim, *ylim])))

        # Diagonal line (in-phase: Œ∏_x = Œ∏_y)
        ax.plot(
            [-max_range, max_range],
            [-max_range, max_range],
            "g--",
            linewidth=1.5,
            alpha=0.5,
            zorder=1,
        )

        # Anti-diagonal line (out-of-phase: Œ∏_x = -Œ∏_y)
        ax.plot(
            [-max_range, max_range],
            [max_range, -max_range],
            "orange",
            linestyle="--",
            linewidth=1.5,
            alpha=0.5,
            zorder=1,
        )

        # Annotating the diagonal lines
        # In Phase Line
        ax.annotate(
            f"{xlabel} = {ylabel}",
            xy=(0.62 * max_range, 0.6 * max_range),
            xytext=(8, 0),
            textcoords="offset points",
            color="green",
            fontsize=9,
            fontweight="bold",
            rotation=45,
            va="center",
            ha="left",
            backgroundcolor="white",
        )
        # Out of Phase Line
        ax.annotate(
            f"{xlabel} = -{ylabel}",
            xy=(0.62 * max_range, -0.6 * max_range),
            xytext=(8, 0),
            textcoords="offset points",
            color="orange",
            fontsize=9,
            fontweight="bold",
            rotation=-45,
            va="center",
            ha="left",
            backgroundcolor="white",
        )

        ax.set_xlabel(f"{xlabel} (degrees)", fontsize=11, fontweight="bold")
        ax.set_ylabel(f"{ylabel} (degrees)", fontsize=11, fontweight="bold")
        ax.set_title(title, fontsize=12, fontweight="bold")
        ax.grid(True, alpha=0.3, linestyle="--")
        ax.axhline(0, color="black", linewidth=0.8, alpha=0.5)
        ax.axvline(0, color="black", linewidth=0.8, alpha=0.5)
        ax.set_aspect("equal", adjustable="box")

    # Add system info box to first config plot
    sys_info_text = (
        f"System Parameters:\n"
        f"{'‚îÄ' * 16}\n"
        f"Initial Angles:\n"
        f"  $\\theta_1^{{(0)}} = {theta_1_init:+.2f}¬∞$\n"
        f"  $\\theta_2^{{(0)}} = {theta_2_init:+.2f}¬∞$\n"
        f"  $\\theta_3^{{(0)}} = {theta_3_init:+.2f}¬∞$\n"
        f"$m_1={m1:.2f}$, $m_2={m2:.2f}$\n$m_3={m3:.2f}$ kg\n"
        f"$k_1={k1:.2f}$, $k_2={k2:.2f}$ N/m\n"
        f"$L = {L:.2f}$ m, $g = {g:.2f}$ m/s¬≤\n"
        f"$T_{{sim}} = {t[-1]:.1f}$ s\n"
        f"{'‚îÄ' * 16}\n"
        f"Normal Modes:\n"
        f"$f_1 = {f1_theory:.4f}$ Hz\n"
        f"$f_2 = {f2_theory:.4f}$ Hz\n"
        f"$f_3 = {f3_theory:.4f}$ Hz"
    )
    ax_phase3.text(
        1.05,
        0.99,
        sys_info_text,
        transform=ax_phase3.transAxes,
        fontsize=9,
        va="top",
        ha="left",
        bbox=dict(
            boxstyle="round",
            facecolor="lightyellow",
            alpha=0.6,
            edgecolor="black",
            linewidth=1.5,
        ),
        fontfamily="monospace",
    )

    # Add note about Lissajous figures to middle config plot
    lissajous_note = (
        f"Lissajous-like Figures\n"
        f"{'‚îÄ' * 16}\n"
        f"These plots show the\n"
        f"relationship between\n"
        f"pairs of pendulum angles.\n"
        f"Closed curves indicate\n"
        f"periodic or quasi-periodic\n"
        f"motion in that subspace."
    )
    ax_config3.text(
        1.05,
        0.01,
        lissajous_note,
        transform=ax_config3.transAxes,
        fontsize=9,
        va="bottom",
        ha="left",
        bbox=dict(
            boxstyle="round",
            facecolor="lightcyan",
            alpha=0.6,
            edgecolor="black",
            linewidth=1.5,
        ),
        fontfamily="monospace",
    )

    # Add colorbar for time
    sm = plt.cm.ScalarMappable(cmap="viridis", norm=plt.Normalize(vmin=0, vmax=t[-1]))  # type: ignore
    sm.set_array([])
    cbar = fig.colorbar(
        sm,
        ax=axes,
        orientation="horizontal",
        pad=0.1,
        aspect=50,
        shrink=0.5,
    )
    cbar.set_label("Time (s)", fontsize=11, fontweight="bold")

    # Save figure if requested
    if save_fig:
        if filename is None:
            filename = (
                f"phase_config_space_triple_"
                f"m1_{m1}_m2_{m2}_m3_{m3}_"
                f"k1_{k1}_k2_{k2}_L_{L}_"
                f"theta1_{theta_1_init:.1f}_theta2_{theta_2_init:.1f}_theta3_{theta_3_init:.1f}_"
                f"t_{t[-1]:.0f}s.png"
            )

        output_dir = "OUTPUTS/FIGURES/coupled_pendulum/triple_pendulum"
        os.makedirs(output_dir, exist_ok=True)

        save_path = os.path.join(output_dir, filename)

        print(f"\nSaving figure to: {save_path}")
        fig.savefig(save_path, dpi=300, bbox_inches="tight")
        print("Figure saved successfully!")

    return fig


def analyze_triple_pendulum_phase_space(
    theta_1_init=0.0,
    theta_2_init=0.0,
    theta_3_init=10.0,
    m1=1.0,
    m2=1.0,
    m3=1.0,
    k1=5.0,
    k2=5.0,
    L=2.0,
    g=9.81,
    simulation_time=50.0,
    n_points=1000,
    arrow_density=8,
    save_fig=False,
    filename=None,
):
    """
    Solve three coupled pendulum equations and plot phase/configuration space.

    Parameters
    ----------
    theta_1_init, theta_2_init, theta_3_init : float
        Initial angles in degrees
    m1, m2, m3 : float
        Masses
    k1, k2 : float
        Spring constants
    L : float
        Pendulum length
    g : float
        Gravitational acceleration
    simulation_time : float
        Simulation duration in seconds
    n_points : int
        Number of time points
    arrow_density : int
        Number of arrows along trajectory
    save_fig : bool, optional
        Whether to save the figure
    filename : str or None, optional
        Filename for saved figure

    Returns
    -------
    fig : matplotlib.figure.Figure
        Figure with phase space plots
    """
    print("=" * 60)
    print("PHASE SPACE ANALYSIS - THREE COUPLED PENDULUMS")
    print("=" * 60)

    # Convert initial angles to radians
    theta1_0 = np.radians(theta_1_init)
    theta2_0 = np.radians(theta_2_init)
    theta3_0 = np.radians(theta_3_init)

    # Initial state: [Œ∏‚ÇÅ, œâ‚ÇÅ, Œ∏‚ÇÇ, œâ‚ÇÇ, Œ∏‚ÇÉ, œâ‚ÇÉ]
    y0 = [theta1_0, 0.0, theta2_0, 0.0, theta3_0, 0.0]

    # Time span
    t_span = (0, simulation_time)
    t_eval = np.linspace(0, simulation_time, n_points)

    # Solve the ODE
    print("\nSolving differential equations...")
    print(
        f"  Initial angles: Œ∏‚ÇÅ={theta_1_init}¬∞, Œ∏‚ÇÇ={theta_2_init}¬∞, Œ∏‚ÇÉ={theta_3_init}¬∞"
    )
    print(f"  Simulation time: {simulation_time} s")
    print(f"  Number of points: {n_points}")

    solution = solve_ivp(
        three_coupled_pendulum_derivatives,
        t_span,
        y0,
        args=(g, L, k1, k2, m1, m2, m3),
        method="RK45",
        t_eval=t_eval,
        rtol=1e-5,
        atol=1e-7,
    )

    theta1 = solution.y[0]
    omega1 = solution.y[1]
    theta2 = solution.y[2]
    omega2 = solution.y[3]
    theta3 = solution.y[4]
    omega3 = solution.y[5]
    t = solution.t

    print("Solution computed successfully!")

    # Calculate theoretical frequencies
    omega1_th, omega2_th, omega3_th, f1_th, f2_th, f3_th = (
        th_normal_modes_triple_pendulum(g, L, k1, k2, m1, m2, m3)
    )

    print("\nTheoretical Normal Mode Frequencies:")
    print(f"  œâ‚ÇÅ = {omega1_th:.4f} rad/s  ‚Üí  f‚ÇÅ = {f1_th:.4f} Hz (sloshing mode)")
    print(f"  œâ‚ÇÇ = {omega2_th:.4f} rad/s  ‚Üí  f‚ÇÇ = {f2_th:.4f} Hz (anti-symmetric mode)")
    print(f"  œâ‚ÇÉ = {omega3_th:.4f} rad/s  ‚Üí  f‚ÇÉ = {f3_th:.4f} Hz (breathing mode)")

    print("\nPlotting phase and configuration space...")

    # Plot phase space
    fig = plot_phase_and_config_space_triple(
        t,
        theta1,
        omega1,
        theta2,
        omega2,
        theta3,
        omega3,
        g,
        L,
        k1,
        k2,
        m1,
        m2,
        m3,
        theta_1_init,
        theta_2_init,
        theta_3_init,
        arrow_density=arrow_density,
        save_fig=save_fig,
        filename=filename,
    )

    return fig


# Run the analysis
fig = analyze_triple_pendulum_phase_space(
    theta_1_init=0.0,
    theta_2_init=10.0,
    theta_3_init=0.0,
    m1=1.0,
    m2=1.0,
    m3=1.0,
    k1=2,
    k2=2,
    L=2.0,
    g=9.81,
    simulation_time=20.0,
    n_points=1000,
    arrow_density=20,
    save_fig=True,
    filename="triple_pendulum_phase_space_demo_2",
)


## 3. Two Coupled Mass-Spring Oscillators

Apart from pendulums, another classic coupled-oscillator system is a chain of masses connected by springs. Here we derive the equations of motion for a 2‚ÄëDOF mass‚Äìspring system.

**Geometry and generalized coordinates**

The mass‚Äìspring system in your notebook is the classic 2‚ÄëDOF chain

$$
\text{Left wall}\;-\,[k_1]\,-\,m_1\,-\,[k_2]\,-\,m_2\,-\,[k_3]\,-\;\text{Right wall}.
$$

Let $x_1(t)$ and $x_2(t)$ be the **small longitudinal displacements from equilibrium** of masses $m_1$ and $m_2$ (positive to the right). With this convention:

- spring 1 extension is $x_1$
- spring 2 extension is $(x_2-x_1)$
- spring 3 extension is $-x_2$ (but the energy uses the square, so it contributes $\tfrac12 k_3 x_2^2$)

This is exactly the energy model used in the code.


**Kinetic energy**

Each mass moves on a line, so the kinetic energy is

$$
\boxed{
T=\frac12 m_1 \dot x_1^{\,2}+\frac12 m_2 \dot x_2^{\,2}.
}
$$

**Potential energy (springs)**

Using the extensions above, the spring potential energy is

$$
\boxed{
V=\frac12 k_1 x_1^2+\frac12 k_2(x_2-x_1)^2+\frac12 k_3 x_2^2.
}
$$


**Lagrangian**

$$
\boxed{\mathcal{L}=T-V =\frac12 m_1 \dot x_1^{\,2}+\frac12 m_2 \dot x_2^{\,2}
-\left[\frac12 k_1 x_1^2+\frac12 k_2(x_2-x_1)^2+\frac12 k_3 x_2^2\right].
}
$$


**Euler‚ÄìLagrange equations (equations of motion)**

For $x_1$:

$$
\begin{aligned}
\dfrac{\partial\mathcal{L}}{\partial \dot x_1}&=m_1\dot x_1 \Rightarrow \dfrac{d}{dt}(\dfrac{\partial\mathcal{L}}{\partial \dot x_1})=m_1\ddot x_1\\[6pt]
\dfrac{\partial V}{\partial x_1}&=k_1x_1+k_2(x_2-x_1)(-1)=k_1x_1-k_2(x_2-x_1)
\end{aligned}
$$

Euler‚ÄìLagrange: 
$$
m_1\ddot x_1+\dfrac{\partial V}{\partial x_1}=0
$$, 
hence

$$
\boxed{
m_1\ddot x_1+(k_1+k_2)x_1-k_2x_2=0.
}
$$

For $x_2$:

$$
\begin{aligned}
\dfrac{\partial\mathcal{L}}{\partial \dot x_2}&=m_2\dot x_2 \Rightarrow \dfrac{d}{dt}(\dfrac{\partial\mathcal{L}}{\partial \dot x_2})=m_2\ddot x_2\\[6pt]
\dfrac{\partial V}{\partial x_2}&=k_2(x_2-x_1)+k_3x_2
\end{aligned}
$$

So

$$
\boxed{
m_2\ddot x_2+(k_2+k_3)x_2-k_2x_1=0.
}
$$

These are the coupled linear ODEs your mass‚Äìspring simulation integrates.


**Normal modes from the eigenvalue problem**

Write the system in matrix form:

$$
\boxed{M\ddot{\mathbf{x}}+K\mathbf{x}=0,}
\qquad
\mathbf{x}=\begin{bmatrix}x_1\\x_2\end{bmatrix},
\quad
M=\begin{bmatrix}m_1&0\\0&m_2\end{bmatrix},
\quad
K=\begin{bmatrix}k_1+k_2&-k_2\\-k_2&k_2+k_3\end{bmatrix}.
$$

Seek normal-mode solutions $\mathbf{x}(t)=\mathbf{a}\,e^{i\omega t}$. Then $\ddot{\mathbf{x}}=-\omega^2\mathbf{x}$ and

$$
(K-\omega^2 M)\mathbf{a}=0.
$$

Nontrivial mode shapes $\mathbf{a}\neq 0$ require

$$
\boxed{\det(K-\omega^2M)=0.}
$$

Compute:

$$
\det\!\begin{bmatrix}
k_1+k_2-\omega^2 m_1 & -k_2\\
-k_2 & k_2+k_3-\omega^2 m_2
\end{bmatrix}=0
$$

$$
\Rightarrow\;
(k_1+k_2-\omega^2 m_1)(k_2+k_3-\omega^2 m_2)-k_2^2=0.
$$

This is a quadratic in $\omega^2$. Solving gives

$$
\boxed{
\omega^2 =
\frac{m_2(k_1+k_2)+m_1(k_2+k_3)\pm \sqrt{\Delta}}{2m_1m_2},
}
$$

where

$$
\boxed{
\Delta=\left[m_2(k_1+k_2)+m_1(k_2+k_3)\right]^2
-4m_1m_2\left[(k_1+k_2)(k_2+k_3)-k_2^2\right].
}
$$

**Mode shapes (in-phase vs out-of-phase)**
For each eigenfrequency $\omega_j$, the eigenvector $\mathbf{a}^{(j)}=[a_1^{(j)},a_2^{(j)}]^T$ satisfies e.g.

$$
\begin{aligned}
&(k_1+k_2-\omega_j^2 m_1)a_1^{(j)}-k_2 a_2^{(j)}=0 \\
\Rightarrow & \frac{a_2^{(j)}}{a_1^{(j)}}=\frac{k_1+k_2-\omega_j^2 m_1}{k_2}.

\end{aligned}
$$

- The **lower** frequency mode typically has $a_1$ and $a_2$ with the **same sign** (in-phase).
- The **higher** frequency mode typically has $a_1$ and $a_2$ with **opposite signs** (out-of-phase).

A very clear physical case is the symmetric system $m_1=m_2=m$ and $k_1=k_3=k$:

- **In-phase** ($x_1=x_2$): the middle spring doesn‚Äôt stretch ($x_2-x_1=0$), so it does not contribute restoring force. The frequency becomes
  $$
  \boxed{\omega_1^2=\frac{k}{m}.}
  $$
- **Out-of-phase** ($x_1=-x_2$): the middle spring stretches twice as much, increasing stiffness, giving
  $$
  \boxed{\omega_2^2=\frac{k+2k_2}{m}.}
  $$


**General small-angle solution for the displacements $x_1(t),x_2(t)$**

Because the equations are linear, the solution is a superposition of the two normal modes:

$$
\boxed{
\mathbf{x}(t) =
C_1\,\mathbf{v}_1\cos(\omega_1 t+\phi_1)
+
C_2\,\mathbf{v}_2\cos(\omega_2 t+\phi_2)
}
$$,

where $\mathbf{v}_1,\mathbf{v}_2$ are the eigenvectors (mode shapes) of $K\mathbf{v}=\omega^2 M\mathbf{v}$.

So explicitly,

$$
x_1(t)=C_1 v_{1,1}\cos(\omega_1 t+\phi_1)+C_2 v_{2,1}\cos(\omega_2 t+\phi_2),
$$

$$
x_2(t)=C_1 v_{1,2}\cos(\omega_1 t+\phi_1)+C_2 v_{2,2}\cos(\omega_2 t+\phi_2),
$$

with constants fixed by $x_1(0),x_2(0),\dot x_1(0),\dot x_2(0)$.


**Beats / energy transfer (why it happens, and explicit formulas)**

‚ÄúBeats‚Äù occur when **both modes are excited** (so both $C_1$ and $C_2$ are nonzero). Then each displacement contains two nearby frequencies, producing amplitude modulation and apparent energy exchange between masses.

For the symmetric case $m_1=m_2=m$, $k_1=k_3=k$, a classic initial condition is
$$
x_1(0)=x_0,\quad x_2(0)=0,\quad \dot x_1(0)=\dot x_2(0)=0,
$$
which excites both modes equally and yields

$$
\boxed{
x_1(t)=\frac{x_0}{2}\left(\cos(\omega_1 t)+\cos(\omega_2 t)\right),
\qquad
x_2(t)=\frac{x_0}{2}\left(\cos(\omega_1 t)-\cos(\omega_2 t)\right).
}
$$

Using trig identities:

$$
\boxed{
\begin{aligned}
x_1(t)&=x_0\cos\!\left(\frac{\omega_1+\omega_2}{2}t\right)\cos\!\left(\frac{\omega_2-\omega_1}{2}t\right)\\
x_2(t)&=-x_0\sin\!\left(\frac{\omega_1+\omega_2}{2}t\right)\sin\!\left(\frac{\omega_2-\omega_1}{2}t\right)
\end{aligned} 
}
$$

The slow envelope frequency is $\tfrac{\omega_2-\omega_1}{2}$, and the beat (energy exchange) period is

$$
\boxed{T_{\mathrm{beat}}=\frac{2\pi}{\omega_2-\omega_1}.}
$$


**From displacements to actual positions in the animation**

The code treats $x_1,x_2$ as **displacements from equilibrium** and then constructs actual positions as

$$
\boxed{
X_1(t)=X_{1,\mathrm{eq}}+x_1(t),\qquad X_2(t)=X_{2,\mathrm{eq}}+x_2(t).
}
$$

In your animation, the equilibrium locations are chosen as
$$
X_{1,\mathrm{eq}}=\frac{L_{\text{system}}}{3},\qquad X_{2,\mathrm{eq}}=\frac{2L_{\text{system}}}{3},
$$
so the masses start equally spaced between walls and the three springs have the same relaxed length $L_{\text{system}}/3$.


### 3.1 Geometry and Parameters

We define the wall‚Äìspring‚Äìmass layout, sign conventions for displacements, and a simple drawing routine to confirm geometry before integrating the equations of motion.

In [None]:
def calculate_fixed_hook_length_mass_spring(relaxed_spring_length, hook_ratio=0.12):
    """
    Calculate the fixed hook length based on the relaxed spring length.
    """
    coiled_spring_len = relaxed_spring_length / (1 + 2 * hook_ratio)
    hook_len = hook_ratio * coiled_spring_len
    return hook_len


def draw_wall(ax, x_position, y_center, wall_width, wall_height, color="saddlebrown"):
    """
    Draw a vertical wall at the specified x position.
    """
    rect = Rectangle(
        (x_position - wall_width / 2, y_center - wall_height / 2),
        wall_width,
        wall_height,
        color=color,
        zorder=5,
    )
    ax.add_patch(rect)
    return rect


def draw_coupled_mass_spring_system(
    x1=0.0,
    x2=0.0,
    system_length=12.0,
    y_position=0.0,
    num_coils=8,
    mass_size=25,
    padding=0.5,
):
    """
    Draw a coupled mass-spring system with horizontal motion only.
    Configuration: Left Wall ‚Üí Spring1 ‚Üí Mass1 ‚Üí Spring2 ‚Üí Mass2 ‚Üí Spring3 ‚Üí Right Wall

    Masses are in equilibrium at L/3 and 2L/3.
    x1 and x2 are DISPLACEMENTS from equilibrium.

    Parameters:
    -----------
    x1 : float
        Displacement of the first mass from equilibrium (L/3)
    x2 : float
        Displacement of the second mass from equilibrium (2L/3)
    system_length : float
        Total length of the system (distance between walls)
    y_position : float
        Y position of the horizontal system (all masses at same height)
    num_coils : int
        Number of coils in each spring
    mass_size : float
        Size of the mass markers
    padding : float
        Padding around the plot
    """
    # Wall positions
    left_wall_x = 0.0
    right_wall_x = system_length

    # Wall dimensions
    wall_width = 0.4
    wall_height = 1.5

    # Equilibrium positions
    eq_x1 = system_length / 3.0
    eq_x2 = 2.0 * system_length / 3.0

    # Actual positions based on displacement
    pos_x1 = eq_x1 + x1
    pos_x2 = eq_x2 + x2

    # Relaxed spring length (distance between equilibrium points)
    relaxed_spring_length = system_length / 3.0

    # Calculate hook length based on relaxed spring length
    hook_length = calculate_fixed_hook_length_mass_spring(relaxed_spring_length)

    # Spring radius (visual thickness)
    spring_radius = 0.15

    # Create figure
    fig, ax = plt.subplots(figsize=(10, 4))

    # Draw left wall
    draw_wall(ax, left_wall_x, y_position, wall_width, wall_height)

    # Draw right wall
    draw_wall(ax, right_wall_x, y_position, wall_width, wall_height)

    # Draw equilibrium position reference lines
    # These show where masses would be at rest
    ax.axvline(x=eq_x1, color="gray", linestyle="--", lw=1, alpha=0.5)
    ax.axvline(x=eq_x2, color="gray", linestyle="--", lw=1, alpha=0.5)

    # Draw horizontal reference line
    ax.axhline(y=y_position, color="gray", linestyle=":", lw=1, alpha=0.3)

    # Draw Spring 1: Left Wall to Mass 1
    spring1_start = (left_wall_x + wall_width / 2, y_position)
    spring1_end = (pos_x1, y_position)
    x_spring1, y_spring1 = draw_spring_with_hook(
        spring1_start,
        spring1_end,
        num_coils=num_coils,
        radius=spring_radius,
        hook_length=hook_length,
    )
    ax.plot(x_spring1, y_spring1, color="darkgrey", lw=2)

    # Draw Spring 2: Mass 1 to Mass 2
    spring2_start = (pos_x1, y_position)
    spring2_end = (pos_x2, y_position)
    x_spring2, y_spring2 = draw_spring_with_hook(
        spring2_start,
        spring2_end,
        num_coils=num_coils,
        radius=spring_radius,
        hook_length=hook_length,
    )
    ax.plot(x_spring2, y_spring2, color="darkgrey", lw=2)

    # Draw Spring 3: Mass 2 to Right Wall
    spring3_start = (pos_x2, y_position)
    spring3_end = (right_wall_x - wall_width / 2, y_position)
    x_spring3, y_spring3 = draw_spring_with_hook(
        spring3_start,
        spring3_end,
        num_coils=num_coils,
        radius=spring_radius,
        hook_length=hook_length,
    )
    ax.plot(x_spring3, y_spring3, color="darkgrey", lw=2)

    # Draw masses
    ax.plot(pos_x1, y_position, "ro", ms=mass_size, label="m‚ÇÅ", zorder=10)
    ax.plot(pos_x2, y_position, "bo", ms=mass_size, label="m‚ÇÇ", zorder=10)

    # Set plot limits
    x_min = left_wall_x - padding
    x_max = right_wall_x + padding
    y_min = y_position - wall_height / 2 - padding
    y_max = y_position + wall_height / 2 + padding

    ax.set_xlim(x_min, x_max)
    ax.set_ylim(y_min, y_max)
    ax.set_title(
        f"Coupled Mass-Spring System: x‚ÇÅ = {x1:.2f} (disp), x‚ÇÇ = {x2:.2f} (disp)",
        fontsize=14,
    )
    ax.set_xticks([])
    ax.set_yticks([])
    ax.set_xticklabels([])
    ax.set_yticklabels([])

    ax.legend(loc="upper right", ncol=2, fontsize=12, markerscale=0.5)

    return fig, ax


fig, ax = draw_coupled_mass_spring_system(x1=-1, x2=1, system_length=12.0)
plt.show()


### 3.2 ODE Model, Normal Modes, and Animation

Just as with the pendulum systems, we express the dynamics as a first-order ODE system and integrate it numerically. We also compute theoretical normal-mode frequencies from the linearized model for comparison with simulation results.

#### 3.2.1 Equations of Motion (State-Space Form)

The function `two_mass_spring_derivatives` implements the ODE system in first-order form with state $y=[x_1,v_1,x_2,v_2]$ and accelerations
  $$
  \ddot x_1=\frac{-(k_1+k_2)x_1+k_2x_2}{m_1},\qquad
  \ddot x_2=\frac{k_2x_1-(k_2+k_3)x_2}{m_2}.
  $$


In [None]:
def two_mass_spring_derivatives(t, y, m1, m2, k1, k2, k3):
    """
    Compute the derivatives for the two-mass three-spring system.

    System: Left Wall ‚Äî[k1]‚Äî m1 ‚Äî[k2]‚Äî m2 ‚Äî[k3]‚Äî Right Wall

    Equations of motion (from theory.md):
    m‚ÇÅ·∫ç‚ÇÅ + (k‚ÇÅ + k‚ÇÇ)x‚ÇÅ - k‚ÇÇx‚ÇÇ = 0
    m‚ÇÇ·∫ç‚ÇÇ + (k‚ÇÇ + k‚ÇÉ)x‚ÇÇ - k‚ÇÇx‚ÇÅ = 0

    Parameters
    ----------
    t : float
        Time (not used explicitly, required by solve_ivp)
    y : array-like
        State vector [x‚ÇÅ, v‚ÇÅ, x‚ÇÇ, v‚ÇÇ] where v = dx/dt
    m1, m2 : float
        Masses
    k1, k2, k3 : float
        Spring constants

    Returns
    -------
    list
        Derivatives [dx‚ÇÅ/dt, dv‚ÇÅ/dt, dx‚ÇÇ/dt, dv‚ÇÇ/dt]
    """
    x1, v1, x2, v2 = y

    # Accelerations from equations of motion
    a1 = (-(k1 + k2) * x1 + k2 * x2) / m1
    a2 = (k2 * x1 - (k2 + k3) * x2) / m2

    return [v1, a1, v2, a2]


#### 3.2.2 Theoretical Normal Mode Frequencies 

The function `th_normal_modes_mass_spring_system` builds $M$ and $K$ and solves the generalized eigenproblem $K\mathbf{v}=\omega^2 M\mathbf{v}$ using `eigh`, returning $\omega_{1,2}$, $f_{1,2}$, and the mode-shape vectors.


In [None]:
def th_normal_modes_mass_spring_system(m1, m2, k1, k2, k3):
    """
    Calculate the theoretical normal mode frequencies for the two-mass system.

    Solves the generalized eigenvalue problem: K*v = œâ¬≤*M*v

    From theory.md:
    œâ¬≤ = [(m‚ÇÇ(k‚ÇÅ+k‚ÇÇ) + m‚ÇÅ(k‚ÇÇ+k‚ÇÉ)) ¬± ‚àöŒî] / (2m‚ÇÅm‚ÇÇ)
    where Œî = [m‚ÇÇ(k‚ÇÅ+k‚ÇÇ) + m‚ÇÅ(k‚ÇÇ+k‚ÇÉ)]¬≤ - 4m‚ÇÅm‚ÇÇ[(k‚ÇÅ+k‚ÇÇ)(k‚ÇÇ+k‚ÇÉ) - k‚ÇÇ¬≤]

    Parameters
    ----------
    m1, m2 : float
        Masses
    k1, k2, k3 : float
        Spring constants

    Returns
    -------
    tuple
        (omega1, omega2, f1, f2, mode1, mode2)
        Angular frequencies (rad/s), frequencies (Hz), and mode shapes
    """
    # Mass matrix
    M = np.diag([m1, m2])

    # Stiffness matrix
    K = np.array([[k1 + k2, -k2], [-k2, k2 + k3]])

    # Solve generalized eigenvalue problem
    omega_sq, modes = eigh(K, M)

    # Sort by frequency (should already be sorted)
    idx = np.argsort(omega_sq)
    omega_sq = omega_sq[idx]
    modes = modes[:, idx]

    # Ensure positive frequencies
    omega_sq = np.maximum(omega_sq, 0.0)

    omega1 = np.sqrt(omega_sq[0])
    omega2 = np.sqrt(omega_sq[1])

    # Convert to Hz
    f1 = omega1 / (2 * np.pi)
    f2 = omega2 / (2 * np.pi)

    # Mode shapes (normalized)
    mode1 = modes[:, 0]
    mode2 = modes[:, 1]

    return omega1, omega2, f1, f2, mode1, mode2


#### 3.2.3 Numerical Frequency Estimation from Simulation Data


The function`est_freqs_mass_spring_system` estimates dominant frequencies from simulated $x_1(t),x_2(t)$ by:
  1) solving the same generalized eigenproblem to get eigenvectors $V$  
  2) projecting to modal coordinates $q(t)=V^T M x(t)$  
  3) applying a Hann window and FFT (`rfft`) and picking the largest non‚ÄëDC peak per mode.

In [None]:
def est_freqs_mass_spring_system(t, x1, x2, m1, m2, k1, k2, k3, use_hann_window=True):
    """
    Estimate the normal mode frequencies from simulation data using FFT.

    Projects the simulation data onto the normal mode coordinates and
    performs FFT to extract dominant frequencies.

    Parameters
    ----------
    t : array
        Time array
    x1, x2 : array
        Position arrays for both masses
    m1, m2 : float
        Masses
    k1, k2, k3 : float
        Spring constants
    use_hann_window : bool
        Whether to apply Hann window to reduce spectral leakage

    Returns
    -------
    tuple
        (f1_estimated, f2_estimated) in Hz
    """
    t = np.asarray(t)
    dt = t[1] - t[0]
    n = t.size

    x = np.vstack([np.asarray(x1), np.asarray(x2)])

    x = x - x.mean(axis=1, keepdims=True)

    M = np.diag([m1, m2])
    K = np.array([[k1 + k2, -k2], [-k2, k2 + k3]])

    omega_sq, V = eigh(K, M)

    q = V.T @ (M @ x)

    if use_hann_window:
        window = np.hanning(n)
        q = q * window

    freqs = np.fft.rfftfreq(n, dt)
    Q = np.fft.rfft(q, axis=1)
    amp = np.abs(Q)

    rms = np.sqrt(np.mean(q**2, axis=1))
    rms_rel = rms / (np.max(rms) + 1e-30)
    f_est = []
    for i in range(2):
        if rms_rel[i] < 1e-3:
            f_est.append(0.0)
            continue

        peak_idx = np.argmax(amp[i, 1:]) + 1
        f_est.append(freqs[peak_idx])

    return f_est[0], f_est[1]


#### 3.2.4 Animating the motion of Coupled Mass-Spring System

The function `simulate_coupled_mass_spring_system` below runs `solve_ivp`, precomputes positions/energies/spring extensions, prints theoretical vs numerical frequencies, then animates by updating wall/mass/spring artists each frame (via `FuncAnimation`) using the precomputed arrays.

In [None]:
def simulate_coupled_mass_spring_system(
    x1_init=0.5,
    x2_init=-0.5,
    v1_init=0.0,
    v2_init=0.0,
    m1=1.0,
    m2=1.0,
    k1=10.0,
    k2=10.0,
    k3=10.0,
    system_length=12.0,
    simulation_time=20.0,
    fps=30,
    save_anim=False,
    filename=None,
):
    """
    Simulate and animate a coupled mass-spring system.

    System: Left Wall ‚Äî[k1]‚Äî m1 ‚Äî[k2]‚Äî m2 ‚Äî[k3]‚Äî Right Wall

    Parameters
    ----------
    x1_init, x2_init : float
        Initial displacements from equilibrium
    v1_init, v2_init : float
        Initial velocities
    m1, m2 : float
        Masses
    k1, k2, k3 : float
        Spring constants
    system_length : float
        Distance between walls
    simulation_time : float
        Total simulation time in seconds
    fps : int
        Frames per second for animation
    save_anim : bool
        Whether to save the animation
    filename : str or None
        Filename for saved animation

    Returns
    -------
    tuple
        (fig, anim) - Figure and animation objects
    """

    # =========================================================================
    # NUMERICAL SOLUTION
    # =========================================================================

    # Initial state: [x‚ÇÅ, v‚ÇÅ, x‚ÇÇ, v‚ÇÇ]
    y0 = [x1_init, v1_init, x2_init, v2_init]

    # Time span
    t_span = (0, simulation_time)
    n_frames = int(simulation_time * fps) + 1
    t_eval = np.linspace(0, simulation_time, n_frames)

    # Solve the ODE
    print("Solving differential equations...")
    solution = solve_ivp(
        two_mass_spring_derivatives,
        t_span,
        y0,
        args=(m1, m2, k1, k2, k3),
        method="RK45",
        rtol=1e-5,
        atol=1e-7,
        dense_output=True,
    )

    if solution.sol is None:
        print("ERROR: ODE solver failed!")
        return None, None

    y = solution.sol(t=t_eval)
    x1 = y[0]
    v1 = y[1]
    x2 = y[2]
    v2 = y[3]

    print(f"Solution computed: {len(t_eval)} time steps")

    # =========================================================================
    # PRECOMPUTE PHYSICAL QUANTITIES
    # =========================================================================

    # Equilibrium positions
    eq_x1 = system_length / 3.0
    eq_x2 = 2.0 * system_length / 3.0

    # Actual positions
    pos_x1 = eq_x1 + x1
    pos_x2 = eq_x2 + x2

    # Spring extensions/compressions
    spring1_ext = x1  # extension of spring 1 from natural length
    spring2_ext = x2 - x1  # extension of spring 2
    spring3_ext = -x2  # extension of spring 3 (negative x2 means extension)

    # Energies
    KE_all = 0.5 * m1 * v1**2 + 0.5 * m2 * v2**2  # Kinetic Energy
    PE_spring_all = (
        0.5 * k1 * x1**2 + 0.5 * k2 * (x2 - x1) ** 2 + 0.5 * k3 * x2**2
    )  # Potential Energy in Springs
    total_E_all = KE_all + PE_spring_all

    # =========================================================================
    # FREQUENCY ANALYSIS
    # =========================================================================

    # Theoretical normal modes
    omega1_theory, omega2_theory, f1_theory, f2_theory, mode1, mode2 = (
        th_normal_modes_mass_spring_system(m1, m2, k1, k2, k3)
    )

    # Numerical frequency estimation
    f1_num, f2_num = est_freqs_mass_spring_system(t_eval, x1, x2, m1, m2, k1, k2, k3)

    def _fmt_freq(f):
        return f"{f:.4f} Hz" if f > 0 else "N/A (Not Excited)"

    f1_str = _fmt_freq(f1_num)
    f2_str = _fmt_freq(f2_num)

    print("\nNormal Mode Frequencies:")
    print(f"  Theoretical: œâ‚ÇÅ = {omega1_theory:.4f} rad/s (f‚ÇÅ = {f1_theory:.4f} Hz)")
    print(f"              œâ‚ÇÇ = {omega2_theory:.4f} rad/s (f‚ÇÇ = {f2_theory:.4f} Hz)")
    print(f"  Numerical:   f‚ÇÅ ‚âà {f1_str}, f‚ÇÇ ‚âà {f2_str}")

    # =========================================================================
    # ANIMATION SETUP
    # =========================================================================

    # Geometry parameters
    left_wall_x = 0.0
    right_wall_x = system_length
    y_position = 0.0
    wall_width = 0.4
    wall_height = 1.5

    # Spring parameters
    relaxed_spring_length = system_length / 3.0
    hook_length = calculate_fixed_hook_length_mass_spring(relaxed_spring_length)
    spring_radius = 0.15
    spring_num_coils = 8

    # Mass visualization sizes (proportional to mass)
    mass_size_ref = 25
    avg_mass = (m1 + m2) / 2.0
    mass_1_size = max(15, mass_size_ref * np.cbrt(m1 / avg_mass))
    mass_2_size = max(15, mass_size_ref * np.cbrt(m2 / avg_mass))

    # Figure setup
    fig, ax = plt.subplots(figsize=(14, 7))
    fig.subplots_adjust(top=0.9, bottom=0.3, left=0.02, right=0.98)

    # Set plot limits
    padding = 0.5
    x_min = left_wall_x - padding
    x_max = right_wall_x + padding
    y_min = y_position - wall_height / 2 - padding
    y_max = y_position + wall_height / 2 + padding

    ax.set_xlim(x_min, x_max)
    ax.set_ylim(y_min, y_max)
    ax.set_xticks([])
    ax.set_yticks([])
    ax.set_xticklabels([])
    ax.set_yticklabels([])

    # Determine motion type
    if abs(x1_init - x2_init) < 1e-6:
        motion_type = "In-Phase Mode (Both move together)"
    elif abs(x1_init + x2_init) < 1e-6:
        motion_type = "Out-of-Phase Mode (Opposite directions)"
    elif x1_init == 0.0 or x2_init == 0.0:
        motion_type = "Energy Transfer (Beats Phenomenon)"
    else:
        motion_type = "Mixed Mode (Superposition)"

    # Title
    title_text = (
        f"Coupled Mass-Spring System: {motion_type}\n"
        f"Normal Modes: $\\omega_1$={omega1_theory:.3f} rad/s, $\\omega_2$={omega2_theory:.3f} rad/s"
    )
    ax.set_title(title_text, fontsize=14, fontweight="bold", pad=10)

    # Initialize plot elements
    # Walls
    draw_wall(ax, left_wall_x, y_position, wall_width, wall_height)
    draw_wall(ax, right_wall_x, y_position, wall_width, wall_height)

    # Equilibrium reference lines
    ax.axvline(eq_x1, color="gray", linestyle="--", lw=1, alpha=0.5, zorder=0)
    ax.axvline(eq_x2, color="gray", linestyle="--", lw=1, alpha=0.5, zorder=0)
    ax.axhline(y_position, color="gray", linestyle=":", lw=1, alpha=0.3, zorder=0)

    # Springs (will be updated)
    (spring1_line,) = ax.plot([], [], color="dimgrey", lw=2, zorder=2)
    (spring2_line,) = ax.plot([], [], color="dimgrey", lw=2, zorder=2)
    (spring3_line,) = ax.plot([], [], color="dimgrey", lw=2, zorder=2)

    # Masses
    (mass1_plot,) = ax.plot(
        [], [], "o", color="crimson", markersize=mass_1_size, zorder=10
    )
    (mass2_plot,) = ax.plot(
        [], [], "o", color="royalblue", markersize=mass_2_size, zorder=10
    )

    # Traces
    trace_length = min(300, n_frames // 2)
    trace1_x, trace1_y = [], []
    trace2_x, trace2_y = [], []
    (trace1_line,) = ax.plot([], [], "r-", alpha=0.3, lw=1, zorder=1)
    (trace2_line,) = ax.plot([], [], "b-", alpha=0.3, lw=1, zorder=1)

    # =========================================================================
    # TEXT ANNOTATIONS
    # =========================================================================

    # Timer
    time_text = ax.text(
        0.02,
        0.98,
        "",
        transform=ax.transAxes,
        fontsize=12,
        verticalalignment="top",
        bbox=dict(boxstyle="round", facecolor="wheat", alpha=0.8),
    )

    # System info box
    info_text = ax.text(
        0.006,
        -0.02,
        "",
        fontsize=9,
        transform=ax.transAxes,
        va="top",
        ha="left",
        family="monospace",
        bbox=dict(boxstyle="round", facecolor="lightblue", alpha=0.8),
    )

    # Dynamic info box
    dynamic_text = ax.text(
        0.38,
        -0.02,
        "",
        fontsize=9,
        transform=ax.transAxes,
        va="top",
        ha="left",
        family="monospace",
        bbox=dict(boxstyle="round", facecolor="lightyellow", alpha=0.8),
    )

    # Frequency comparison box
    freq_text = ax.text(
        0.845,
        -0.02,
        "",
        fontsize=9,
        transform=ax.transAxes,
        va="top",
        ha="left",
        family="monospace",
        bbox=dict(boxstyle="round", facecolor="lightgreen", alpha=0.8),
    )

    # Set static system info
    info_str = (
        "SYSTEM PARAMETERS\n"
        "‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ\n"
        f"$m_1$ = {m1:.2f} kg,"
        f"$m_2$ = {m2:.2f} kg\n"
        f"$k_1$ = {k1:.2f} N/m,"
        f"$k_2$ = {k2:.2f} N/m,"
        f"$k_3$ = {k3:.2f} N/m\n"
        f"\n"
        f"INITIAL CONDITIONS\n"
        f"‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ\n"
        f"$x_1$(0) = {x1_init:+.3f} m,"
        f"$x_2$(0) = {x2_init:+.3f} m\n"
        f"$\\dot{{x}}_1$(0) = {v1_init:+.3f} m/s,"
        f"$\\dot{{x}}_2$(0) = {v2_init:+.3f} m/s"
    )
    info_text.set_text(info_str)

    # Set static frequency comparison
    freq_str = (
        "FREQUENCY ANALYSIS\n"
        "‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ\n"
        f"Mode 1 (Low):\n"
        f"  Theory: {f1_theory:.4f} Hz\n"
        f"  Simul.: {f1_str}\n"
        f"\n"
        f"Mode 2 (High):\n"
        f"  Theory: {f2_theory:.4f} Hz\n"
        f"  Simul.: {f2_str}\n"
    )
    freq_text.set_text(freq_str)

    # =========================================================================
    # ANIMATION FUNCTION
    # =========================================================================

    def init():
        """Initialize animation"""
        spring1_line.set_data([], [])
        spring2_line.set_data([], [])
        spring3_line.set_data([], [])
        mass1_plot.set_data([], [])
        mass2_plot.set_data([], [])
        trace1_line.set_data([], [])
        trace2_line.set_data([], [])
        time_text.set_text("")
        dynamic_text.set_text("")
        return (
            spring1_line,
            spring2_line,
            spring3_line,
            mass1_plot,
            mass2_plot,
            trace1_line,
            trace2_line,
            time_text,
            dynamic_text,
        )

    def update(frame):
        """Update animation for each frame"""
        # Current positions
        x1_curr = pos_x1[frame]
        x2_curr = pos_x2[frame]

        # Draw springs
        # Spring 1: Left wall to mass 1
        x_s1, y_s1 = draw_spring_with_hook(
            (left_wall_x + wall_width / 2, y_position),
            (x1_curr, y_position),
            spring_num_coils,
            spring_radius,
            hook_length,
        )
        spring1_line.set_data(x_s1, y_s1)

        # Spring 2: Mass 1 to mass 2
        x_s2, y_s2 = draw_spring_with_hook(
            (x1_curr, y_position),
            (x2_curr, y_position),
            spring_num_coils,
            spring_radius,
            hook_length,
        )
        spring2_line.set_data(x_s2, y_s2)

        # Spring 3: Mass 2 to right wall
        x_s3, y_s3 = draw_spring_with_hook(
            (x2_curr, y_position),
            (right_wall_x - wall_width / 2, y_position),
            spring_num_coils,
            spring_radius,
            hook_length,
        )
        spring3_line.set_data(x_s3, y_s3)

        # Update masses
        mass1_plot.set_data([x1_curr], [y_position])
        mass2_plot.set_data([x2_curr], [y_position])

        # Update traces
        trace1_x.append(x1_curr)
        trace1_y.append(y_position)
        trace2_x.append(x2_curr)
        trace2_y.append(y_position)

        if len(trace1_x) > trace_length:
            trace1_x.pop(0)
            trace1_y.pop(0)
            trace2_x.pop(0)
            trace2_y.pop(0)

        trace1_line.set_data(trace1_x, trace1_y)
        trace2_line.set_data(trace2_x, trace2_y)

        # Update timer
        time_text.set_text(f"Time: {t_eval[frame]:.2f} s")

        # Update dynamic info
        dynamic_str = (
            "CURRENT STATE\n"
            "‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ\n"
            f"$x_1$ = {x1[frame]:+.3f} m, "
            f"$x_2$ = {x2[frame]:+.3f} m\n"
            f"$\\dot{{x}}_1$ = {v1[frame]:+.3f} m/s,"
            f"$\\dot{{x}}_2$ = {v2[frame]:+.3f} m/s\n"
            f"\n"
            f"SPRING EXTENSIONS\n"
            f"‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ\n"
            f"$\\Delta s_1$ = {spring1_ext[frame]:+.3f} m, "
            f"$\\Delta s_2$ = {spring2_ext[frame]:+.3f} m, "
            f"$\\Delta s_3$ = {spring3_ext[frame]:+.3f} m\n"
            f"\n"
            f"ENERGY\n"
            f"‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ\n"
            f"KE  = {KE_all[frame]:.3f} J, "
            f"PE  = {PE_spring_all[frame]:.3f} J, "
            f"Tot = {total_E_all[frame]:.3f} J"
        )
        dynamic_text.set_text(dynamic_str)

        return (
            spring1_line,
            spring2_line,
            spring3_line,
            mass1_plot,
            mass2_plot,
            trace1_line,
            trace2_line,
            time_text,
            dynamic_text,
        )

    # Create animation
    print("Creating animation...")
    anim = FuncAnimation(
        fig,
        update,
        init_func=init,
        frames=n_frames,
        interval=int(1000 / fps),
        blit=False,
        repeat=False,
    )

    # Save animation if requested
    if save_anim:
        if filename is None:
            filename = (
                f"mass_spring_x1={x1_init}_x2={x2_init}_m1={m1}_m2={m2}_"
                f"k1={k1}_k2={k2}_k3={k3}.gif"
            )

        save_dir = "OUTPUTS/ANIMATIONS/mass_spring_systems"
        os.makedirs(save_dir, exist_ok=True)
        filepath = os.path.join(save_dir, filename)
        try:
            print(f"Saving animation to {filepath}...")
            anim.save(filepath, writer="ffmpeg", fps=fps, codec="gif", dpi=100)
            print("Animation saved!")
            plt.close(fig)
        except Exception as e:
            print(f"Error saving animation: {e}")
    else:
        plt.show()

    return anim


def main_mass_spring_simulation():
    """Main function to run mass-spring simulation with user-defined parameters"""
    print("\n" + "=" * 60)
    print("COUPLED MASS-SPRING SYSTEM SIMULATION")
    print("=" * 60)

    default_params = {
        "x1_init": 0.5,
        "x2_init": -0.5,
        "v1_init": 0.0,
        "v2_init": 0.0,
        "m1": 1.0,
        "m2": 1.0,
        "k1": 10.0,
        "k2": 10.0,
        "k3": 10.0,
        "system_length": 12.0,
        "simulation_time": 15.0,
        "fps": 30,
    }

    use_defaults = input("Use default parameters? (y/n): ").strip().lower() == "y"
    if use_defaults:
        params = default_params
    else:
        params = {}
        for key, val in default_params.items():
            user_input = input(f"Enter value for {key} (default={val}): ").strip()
            if user_input == "":
                params[key] = val
            else:
                try:
                    params[key] = float(user_input)
                except ValueError:
                    print(f"Invalid input for {key}. Using default value {val}.")
                    params[key] = val

    save_animation = input("Save animation? (y/n): ").strip().lower() == "y"
    filename = None
    if save_animation:
        filename_input = input(
            "Enter filename for animation (e.g., 'mass_spring.gif'): "
        ).strip()
        if filename_input:
            if not filename_input.lower().endswith(".gif"):
                filename_input += ".gif"
            filename = filename_input

    print("\nStarting simulation with parameters:")
    print(f"    Initial Displacement of mass 1 (x1_init): {params['x1_init']} m")
    print(f"    Initial Displacement of mass 2 (x2_init): {params['x2_init']} m")
    print(f"    Initial Velocity of mass 1 (v1_init): {params['v1_init']} m/s")
    print(f"    Initial Velocity of mass 2 (v2_init): {params['v2_init']} m/s")
    print(f"    Mass 1 (m1): {params['m1']} kg")
    print(f"    Mass 2 (m2): {params['m2']} kg")
    print(f"    Spring Constant k1: {params['k1']} N/m")
    print(f"    Spring Constant k2: {params['k2']} N/m")
    print(f"    Spring Constant k3: {params['k3']} N/m")
    print(f"    System Length: {params['system_length']} m")
    print(f"    Simulation Time: {params['simulation_time']} s")
    print(f"    Frames per Second (fps): {params['fps']}")

    anim = simulate_coupled_mass_spring_system(
        x1_init=params["x1_init"],
        x2_init=params["x2_init"],
        v1_init=params["v1_init"],
        v2_init=params["v2_init"],
        m1=params["m1"],
        m2=params["m2"],
        k1=params["k1"],
        k2=params["k2"],
        k3=params["k3"],
        system_length=params["system_length"],
        simulation_time=params["simulation_time"],
        fps=int(params["fps"]),
        save_anim=save_animation,
        filename=filename,
    )

    return anim


if __name__ == "__main__":
    main_mass_spring_simulation()
    print("\nSimulation complete.")
    print(
        "You can modify the parameters in the 'main_mass_spring_simulation' function to run different scenarios."
    )

### 3.3 Time Series and Phase Space Visualizations

In [None]:
def coupled_mass_spring_system_animation_with_plots(
    x1_init=0.5,
    x2_init=-0.5,
    v1_init=0.0,
    v2_init=0.0,
    m1=1.0,
    m2=1.0,
    k1=10.0,
    k2=10.0,
    k3=10.0,
    system_length=12.0,
    simulation_time=20.0,
    fps=30,
    save_format="gif",
    save_anim=False,
    filename=None,
):
    """
    Simulate and animate a coupled mass-spring system.

    System: Left Wall ‚Äî[k1]‚Äî m1 ‚Äî[k2]‚Äî m2 ‚Äî[k3]‚Äî Right Wall

    Parameters
    ----------
    x1_init, x2_init : float
        Initial displacements from equilibrium
    v1_init, v2_init : float
        Initial velocities
    m1, m2 : float
        Masses
    k1, k2, k3 : float
        Spring constants
    system_length : float
        Distance between walls
    simulation_time : float
        Total simulation time in seconds
    fps : int
        Frames per second for animation
    save_format : str
        Format to save animation ('gif' or 'mp4')
    save_anim : bool
        Whether to save the animation
    filename : str or None
        Filename for saved animation

    Returns
    -------
        anim - animation object
    """

    # =========================================================================
    # NUMERICAL SOLUTION
    # =========================================================================

    # Initial state: [x‚ÇÅ, v‚ÇÅ, x‚ÇÇ, v‚ÇÇ]
    y0 = [x1_init, v1_init, x2_init, v2_init]

    # Time span
    t_span = (0, simulation_time)
    n_frames = int(simulation_time * fps) + 1
    t_eval = np.linspace(0, simulation_time, n_frames)

    # Solve the ODE
    print("Solving differential equations...")
    solution = solve_ivp(
        two_mass_spring_derivatives,
        t_span,
        y0,
        args=(m1, m2, k1, k2, k3),
        method="RK45",
        rtol=1e-5,
        atol=1e-7,
        dense_output=True,
    )

    if solution.sol is None:
        print("ERROR: ODE solver failed!")
        return None, None

    y = solution.sol(t=t_eval)
    x1 = y[0]
    v1 = y[1]
    x2 = y[2]
    v2 = y[3]

    print(f"Solution computed: {len(t_eval)} time steps")

    # =========================================================================
    # PRECOMPUTE PHYSICAL QUANTITIES
    # =========================================================================

    # Equilibrium positions
    eq_x1 = system_length / 3.0
    eq_x2 = 2.0 * system_length / 3.0

    # Actual positions
    pos_x1 = eq_x1 + x1
    pos_x2 = eq_x2 + x2

    # Spring extensions/compressions
    spring1_ext = x1
    spring2_ext = x2 - x1
    spring3_ext = -x2

    # Energies
    KE_all = 0.5 * m1 * v1**2 + 0.5 * m2 * v2**2
    PE_spring_all = 0.5 * k1 * x1**2 + 0.5 * k2 * (x2 - x1) ** 2 + 0.5 * k3 * x2**2
    total_E_all = KE_all + PE_spring_all

    # =========================================================================
    # FREQUENCY ANALYSIS
    # =========================================================================

    # Theoretical normal modes
    omega1_theory, omega2_theory, f1_theory, f2_theory, mode1, mode2 = (
        th_normal_modes_mass_spring_system(m1, m2, k1, k2, k3)
    )

    # Numerical frequency estimation
    f1_num, f2_num = est_freqs_mass_spring_system(t_eval, x1, x2, m1, m2, k1, k2, k3)

    def _fmt_freq(f):
        return f"{f:.4f} Hz" if f > 0 else "N/A (Not Excited)"

    f1_str = _fmt_freq(f1_num)
    f2_str = _fmt_freq(f2_num)

    print("\nNormal Mode Frequencies:")
    print(f"  Theoretical: œâ‚ÇÅ = {omega1_theory:.4f} rad/s (f‚ÇÅ = {f1_theory:.4f} Hz)")
    print(f"              œâ‚ÇÇ = {omega2_theory:.4f} rad/s (f‚ÇÇ = {f2_theory:.4f} Hz)")
    print(f"  Numerical:   f‚ÇÅ ‚âà {f1_str}, f‚ÇÇ ‚âà {f2_str}")

    # =========================================================================
    # ANIMATION SETUP
    # =========================================================================

    # Geometry parameters
    left_wall_x = 0.0
    right_wall_x = system_length
    y_position = 0.0
    wall_width = 0.4
    wall_height = 1.5

    # Spring parameters
    relaxed_spring_length = system_length / 3.0
    hook_length = calculate_fixed_hook_length_mass_spring(relaxed_spring_length)
    spring_radius = 0.15
    spring_num_coils = 8

    mass_size_ref = 25
    avg_mass = (m1 + m2) / 2.0
    mass_1_size = max(15, mass_size_ref * np.cbrt(m1 / avg_mass))
    mass_2_size = max(15, mass_size_ref * np.cbrt(m2 / avg_mass))

    # Figure setup
    fig = plt.figure(figsize=(14, 10))
    gs = GridSpec(
        2,
        3,
        left=0.08,
        right=0.92,
        top=0.92,
        bottom=0.06,
        height_ratios=[1.5, 1],
        hspace=0.77,
        wspace=0.3,
    )

    # Animation subplot
    ax = fig.add_subplot(gs[0, :])

    # Time domain subplots
    ax_x1 = fig.add_subplot(gs[1, 0])
    ax_x2 = fig.add_subplot(gs[1, 1])
    ax_phase = fig.add_subplot(gs[1, 2])

    # Set plot limits for animation
    padding = 0.5
    x_min = left_wall_x - padding
    x_max = right_wall_x + padding
    y_min = y_position - wall_height / 2 - padding
    y_max = y_position + wall_height / 2 + padding

    ax.set_xlim(x_min, x_max)
    ax.set_ylim(y_min, y_max)
    ax.set_xticks([])
    ax.set_yticks([])
    ax.set_xticklabels([])
    ax.set_yticklabels([])

    # Determine motion type
    if abs(x1_init - x2_init) < 1e-6:
        motion_type = "In-Phase Mode (Both move together)"
    elif abs(x1_init + x2_init) < 1e-6:
        motion_type = "Out-of-Phase Mode (Opposite directions)"
    elif x1_init == 0.0 or x2_init == 0.0:
        motion_type = "Energy Transfer (Beats Phenomenon)"
    else:
        motion_type = "Mixed Mode (Superposition)"

    # Title
    title_text = (
        f"Coupled Mass-Spring System: {motion_type}\n"
        f"Normal Modes: $\\omega_1$={omega1_theory:.3f} rad/s, $\\omega_2$={omega2_theory:.3f} rad/s"
    )
    ax.set_title(title_text, fontsize=14, fontweight="bold", pad=10)

    # Initialize plot elements for animation
    # Walls
    draw_wall(ax, left_wall_x, y_position, wall_width, wall_height)
    draw_wall(ax, right_wall_x, y_position, wall_width, wall_height)

    # Equilibrium reference lines
    ax.axvline(eq_x1, color="gray", linestyle="--", lw=1, alpha=0.5, zorder=0)
    ax.axvline(eq_x2, color="gray", linestyle="--", lw=1, alpha=0.5, zorder=0)
    ax.axhline(y_position, color="gray", linestyle=":", lw=1, alpha=0.3, zorder=0)

    # Springs (will be updated)
    (spring1_line,) = ax.plot([], [], color="dimgrey", lw=2, zorder=2)
    (spring2_line,) = ax.plot([], [], color="dimgrey", lw=2, zorder=2)
    (spring3_line,) = ax.plot([], [], color="dimgrey", lw=2, zorder=2)

    # Masses
    (mass1_plot,) = ax.plot(
        [], [], "o", color="crimson", markersize=mass_1_size, zorder=10
    )
    (mass2_plot,) = ax.plot(
        [], [], "o", color="royalblue", markersize=mass_2_size, zorder=10
    )

    # Traces
    trace_length = min(300, n_frames // 2)
    trace1_x, trace1_y = [], []
    trace2_x, trace2_y = [], []
    (trace1_line,) = ax.plot([], [], "r-", alpha=0.3, lw=1, zorder=1)
    (trace2_line,) = ax.plot([], [], "b-", alpha=0.3, lw=1, zorder=1)

    # =========================================================================
    # SUBPLOTS SETUP
    # =========================================================================

    # x1 vs t
    ax_x1.set_title("Displacement $x_1$ vs Time", fontsize=12, fontweight="bold")
    ax_x1.set_xlabel("Time (s)", fontsize=10)
    ax_x1.set_ylabel("$x_1$ (m)", fontsize=10)
    ax_x1.set_xlim(0, 1.15 * simulation_time)
    ax_x1.set_ylim(min(x1) * 1.1, max(x1) * 1.1)
    ax_x1.grid(True, alpha=0.3)
    (line_x1,) = ax_x1.plot([], [], "r-", lw=1.5)
    (marker_x1,) = ax_x1.plot([], [], "ro", markersize=8)

    # x2 vs t
    ax_x2.set_title("Displacement $x_2$ vs Time", fontsize=12, fontweight="bold")
    ax_x2.set_xlabel("Time (s)", fontsize=10)
    ax_x2.set_ylabel("$x_2$ (m)", fontsize=10)
    ax_x2.set_xlim(0, 1.15 * simulation_time)
    ax_x2.set_ylim(min(x2) * 1.1, max(x2) * 1.1)
    ax_x2.grid(True, alpha=0.3)
    (line_x2,) = ax_x2.plot([], [], "b-", lw=1.5)
    (marker_x2,) = ax_x2.plot([], [], "bo", markersize=8)

    # Phase Space (velocity vs position)
    ax_phase.set_title(
        "Phase Space ($\\dot{x}$ vs $x$)", fontsize=12, fontweight="bold"
    )
    ax_phase.set_xlabel("Position (m)", fontsize=10)
    ax_phase.set_ylabel("Velocity (m/s)", fontsize=10)
    x_range = min(min(x1), min(x2)) * 1.2, max(max(x1), max(x2)) * 1.2
    v_range = min(min(v1), min(v2)) * 1.2, max(max(v1), max(v2)) * 1.2
    ax_phase.set_xlim(x_range)
    ax_phase.set_ylim(v_range)
    ax_phase.grid(True, alpha=0.3)
    (line_phase1,) = ax_phase.plot([], [], "r-", lw=1.5, label="Mass 1", alpha=0.7)
    (line_phase2,) = ax_phase.plot([], [], "b-", lw=1.5, label="Mass 2", alpha=0.7)
    (marker_phase1,) = ax_phase.plot([], [], "ro", markersize=8)
    (marker_phase2,) = ax_phase.plot([], [], "bo", markersize=8)
    ax_phase.legend(loc="upper right", fontsize=8)

    # =========================================================================
    # TEXT ANNOTATIONS
    # =========================================================================

    # Timer
    time_text = ax.text(
        0.02,
        0.98,
        "",
        transform=ax.transAxes,
        fontsize=12,
        verticalalignment="top",
        bbox=dict(boxstyle="round", facecolor="wheat", alpha=0.8),
    )

    # System info box
    info_text = ax.text(
        0.01,
        -0.03,
        "",
        fontsize=9,
        transform=ax.transAxes,
        va="top",
        ha="left",
        family="monospace",
        bbox=dict(boxstyle="round", facecolor="lightblue", alpha=0.8),
    )

    # Dynamic info box
    dynamic_text = ax.text(
        0.38,
        -0.03,
        "",
        fontsize=9,
        transform=ax.transAxes,
        va="top",
        ha="left",
        family="monospace",
        bbox=dict(boxstyle="round", facecolor="lightyellow", alpha=0.8),
    )

    # Frequency comparison box
    freq_text = ax.text(
        0.82,
        -0.03,
        "",
        fontsize=9,
        transform=ax.transAxes,
        va="top",
        ha="left",
        family="monospace",
        bbox=dict(boxstyle="round", facecolor="lightgreen", alpha=0.8),
    )

    # Set static system info
    info_str = (
        "SYSTEM PARAMETERS\n"
        "‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ\n"
        f"$m_1$ = {m1:.2f} kg,"
        f"$m_2$ = {m2:.2f} kg\n"
        f"$k_1$ = {k1:.2f} N/m,"
        f"$k_2$ = {k2:.2f} N/m,"
        f"$k_3$ = {k3:.2f} N/m\n"
        f"\n"
        f"INITIAL CONDITIONS\n"
        f"‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ\n"
        f"$x_1$(0) = {x1_init:+.3f} m,"
        f"$x_2$(0) = {x2_init:+.3f} m\n"
        f"$\\dot{{x}}_1$(0) = {v1_init:+.3f} m/s,"
        f"$\\dot{{x}}_2$(0) = {v2_init:+.3f} m/s"
    )
    info_text.set_text(info_str)

    # Set static frequency comparison
    freq_str = (
        "FREQUENCY ANALYSIS\n"
        "‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ\n"
        f"Mode 1 (Low):\n"
        f"  Theory: {f1_theory:.4f} Hz\n"
        f"  Simul.: {f1_str}\n"
        f"\n"
        f"Mode 2 (High):\n"
        f"  Theory: {f2_theory:.4f} Hz\n"
        f"  Simul.: {f2_str}\n"
    )
    freq_text.set_text(freq_str)

    # =========================================================================
    # ANIMATION FUNCTION
    # =========================================================================

    def init():
        """Initialize animation"""
        spring1_line.set_data([], [])
        spring2_line.set_data([], [])
        spring3_line.set_data([], [])
        mass1_plot.set_data([], [])
        mass2_plot.set_data([], [])
        trace1_line.set_data([], [])
        trace2_line.set_data([], [])

        line_x1.set_data([], [])
        line_x2.set_data([], [])
        line_phase1.set_data([], [])
        line_phase2.set_data([], [])
        marker_x1.set_data([], [])
        marker_x2.set_data([], [])
        marker_phase1.set_data([], [])
        marker_phase2.set_data([], [])

        time_text.set_text("")
        dynamic_text.set_text("")
        return (
            spring1_line,
            spring2_line,
            spring3_line,
            mass1_plot,
            mass2_plot,
            trace1_line,
            trace2_line,
            line_x1,
            line_x2,
            line_phase1,
            line_phase2,
            marker_x1,
            marker_x2,
            marker_phase1,
            marker_phase2,
            time_text,
            dynamic_text,
        )

    def update(frame):
        """Update animation for each frame"""
        # Current positions
        x1_curr = pos_x1[frame]
        x2_curr = pos_x2[frame]

        # Draw springs
        # Spring 1: Left wall to mass 1
        x_s1, y_s1 = draw_spring_with_hook(
            (left_wall_x + wall_width / 2, y_position),
            (x1_curr, y_position),
            spring_num_coils,
            spring_radius,
            hook_length,
        )
        spring1_line.set_data(x_s1, y_s1)

        # Spring 2: Mass 1 to mass 2
        x_s2, y_s2 = draw_spring_with_hook(
            (x1_curr, y_position),
            (x2_curr, y_position),
            spring_num_coils,
            spring_radius,
            hook_length,
        )
        spring2_line.set_data(x_s2, y_s2)

        # Spring 3: Mass 2 to right wall
        x_s3, y_s3 = draw_spring_with_hook(
            (x2_curr, y_position),
            (right_wall_x - wall_width / 2, y_position),
            spring_num_coils,
            spring_radius,
            hook_length,
        )
        spring3_line.set_data(x_s3, y_s3)

        # Update masses
        mass1_plot.set_data([x1_curr], [y_position])
        mass2_plot.set_data([x2_curr], [y_position])

        # Update traces
        trace1_x.append(x1_curr)
        trace1_y.append(y_position)
        trace2_x.append(x2_curr)
        trace2_y.append(y_position)

        if len(trace1_x) > trace_length:
            trace1_x.pop(0)
            trace1_y.pop(0)
            trace2_x.pop(0)
            trace2_y.pop(0)

        trace1_line.set_data(trace1_x, trace1_y)
        trace2_line.set_data(trace2_x, trace2_y)

        # Update subplots
        # We show history up to current frame
        current_t = t_eval[: frame + 1]
        current_x1 = x1[: frame + 1]
        current_x2 = x2[: frame + 1]
        current_v1 = v1[: frame + 1]
        current_v2 = v2[: frame + 1]

        line_x1.set_data(current_t, current_x1)
        line_x2.set_data(current_t, current_x2)
        line_phase1.set_data(current_x1, current_v1)
        line_phase2.set_data(current_x2, current_v2)

        # Update markers to show current position
        marker_x1.set_data([t_eval[frame]], [x1[frame]])
        marker_x2.set_data([t_eval[frame]], [x2[frame]])
        marker_phase1.set_data([x1[frame]], [v1[frame]])
        marker_phase2.set_data([x2[frame]], [v2[frame]])

        # Update timer
        time_text.set_text(f"Time: {t_eval[frame]:.2f} s")

        # Update dynamic info
        dynamic_str = (
            "CURRENT STATE\n"
            "‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ\n"
            f"$x_1$ = {x1[frame]:+.3f} m, "
            f"$x_2$ = {x2[frame]:+.3f} m\n"
            f"$\\dot{{x}}_1$ = {v1[frame]:+.3f} m/s,"
            f"$\\dot{{x}}_2$ = {v2[frame]:+.3f} m/s\n"
            f"\n"
            f"SPRING EXTENSIONS\n"
            f"‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ\n"
            f"$\\Delta s_1$ = {spring1_ext[frame]:+.3f} m, "
            f"$\\Delta s_2$ = {spring2_ext[frame]:+.3f} m, "
            f"$\\Delta s_3$ = {spring3_ext[frame]:+.3f} m\n"
            f"\n"
            f"ENERGY\n"
            f"‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ\n"
            f"KE  = {KE_all[frame]:.3f} J, "
            f"PE  = {PE_spring_all[frame]:.3f} J, "
            f"Tot = {total_E_all[frame]:.3f} J"
        )
        dynamic_text.set_text(dynamic_str)

        return (
            spring1_line,
            spring2_line,
            spring3_line,
            mass1_plot,
            mass2_plot,
            trace1_line,
            trace2_line,
            line_x1,
            line_x2,
            line_phase1,
            line_phase2,
            marker_x1,
            marker_x2,
            marker_phase1,
            marker_phase2,
            time_text,
            dynamic_text,
        )

    # Create animation
    print("Creating animation...")
    anim = FuncAnimation(
        fig,
        update,
        init_func=init,
        frames=n_frames,
        interval=int(1000 / fps),
        blit=False,
        repeat=False,
    )

    # Save animation if requested
    if save_anim:
        if filename is None:
            ext = "gif" if save_format.lower() == "gif" else "mp4"
            filename = f"mass_spring_m1={m1}_m2={m2}_k1={k1}_k2={k2}_k3={k3}.{ext}"

        save_dir = "OUTPUTS/ANIMATIONS/mass_spring_systems"
        os.makedirs(save_dir, exist_ok=True)
        filepath = os.path.join(save_dir, filename)
        try:
            print(f"Saving animation to {filepath}...")
            if save_format.lower() == "gif":
                anim.save(filepath, writer="ffmpeg", fps=fps, codec="gif", dpi=120)
            else:
                anim.save(filepath, writer="ffmpeg", fps=fps, dpi=100)
            print("Animation saved!")
            plt.close(fig)
        except Exception as e:
            print(f"Error saving animation: {e}")
    else:
        plt.show()

    return anim


def main_mass_spring_simulation_with_plots():
    """Main function to run mass-spring simulation with user-defined parameters"""
    print("\n" + "=" * 60)
    print("COUPLED MASS-SPRING SYSTEM SIMULATION WITH PLOTS")
    print("=" * 60)

    default_params = {
        "x1_init": 0.1,
        "x2_init": 0.0,
        "v1_init": 0.0,
        "v2_init": 0.0,
        "m1": 1.0,
        "m2": 1.0,
        "k1": 10.0,
        "k2": 0.5,
        "k3": 10.0,
        "system_length": 12.0,
        "simulation_time": 15.0,
        "fps": 30,
        "save_format": "mp4",
    }

    use_defaults = input("Use default parameters? (y/n): ").strip().lower() == "y"
    if use_defaults:
        params = default_params
    else:
        params = {}
        for key, val in default_params.items():
            user_input = input(f"Enter value for {key} (default={val}): ").strip()
            if user_input == "":
                params[key] = val
            else:
                try:
                    params[key] = float(user_input)
                except ValueError:
                    print(f"Invalid input for {key}. Using default value {val}.")
                    params[key] = val

    save_animation = input("Save animation? (y/n): ").strip().lower() == "y"
    filename = None
    if save_animation:
        filename_input = input(
            "Enter filename for animation (e.g., 'mass_spring.mp4'): "
        ).strip()
        if filename_input:
            if not filename_input.lower().endswith((".mp4", ".gif")):
                filename_input += f".{params['save_format']}"
            filename = filename_input

    print("\nStarting simulation with parameters:")
    print(f"    Initial Displacement of mass 1 (x1_init): {params['x1_init']} m")
    print(f"    Initial Displacement of mass 2 (x2_init): {params['x2_init']} m")
    print(f"    Initial Velocity of mass 1 (v1_init): {params['v1_init']} m/s")
    print(f"    Initial Velocity of mass 2 (v2_init): {params['v2_init']} m/s")
    print(f"    Mass 1 (m1): {params['m1']} kg")
    print(f"    Mass 2 (m2): {params['m2']} kg")
    print(f"    Spring Constant k1: {params['k1']} N/m")
    print(f"    Spring Constant k2: {params['k2']} N/m")
    print(f"    Spring Constant k3: {params['k3']} N/m")
    print(f"    System Length: {params['system_length']} m")
    print(f"    Simulation Time: {params['simulation_time']} s")
    print(f"    Frames per Second (fps): {params['fps']}")

    anim = coupled_mass_spring_system_animation_with_plots(
        x1_init=params["x1_init"],
        x2_init=params["x2_init"],
        v1_init=params["v1_init"],
        v2_init=params["v2_init"],
        m1=params["m1"],
        m2=params["m2"],
        k1=params["k1"],
        k2=params["k2"],
        k3=params["k3"],
        system_length=params["system_length"],
        simulation_time=params["simulation_time"],
        fps=int(params["fps"]),
        save_format=params["save_format"],
        save_anim=save_animation,
        filename=filename,
    )

    return anim


if __name__ == "__main__":
    main_mass_spring_simulation_with_plots()
    print("\nSimulation complete.")
    print(
        "You can modify the parameters in the 'main_mass_spring_simulation_with_plots' function to run different scenarios."
    )


### 3.4 Coupled Mass-Spring System Analysis with Two Masses

In [None]:
def coupled_mass_spring_system_analysis(
    x1_init=0.5,
    x2_init=-0.5,
    v1_init=0.0,
    v2_init=0.0,
    m1=1.0,
    m2=1.0,
    k1=10.0,
    k2=10.0,
    k3=10.0,
    simulation_time=20.0,
    n_points=1000,
    save_fig=False,
    filename=None,
):
    """
    Analyze and visualize a coupled mass-spring system with static plots.

    System: Left Wall ‚Äî[k1]‚Äî m1 ‚Äî[k2]‚Äî m2 ‚Äî[k3]‚Äî Right Wall

    Parameters
    ----------
    x1_init, x2_init : float
        Initial displacements from equilibrium
    v1_init, v2_init : float
        Initial velocities
    m1, m2 : float
        Masses
    k1, k2, k3 : float
        Spring constants
    simulation_time : float
        Total simulation time in seconds
    n_points : int
        Number of time points to compute
    save_fig : bool
        Whether to save the figure
    filename : str or None
        Filename for saved figure

    Returns
    -------
    fig : Figure
        Matplotlib figure object
    """

    # =========================================================================
    # NUMERICAL SOLUTION
    # =========================================================================

    # Initial state: [x‚ÇÅ, v‚ÇÅ, x‚ÇÇ, v‚ÇÇ]
    y0 = [x1_init, v1_init, x2_init, v2_init]

    # Time span
    t_span = (0, simulation_time)
    t_eval = np.linspace(0, simulation_time, n_points)

    # Solve the ODE
    print("Solving differential equations...")
    solution = solve_ivp(
        two_mass_spring_derivatives,
        t_span,
        y0,
        args=(m1, m2, k1, k2, k3),
        method="RK45",
        rtol=1e-6,
        atol=1e-8,
        t_eval=t_eval,
    )

    if not solution.success:
        print("ERROR: ODE solver failed!")
        return None

    t = solution.t
    x1 = solution.y[0]
    v1 = solution.y[1]
    x2 = solution.y[2]
    v2 = solution.y[3]

    print(f"Solution computed: {len(t)} time steps")

    # =========================================================================
    # FREQUENCY ANALYSIS
    # =========================================================================

    # Theoretical normal modes
    omega1_theory, omega2_theory, f1_theory, f2_theory, mode1, mode2 = (
        th_normal_modes_mass_spring_system(m1, m2, k1, k2, k3)
    )

    # Numerical frequency estimation
    f1_num, f2_num = est_freqs_mass_spring_system(t, x1, x2, m1, m2, k1, k2, k3)

    def _fmt_freq(f):
        return f"{f:.4f} Hz" if f > 0 else "N/A (Not Excited)"

    f1_str = _fmt_freq(f1_num)
    f2_str = _fmt_freq(f2_num)

    print("\nNormal Mode Frequencies:")
    print(f"  Theoretical: œâ‚ÇÅ = {omega1_theory:.4f} rad/s (f‚ÇÅ = {f1_theory:.4f} Hz)")
    print(f"              œâ‚ÇÇ = {omega2_theory:.4f} rad/s (f‚ÇÇ = {f2_theory:.4f} Hz)")
    print(f"  Numerical:   f‚ÇÅ ‚âà {f1_str}, f‚ÇÇ ‚âà {f2_str}")

    # =========================================================================
    # CREATE 2x2 GRID OF PLOTS
    # =========================================================================

    fig, axes = plt.subplots(2, 2, figsize=(14, 10))
    fig.subplots_adjust(
        left=0.08, right=0.8, top=0.9, bottom=0.06, hspace=0.25, wspace=0.25
    )
    ax_x1 = axes[0, 0]
    ax_x2 = axes[0, 1]
    ax_phase = axes[1, 0]
    ax_config = axes[1, 1]

    # Determine motion type
    if abs(x1_init - x2_init) < 1e-6:
        motion_type = "In-Phase Mode"
    elif abs(x1_init + x2_init) < 1e-6:
        motion_type = "Out-of-Phase Mode"
    elif x1_init == 0.0 or x2_init == 0.0:
        motion_type = "Energy Transfer Mode (Beats Phenomenon)"
    else:
        motion_type = "Mixed Mode"

    # Main title
    title_text = (
        f"Coupled Mass-Spring System Analysis: {motion_type}\n"
        f"Normal Modes: $\\omega_1$={omega1_theory:.3f} rad/s, $\\omega_2$={omega2_theory:.3f} rad/s"
    )
    fig.suptitle(title_text, fontsize=14, fontweight="bold")

    # -------------------------------------------------------------------------
    # [0, 0]: x1 vs t
    # -------------------------------------------------------------------------
    ax_x1.plot(t, x1, "r-", lw=1.5)
    ax_x1.set_xlabel("Time (s)", fontsize=11)
    ax_x1.set_ylabel("Displacement $x_1$ (m)", fontsize=11)
    ax_x1.set_title("Mass 1 Displacement vs Time", fontsize=12, fontweight="bold")
    ax_x1.grid(True, alpha=0.3)
    ax_x1.axhline(0, color="k", linestyle="--", lw=0.8, alpha=0.5)

    # -------------------------------------------------------------------------
    # [0, 1]: x2 vs t
    # -------------------------------------------------------------------------
    ax_x2.plot(t, x2, "b-", lw=1.5)
    ax_x2.set_xlabel("Time (s)", fontsize=11)
    ax_x2.set_ylabel("Displacement $x_2$ (m)", fontsize=11)
    ax_x2.set_title("Mass 2 Displacement vs Time", fontsize=12, fontweight="bold")
    ax_x2.grid(True, alpha=0.3)
    ax_x2.axhline(0, color="k", linestyle="--", lw=0.8, alpha=0.5)

    # -------------------------------------------------------------------------
    # [1, 0]: Phase Space (v vs x)
    # -------------------------------------------------------------------------
    ax_phase.plot(x1, v1, "r-", lw=1.5, label="Mass 1", alpha=0.7)
    ax_phase.plot(x2, v2, "b-", lw=1.5, label="Mass 2", alpha=0.7)
    ax_phase.set_xlabel("Position (m)", fontsize=11)
    ax_phase.set_ylabel("Velocity (m/s)", fontsize=11)
    ax_phase.set_title("Phase Space (v vs x)", fontsize=12, fontweight="bold")
    ax_phase.grid(True, alpha=0.3)
    ax_phase.axhline(0, color="k", linestyle="--", lw=0.8, alpha=0.5)
    ax_phase.axvline(0, color="k", linestyle="--", lw=0.8, alpha=0.5)
    ax_phase.legend(loc="upper right", fontsize=9)

    # -------------------------------------------------------------------------
    # [1, 1]: Configuration Space (x2 vs x1)
    # -------------------------------------------------------------------------
    ax_config.plot(x1, x2, "g-", lw=1.5)
    ax_config.plot(x1[0], x2[0], "yo", markersize=8, label="Start")
    ax_config.plot(x1[-1], x2[-1], "ro", markersize=8, label="End")
    ax_config.set_xlabel("$x_1$ (m)", fontsize=11)
    ax_config.set_ylabel("$x_2$ (m)", fontsize=11)
    ax_config.set_title(
        "Configuration Space ($x_2$ vs $x_1$)", fontsize=12, fontweight="bold"
    )
    ax_config.grid(True, alpha=0.3)
    ax_config.axhline(0, color="k", linestyle="--", lw=0.8, alpha=0.5)
    ax_config.axvline(0, color="k", linestyle="--", lw=0.8, alpha=0.5)
    legend_loc = (
        "upper center"
        if motion_type != "Energy Transfer Mode (Beats Phenomenon)"
        else "upper right"
    )
    ax_config.legend(loc=legend_loc, fontsize=9)

    # Add text box with parameters
    param_text = (
        f"System Parameters:\n"
        f"$m_1$={m1:.2f} kg\n$m_2$={m2:.2f} kg\n"
        f"$k_1$={k1:.2f} N/m\n$k_2$={k2:.2f} N/m\n$k_3$={k3:.2f} N/m\n"
        f"ICs:\n"
        f"$x_1(0)$={x1_init:+.2f} m\n$x_2(0)$={x2_init:+.2f} m\n"
        f"$\\dot{{x}}_1(0)$={v1_init:+.2f} m/s\n$\\dot{{x}}_2(0)$={v2_init:+.2f} m/s"
    )

    ax_x2.text(
        1.02,
        0.98,
        param_text,
        transform=ax_x2.transAxes,
        fontsize=9,
        family="monospace",
        va="top",
        ha="left",
        bbox=dict(boxstyle="round", facecolor="lightblue", alpha=0.6),
    )

    plt.show()

    # Save figure if requested
    if save_fig:
        if filename is None:
            filename = (
                f"mass_spring_analysis_x1={x1_init}_x2={x2_init}_m1={m1}_m2={m2}_"
                f"k1={k1}_k2={k2}_k3={k3}.png"
            )

        save_dir = "OUTPUTS/FIGURES/mass_spring_systems"
        os.makedirs(save_dir, exist_ok=True)
        filepath = os.path.join(save_dir, filename)
        try:
            print(f"Saving figure to {filepath}...")
            fig.savefig(filepath, dpi=300, bbox_inches="tight")
            print("Figure saved!")
        except Exception as e:
            print(f"Error saving figure: {e}")

    return fig


def main_mass_spring_system_analysis():
    """Main function to run mass-spring system analysis with user-defined parameters"""
    print("\n" + "=" * 60)
    print("COUPLED MASS-SPRING SYSTEM ANALYSIS WITH STATIC PLOTS")
    print("=" * 60)

    default_params = {
        "x1_init": 1.0,
        "x2_init": 0.0,
        "v1_init": 0.0,
        "v2_init": 0.0,
        "m1": 1.0,
        "m2": 1.0,
        "k1": 10.0,
        "k2": 0.5,
        "k3": 10.0,
        "simulation_time": 80.0,
        "n_points": 1000,
    }

    use_defaults = input("Use default parameters? (y/n): ").strip().lower() == "y"
    if use_defaults:
        params = default_params
    else:
        params = {}
        for key, val in default_params.items():
            user_input = input(f"Enter value for {key} (default={val}): ").strip()
            if user_input == "":
                params[key] = val
            else:
                try:
                    params[key] = float(user_input)
                except ValueError:
                    print(f"Invalid input for {key}. Using default value {val}.")
                    params[key] = val

    save_figure = input("Save figure? (y/n): ").strip().lower() == "y"
    filename = None
    if save_figure:
        filename_input = input(
            "Enter filename for figure (e.g., 'mass_spring_analysis.png'): "
        ).strip()
        if filename_input:
            if not filename_input.lower().endswith((".png", ".jpg", ".jpeg")):
                filename_input += ".png"
            filename = filename_input

    fig = coupled_mass_spring_system_analysis(
        x1_init=params["x1_init"],
        x2_init=params["x2_init"],
        v1_init=params["v1_init"],
        v2_init=params["v2_init"],
        m1=params["m1"],
        m2=params["m2"],
        k1=params["k1"],
        k2=params["k2"],
        k3=params["k3"],
        simulation_time=params["simulation_time"],
        n_points=int(params["n_points"]),
        save_fig=save_figure,
        filename=filename,
    )
    return fig


if __name__ == "__main__":
    main_mass_spring_system_analysis()


### 3.5 Coupling Strength Effects

To study the effect of coupling strength, we vary the middle spring constant $k_2$ while keeping $k_1,k_3$ fixed. This changes the normal-mode frequencies and the beat period. To analyze this, we simulate the system for three distinct $k_2$ values (weak, moderate, strong coupling) and plot the resulting time series and phase-space trajectories for each case.The function `parametric_sweep_coupling_conditions` below implements this analysis. It runs simulations for each $k_2$ value, then creates phase-space plots ($v_1$ vs $x_1$ and $v_2$ vs $x_2$) to visualize how coupling strength affects the dynamics.

In [None]:
def parametric_sweep_coupling_conditions(
    k1=10.0,
    k3=10.0,
    k2_weak=2.0,
    k2_medium=10.0,
    k2_strong=30.0,
    x1_init=0.5,
    x2_init=-0.5,
    v1_init=0.0,
    v2_init=0.0,
    m1=1.0,
    m2=1.0,
    system_length=12.0,
    simulation_time=15.0,
    fps=30,
    save_fig=False,
    filename=None,
):
    """
    Parametric sweep to test different coupling conditions by varying k2.

    Creates a 3x3 grid of subplots showing time series of x1 and x2 displacements
    and phase space trajectories for three coupling conditions: weak, medium, and strong.

    Parameters
    ----------
    k1, k3 : float
        Fixed spring constants for outer springs (N/m)
    k2_weak, k2_medium, k2_strong : float
        Coupling spring constant values for weak, medium, and strong coupling
    x1_init, x2_init : float
        Initial displacements from equilibrium (m)
    v1_init, v2_init : float
        Initial velocities (m/s)
    m1, m2 : float
        Masses (kg)
    system_length : float
        Distance between walls (m)
    simulation_time : float
        Total simulation time (s)
    fps : int
        Temporal resolution (frames per second)
    save_fig : bool
        Whether to save the figure
    filename : str or None
        Filename for saved figure

    Returns
    -------
    fig : matplotlib.figure.Figure
        The figure object containing all subplots
    """

    # Define coupling conditions
    k2_values = [k2_weak, k2_medium, k2_strong]
    coupling_labels = ["Weak Coupling", "Medium Coupling", "Strong Coupling"]

    # Create figure with 3x3 subplots (3 rows, 3 columns)
    fig, axes = plt.subplots(3, 3, figsize=(18, 12), sharex="col")
    fig.suptitle(
        "Parametric Sweep: Effect of Coupling Strength on Coupled Mass-Spring System\n"
        f"Fixed: $m_1$={m1:.2f} kg, $m_2$={m2:.2f} kg, $k_1$={k1:.2f} N/m, $k_3$={k3:.2f} N/m",
        fontsize=16,
        fontweight="bold",
        y=0.995,
    )

    # Adjust spacing
    fig.subplots_adjust(
        left=0.12, right=0.85, top=0.9, bottom=0.1, hspace=0.25, wspace=0.3
    )

    # Colors for x1 and x2
    color_x1 = "crimson"
    color_x2 = "royalblue"

    # Simulate for each coupling condition
    for row, (k2, coupling_label) in enumerate(zip(k2_values, coupling_labels)):
        print(f"\n{'=' * 60}")
        print(f"Simulating: {coupling_label} (k2 = {k2:.2f} N/m)")
        print(f"{'=' * 60}")

        # =====================================================================
        # NUMERICAL SOLUTION
        # =====================================================================

        # Initial state: [x‚ÇÅ, v‚ÇÅ, x‚ÇÇ, v‚ÇÇ]
        y0 = [x1_init, v1_init, x2_init, v2_init]

        # Time span
        t_span = (0, simulation_time)
        n_points = int(simulation_time * fps) + 1
        t_eval = np.linspace(0, simulation_time, n_points)

        # Solve the ODE
        solution = solve_ivp(
            two_mass_spring_derivatives,
            t_span,
            y0,
            args=(m1, m2, k1, k2, k3),
            method="RK45",
            rtol=1e-6,
            atol=1e-8,
            dense_output=True,
        )

        if solution.sol is None:
            print(f"ERROR: ODE solver failed for k2={k2}!")
            continue

        y = solution.sol(t=t_eval)
        x1 = y[0]
        v1 = y[1]
        x2 = y[2]
        v2 = y[3]

        # =====================================================================
        # COMPUTE PHYSICAL QUANTITIES
        # =====================================================================

        # Energies
        KE = 0.5 * m1 * v1**2 + 0.5 * m2 * v2**2
        PE = 0.5 * k1 * x1**2 + 0.5 * k2 * (x2 - x1) ** 2 + 0.5 * k3 * x2**2
        total_E = KE + PE

        # Energy conservation (relative error)
        E0 = total_E[0]
        energy_error = np.abs((total_E - E0) / E0) * 100  # percentage
        max_energy_error = np.max(energy_error)

        # =====================================================================
        # FREQUENCY ANALYSIS
        # =====================================================================

        # Theoretical normal modes
        omega1_th, omega2_th, f1_th, f2_th, mode1, mode2 = (
            th_normal_modes_mass_spring_system(m1, m2, k1, k2, k3)
        )

        # Numerical frequency estimation
        f1_num, f2_num = est_freqs_mass_spring_system(
            t_eval, x1, x2, m1, m2, k1, k2, k3
        )

        def _fmt_freq(f):
            return f"{f:.4f} Hz" if f > 0 else "N/A (Not Excited)"

        print("  Normal Mode Frequencies:")
        print(f"    Mode 1: œâ‚ÇÅ = {omega1_th:.4f} rad/s (f‚ÇÅ = {f1_th:.4f} Hz)")
        print(f"    Mode 2: œâ‚ÇÇ = {omega2_th:.4f} rad/s (f‚ÇÇ = {f2_th:.4f} Hz)")
        print(f"  Energy conservation: Max error = {max_energy_error:.4f}%")

        # =====================================================================
        # PLOTTING - Column 0: x1 vs time
        # =====================================================================

        ax = axes[row, 0]
        ax.plot(t_eval, x1, color=color_x1, lw=1.5)
        ax.axhline(0, color="gray", linestyle="--", lw=0.8, alpha=0.5)
        ax.grid(True, alpha=0.3, linestyle=":")
        ax.set_ylabel("$x_1$ (m)", fontsize=11, fontweight="bold")

        # Add coupling label on left
        if row == 0:
            ax.set_title("Mass 1 Displacement", fontsize=12, fontweight="bold", pad=10)

        # Add row label
        ax.text(
            -0.22,
            0.5,
            coupling_label,
            transform=ax.transAxes,
            fontsize=12,
            fontweight="bold",
            rotation=90,
            va="center",
            ha="center",
        )

        if row == 2:
            ax.set_xlabel("Time (s)", fontsize=11, fontweight="bold")

        # =====================================================================
        # PLOTTING - Column 1: x2 vs time
        # =====================================================================

        ax = axes[row, 1]
        ax.plot(t_eval, x2, color=color_x2, lw=1.5)
        ax.axhline(0, color="gray", linestyle="--", lw=0.8, alpha=0.5)
        ax.grid(True, alpha=0.3, linestyle=":")
        ax.set_ylabel("$x_2$ (m)", fontsize=11, fontweight="bold")

        if row == 0:
            ax.set_title("Mass 2 Displacement", fontsize=12, fontweight="bold", pad=10)

        if row == 2:
            ax.set_xlabel("Time (s)", fontsize=11, fontweight="bold")

        # =====================================================================
        # PLOTTING - Column 2: Phase Space (velocity vs position)
        # =====================================================================

        ax = axes[row, 2]

        # Plot phase space trajectories
        ax.plot(x1, v1, color=color_x1, lw=1.5, label="Mass 1", alpha=0.8)
        ax.plot(x2, v2, color=color_x2, lw=1.5, label="Mass 2", alpha=0.8)

        # Mark initial points
        ax.plot(
            x1[0],
            v1[0],
            "o",
            color=color_x1,
            markersize=8,
            markeredgecolor="black",
            markeredgewidth=1.5,
            label="Mass 1 start",
            zorder=10,
        )
        ax.plot(
            x2[0],
            v2[0],
            "s",
            color=color_x2,
            markersize=8,
            markeredgecolor="black",
            markeredgewidth=1.5,
            label="Mass 2 start",
            zorder=10,
        )

        # Mark final points
        ax.plot(
            x1[-1],
            v1[-1],
            "x",
            color=color_x1,
            markersize=10,
            markeredgewidth=2.5,
            label="Mass 1 end",
            alpha=0.7,
        )
        ax.plot(
            x2[-1],
            v2[-1],
            "x",
            color=color_x2,
            markersize=10,
            markeredgewidth=2.5,
            label="Mass 2 end",
            alpha=0.7,
        )

        # Add reference lines
        ax.axhline(0, color="gray", linestyle="--", lw=0.8, alpha=0.5)
        ax.axvline(0, color="gray", linestyle="--", lw=0.8, alpha=0.5)

        ax.grid(True, alpha=0.3, linestyle=":")
        if row == 2:
            ax.set_xlabel("Position $x$ (m)", fontsize=11, fontweight="bold")
        ax.set_ylabel("Velocity $\\dot{x}$ (m/s)", fontsize=11, fontweight="bold")

        if row == 0:
            ax.set_title(
                "Phase Space (Both Masses)", fontsize=12, fontweight="bold", pad=10
            )

        # Frequency comparison text box
        param_sweep_text = (
            f"Spring Constant:\n"
            f"  $k_2$ = {k2:.2f} N/m\n"
            f"ENERGY\n"
            f"  $E_0$ = {E0:.4f} J\n"
            f"  Max Err = {max_energy_error:.4f}%\n"
            f"Mode 1:\n"
            f"  Theo: {f1_th:.4f} Hz\n"
            f"  Num: {_fmt_freq(f1_num)}\n"
            f"Mode 2:\n"
            f"  Theo: {f2_th:.4f} Hz\n"
            f"  Num: {_fmt_freq(f2_num)}"
        )
        ax.text(
            1.04,
            0.01,
            param_sweep_text,
            transform=ax.transAxes,
            fontsize=8,
            va="bottom",
            ha="left",
            bbox=dict(boxstyle="round", facecolor="lightgreen", alpha=0.6),
            family="monospace",
        )

        # Energy conservation annotation
        sys_info = (
            f"SYSTEM PARAMETERS\n"
            f"{'-' * 20}\n"
            f"Masses: $m_1$={m1:.2f} kg, $m_2$={m2:.2f} kg\n"
            f"Springs: $k_1$={k1:.2f} N/m, $k_3$={k3:.2f} N/m\n"
            f"INITIAL CONDITIONS\n"
            f"{'-' * 22}\n"
            f"$x_1$(0) = {x1_init:+.3f} m, $x_2$(0) = {x2_init:+.3f} m\n"
            f"$\\dot{{x}}_1$(0) = {v1_init:+.3f} m/s, $\\dot{{x}}_2$(0) = {v2_init:+.3f} m/s"
        )

        if row == 0:
            ax.text(
                1.04,
                0.99,
                sys_info,
                transform=ax.transAxes,
                fontsize=8,
                va="top",
                ha="left",
                bbox=dict(boxstyle="round", facecolor="lightgray", alpha=0.6),
                family="monospace",
            )

        # Legend
        if row == 1:
            ax.legend(
                loc="upper right",
                bbox_to_anchor=(1.34, 1),
                fontsize=8,
                framealpha=0.9,
                ncol=1,
                markerscale=0.8,
            )

    # Save figure if requested
    if save_fig:
        if filename is None:
            filename = (
                f"parametric_sweep_k2_weak={k2_weak}_medium={k2_medium}_"
                f"strong={k2_strong}_phase_space.png"
            )

        save_dir = "OUTPUTS/FIGURES/mass_spring_systems"
        os.makedirs(save_dir, exist_ok=True)
        filepath = os.path.join(save_dir, filename)

        try:
            print(f"\nSaving figure to {filepath}...")
            fig.savefig(filepath, dpi=200, bbox_inches="tight")
            print("Figure saved successfully!")
        except Exception as e:
            print(f"Error saving figure: {e}")

    plt.show()

    return fig


# Run parametric sweep with default values
fig = parametric_sweep_coupling_conditions(
    k1=10.0,
    k3=10.0,
    k2_weak=1.0,  # Weak coupling
    k2_medium=3.0,  # Medium coupling
    k2_strong=8.0,  # Strong coupling
    x1_init=0.5,
    x2_init=0.0,
    v1_init=0.0,
    v2_init=0.0,
    m1=1.0,
    m2=1.0,
    system_length=12.0,
    simulation_time=40.0,
    fps=30,
    save_fig=True,
    filename="coupling_strength_parametric_sweep_mass_spring.png",
)

**Analysis of the Beats Phenomenon in Coupled Mass-Spring Systems**

Beats in a coupled mass‚Äìspring system are a *superposition effect*: when your initial conditions excite **both normal modes**, each mass oscillates with two close (or not-so-close) frequencies. The interference between them produces an amplitude envelope (beats).

**1) Normal-mode picture (why beats happen)**  
For the 2‚Äëmass / 3‚Äëspring system (wall‚Äì$k_1$‚Äì$m_1$‚Äì$k_2$‚Äì$m_2$‚Äì$k_3$‚Äìwall), the equations are
$$
m_1\ddot x_1+(k_1+k_2)x_1-k_2x_2=0,\qquad
m_2\ddot x_2+(k_2+k_3)x_2-k_2x_1=0.
$$
Write in matrix form:
$$
M\ddot{\mathbf x}+K\mathbf x=0,\quad
M=\begin{pmatrix}m_1&0\\0&m_2\end{pmatrix},\;
K=\begin{pmatrix}k_1+k_2&-k_2\\-k_2&k_2+k_3\end{pmatrix}.
$$
Normal modes satisfy the generalized eigenproblem
$$
K\mathbf v=\omega^2 M\mathbf v,
$$
giving two eigenfrequencies $\omega_1<\omega_2$.

A convenient closed form (general case) is:
$$
\omega_{1,2}^2=\frac{A\pm\sqrt{A^2-4B}}{2},
$$
with
$$
A=\frac{k_1+k_2}{m_1}+\frac{k_2+k_3}{m_2},\qquad
B=\frac{(k_1+k_2)(k_2+k_3)-k_2^2}{m_1m_2}.
$$
So changing **only $k_2$** changes $\omega_1,\omega_2$ ‚Üí changes their separation.

**2) Beat formula (envelope period comes from frequency splitting)**  
If the motion of (say) mass 1 contains both mode frequencies,
$$
x_1(t)=A\cos(\omega_1 t)+B\cos(\omega_2 t),
$$
then when the two frequencies are not too far apart (classic beats) and $A\approx B$, we can rewrite:
$$
x_1(t)\approx 2A\cos\!\Big(\frac{\omega_2-\omega_1}{2}t\Big)\;
\cos\!\Big(\frac{\omega_1+\omega_2}{2}t\Big).
$$
- The fast oscillation (‚Äúcarrier‚Äù) is at $\omega_{\text{avg}}=\frac{\omega_1+\omega_2}{2}$.
- The slowly varying envelope is set by $\Delta\omega=\omega_2-\omega_1$.

**Beat period**
$$
T_{\text{beat}}=\frac{2\pi}{\Delta\omega}
\quad\text{equivalently}\quad
f_{\text{beat}}=\big|f_2-f_1\big|
\;\; \text{with}\;\; f_i=\frac{\omega_i}{2\pi}.
$$

**Energy-transfer time** (from one mass to the other) is about half the beat period:
$$
T_{\text{transfer}}\approx \frac{T_{\text{beat}}}{2}=\frac{\pi}{\Delta\omega},
$$
because energy scales like amplitude squared (envelope\(^2\)).

**3) Why weak/medium/strong coupling matches our plot trends**

A very clear special case is **symmetric ends**:
$$
m_1=m_2=m,\qquad k_1=k_3=k.
$$
Then the eigenvectors are in‚Äëphase $[1,1]$ and out‚Äëof‚Äëphase $[1,-1]$, with
$$
\omega_1^2=\frac{k}{m}\quad(\text{in-phase, does NOT depend on }k_2),
\qquad
\omega_2^2=\frac{k+2k_2}{m}\quad(\text{out-of-phase, increases with }k_2).
$$
So the splitting is
$$
\Delta\omega
=\sqrt{\frac{k+2k_2}{m}}-\sqrt{\frac{k}{m}},
$$
which **increases** with $k_2$. Therefore
$$
T_{\text{beat}}=\frac{2\pi}{\Delta\omega}\;\;\downarrow\;\;\text{as }k_2\uparrow.
$$

That directly explains our observations:

- **Weak coupling (small $k_2$)**: $\omega_2$ is close to $\omega_1$ ‚áí $\Delta\omega$ small ‚áí $T_{\text{beat}}$ large.  
    In a fixed simulation window, number of beat ‚Äúcapsules‚Äù roughly scales like
    $$
    N \sim \frac{T_{\text{sim}}}{T_{\text{beat}}}=\frac{T_{\text{sim}}\Delta\omega}{2\pi},
    $$
    so we see **fewer** capsules, widely spaced. Beats look **very prominent** because the envelope varies slowly compared to the carrier.

- **Medium coupling**: $\Delta\omega$ larger ‚áí $T_{\text{beat}}$ smaller ‚áí **more** capsules, more evenly spaced in time.

- **Strong coupling (large $k_2$)**: $\Delta\omega$ becomes large ‚áí $T_{\text{beat}}$ is the **smallest**, so we get the **maximum** number of capsules and they look **dense**. Beats can appear ‚Äúless prominent‚Äù visually because the envelope is no longer very slow compared to the fast oscillations; the modulation is happening quickly and the trace looks more like rapidly varying oscillations than clean, slow packets.

Here is a **short, phase-space‚Äìonly** analysis that you can directly add below your figure.


**Phase-Space Interpretation for Different Coupling Strengths**

Phase space ($x$, $\dot{x}$) shows how position and velocity evolve together. In an undamped coupled mass‚Äìspring system, the motion is generally **quasi-periodic** because both normal modes are excited, so the trajectories are not simple ellipses but multi-looped patterns.

**Weak coupling (small $k_2$)**

* Phase trajectories are **nearly elliptical** and thin.
* Orbits evolve slowly, showing gradual deformation.
* Indicates that each mass behaves almost like an independent oscillator.
* **Slow energy exchange** is reflected by the slow rotation of the phase curves.

**Medium coupling**

* Phase portraits become **thicker rosette-like patterns**.
* Trajectories cover a larger region of phase space.
* Faster alternation between high- and low-energy states of each mass.
* Shows **stronger mixing of the two normal modes**.

**Strong coupling (large $k_2$)**

* Phase space is **densely filled** with tightly wound curves.
* Larger velocity range due to stronger restoring forces.
* Envelope modulation is fast, so individual beat cycles are less distinct.
* Motion reflects a **collective oscillation** rather than two weakly linked masses.

**Overall insight**

As coupling strength increases, phase space evolves from **simple, slowly varying ellipses** to **dense quasi-periodic structures**, directly visualizing faster energy transfer and stronger normal-mode interaction.


## 4. N-Mass-Spring Chain: Normal Modes and Dynamics

In this section, we generalize the 2‚Äëmass spring system to $N$ masses connected by $N+1$ springs between two fixed walls. We derive the equations of motion using the Lagrangian formalism. Below we present a step-by-step derivation.

**1) Geometry / coordinates**

Consider $N$ point masses constrained to move on a straight line (1D). They are connected by $N+1$ springs between two rigid walls:

$$
\text{Wall} \;-\; k_1 \;-\; m_1 \;-\; k_2 \;-\; m_2 \;-\; \cdots \;-\; k_N \;-\; m_N \;-\; k_{N+1}\;-\; \text{Wall}.
$$

Let $x_i(t)$ be the displacement of mass $m_i$ from its equilibrium position (positive to the right). Using displacements from equilibrium removes any constant-force terms; only *extensions/compressions* appear.

Boundary (fixed walls) is equivalent to ‚Äúghost‚Äù displacements
$$
x_0(t)=0,\qquad x_{N+1}(t)=0.
$$

**2) Kinetic energy**

Each mass has speed $\dot x_i$, so
$$
\boxed{T=\frac12\sum_{i=1}^{N} m_i \dot x_i^2.}
$$

**3) Potential energy**

Each spring stores elastic energy $\tfrac12 k(\Delta \ell)^2$ where $\Delta \ell$ is its extension from equilibrium.

- Spring $k_1$ is between wall and mass 1: extension is $x_1-x_0=x_1$.
- Spring $k_{i}$ for $2\le i\le N$ is between masses $(i-1)$ and $i$: extension is $x_i-x_{i-1}$.
- Spring $k_{N+1}$ is between mass $N$ and wall: extension is $x_{N+1}-x_N=-x_N$.

Thus
$$
\boxed{
V=\frac12 k_1 x_1^2+\frac12\sum_{i=2}^{N} k_i(x_i-x_{i-1})^2+\frac12 k_{N+1}x_N^2.
}
$$

**4) Lagrangian**

$$
\boxed{\mathcal{L}=T-V.}
$$

**5) Euler‚ÄìLagrange ‚áí equations of motion**

For each generalized coordinate $x_j$ ($j=1,\dots,N$), the Euler‚ÄìLagrange equation is
$$
\frac{d}{dt}\left(\frac{\partial\mathcal{L}}{\partial \dot x_j}\right)-\frac{\partial\mathcal{L}}{\partial x_j}=0.
$$

We have $\frac{\partial\mathcal{L}}{\partial \dot x_j}=m_j\dot x_j \Rightarrow \frac{d}{dt}(\cdot)=m_j\ddot x_j$.

So the dynamics reduce to
$$
m_j\ddot x_j+\frac{\partial V}{\partial x_j}=0.
$$

Now compute $\frac{\partial V}{\partial x_j}$. The coordinate $x_j$ appears in only two spring terms: $k_j(x_j-x_{j-1})^2/2$ and $k_{j+1}(x_{j+1}-x_j)^2/2$ (with endpoints interpreted using $x_0=x_{N+1}=0$). Differentiating yields:

- From the left spring:
$$
\frac{\partial}{\partial x_j}\left(\frac12 k_j(x_j-x_{j-1})^2\right)=k_j(x_j-x_{j-1}).
$$
- From the right spring:
$$
\frac{\partial}{\partial x_j}\left(\frac12 k_{j+1}(x_{j+1}-x_j)^2\right)= -k_{j+1}(x_{j+1}-x_j)=k_{j+1}(x_j-x_{j+1}).
$$

Adding them:
$$
\frac{\partial V}{\partial x_j}=k_j(x_j-x_{j-1})+k_{j+1}(x_j-x_{j+1}).
$$

Therefore, for interior masses ($j=2,\dots,N-1$),
$$
\boxed{
m_j\ddot x_j + (k_j+k_{j+1})x_j - k_j x_{j-1}-k_{j+1}x_{j+1}=0.
}
$$

Endpoints come out naturally using $x_0=x_{N+1}=0$:
- For $j=1$:
$$
\boxed{m_1\ddot x_1+(k_1+k_2)x_1-k_2 x_2=0.}
$$
- For $j=N$:
$$
\boxed{m_N\ddot x_N+(k_N+k_{N+1})x_N-k_N x_{N-1}=0.}
$$


**Matrix equation of motion (compact formalism)**

Define the displacement vector
$$
\mathbf{x}(t)=\begin{bmatrix}x_1\\x_2\\ \vdots\\ x_N\end{bmatrix},
\qquad
M=\mathrm{diag}(m_1,\dots,m_N).
$$

The stiffness matrix $K$ is tridiagonal:
$$
\boxed{
K_{jj}=k_j+k_{j+1},\quad
K_{j,j-1}=-k_j,\quad
K_{j,j+1}=-k_{j+1},
}
$$
with other entries zero.

Then all $N$ equations are
$$
\boxed{M\ddot{\mathbf{x}}+K\mathbf{x}=0.}
$$

This is the central result: dynamics are obtained by assembling $M$ and $K$.


**Normal modes via the eigenvalue equation**

Seek modal solutions
$$
\mathbf{x}(t)=\mathbf{a}\,e^{i\omega t}.
$$
Then $\ddot{\mathbf{x}}=-\omega^2\mathbf{x}$ and
$$
(K-\omega^2 M)\mathbf{a}=0.
$$
Nontrivial mode shapes $\mathbf{a}\ne 0$ require
$$
\boxed{\det(K-\omega^2 M)=0,}
$$
equivalently the generalized eigenproblem
$$
\boxed{K\mathbf{a}=\omega^2 M\mathbf{a}.}
$$

We get $N$ eigenpairs $\{(\omega_r^2,\mathbf{a}^{(r)})\}_{r=1}^{N}$, giving the modal expansion
$$
\boxed{
\mathbf{x}(t)=\sum_{r=1}^{N} C_r\,\mathbf{a}^{(r)}\cos(\omega_r t+\phi_r),
}
$$
and each displacement is
$$
\boxed{
x_i(t)=\sum_{r=1}^{N} C_r\,a^{(r)}_i\cos(\omega_r t+\phi_r).
}
$$

**Mode ‚Äútypes‚Äù (in-phase, out-of-phase, nodes)**

For $N>2$, ‚Äúin-phase/out-of-phase‚Äù generalizes to the *pattern of signs* and the number of internal nodes:
- **Lowest mode (often called in-phase-like):** all masses move with the same sign (no internal nodes), so springs are stretched gently ‚Üí lowest frequency.
- **Higher modes:** one or more sign changes (nodes) along the chain ‚Üí larger relative motion between neighbors ‚Üí higher frequencies.
- **Highest mode (often alternating):** neighboring masses move nearly opposite, maximizing spring deformation ‚Üí highest frequency.

**Beats / energy transfer**

If the initial condition is not a pure eigenvector, multiple modes are excited:
$$
x_i(t)=\sum_r (\text{mode }r\text{ contribution}).
$$
Interference between modes produces amplitude modulation (beats). Energy appears to ‚Äúflow‚Äù among masses because the modal phases evolve at different rates $\omega_r$.

**Closed-form special case (uniform chain)**

If $m_i=m$ and all springs are equal $k_j=k$ (fixed ends), then
$$
\boxed{\omega_r=2\sqrt{\frac{k}{m}}\sin\!\left(\frac{r\pi}{2(N+1)}\right),\quad r=1,\dots,N,}
$$
and a common choice of mode shapes is
$$
\boxed{a^{(r)}_j=\sin\!\left(\frac{jr\pi}{N+1}\right).}
$$


### 4.1 State-Space Derivatives for an N-Mass Chain

The `n_mass_spring_derivatives` function builds the tridiagonal $K$ with
  $K_{00}=k_1+k_2$, middle rows $K_{ii}=k_{i+1}+k_{i+2}$, and off-diagonals $-k$, then computes
  $$
  \ddot{\mathbf{x}}=-M^{-1}K\mathbf{x}
  $$
  using elementwise division by `masses` (since $M$ is diagonal)


In [None]:
def n_mass_spring_derivatives(t, y, masses, springs):
    """
    Compute the derivatives for the N-mass (N+1)-spring system.

    System: Wall ‚Äî[k‚ÇÅ]‚Äî m‚ÇÅ ‚Äî[k‚ÇÇ]‚Äî m‚ÇÇ ‚Äî[k‚ÇÉ]‚Äî ... ‚Äî[k‚Çô]‚Äî m‚Çô ‚Äî[k‚Çô‚Çä‚ÇÅ]‚Äî Wall

    Uses matrix formulation from theory.md:
    M*·∫ç + K*x = 0

    where M is diagonal mass matrix and K is tridiagonal stiffness matrix.

    Parameters
    ----------
    t : float
        Time (not used explicitly, required by solve_ivp)
    y : array-like
        State vector [x‚ÇÅ, v‚ÇÅ, x‚ÇÇ, v‚ÇÇ, ..., x‚Çô, v‚Çô] where v = dx/dt
    masses : array-like
        Array of N masses [m‚ÇÅ, m‚ÇÇ, ..., m‚Çô]
    springs : array-like
        Array of N+1 spring constants [k‚ÇÅ, k‚ÇÇ, ..., k‚Çô‚Çä‚ÇÅ]

    Returns
    -------
    array
        Derivatives [dx‚ÇÅ/dt, dv‚ÇÅ/dt, dx‚ÇÇ/dt, dv‚ÇÇ/dt, ..., dx‚Çô/dt, dv‚Çô/dt]
    """
    N = len(masses)

    # Extract positions and velocities
    x = np.array([y[2 * i] for i in range(N)])
    v = np.array([y[2 * i + 1] for i in range(N)])

    # Stiffness matrix K (tridiagonal)
    K = np.zeros((N, N))

    # First row
    K[0, 0] = springs[0] + springs[1]
    if N > 1:
        K[0, 1] = -springs[1]

    # Middle rows
    for i in range(1, N - 1):
        K[i, i - 1] = -springs[i]
        K[i, i] = springs[i] + springs[i + 1]
        K[i, i + 1] = -springs[i + 1]

    # Last row
    if N > 1:
        K[N - 1, N - 2] = -springs[N - 1]
        K[N - 1, N - 1] = springs[N - 1] + springs[N]

    # Compute accelerations: a = M‚Åª¬π * (-K * x)
    # Since M is diagonal, M‚Åª¬π is just 1/m·µ¢
    Kx = K @ x
    a = -Kx / masses

    # Build derivative vector
    dy = np.zeros(2 * N)
    for i in range(N):
        dy[2 * i] = v[i]
        dy[2 * i + 1] = a[i]

    return dy


### 4.2 Theoretical Normal Modes via Eigenvalue Problem

The `th_normal_modes_n_mass_system` function constructs the same $M$ and $K$ and solves
  $K\mathbf{v}=\omega^2 M\mathbf{v}$ via `eigh(K, M)`, returning `omegas`, `freqs`, and eigenvectors (mode shapes).


In [None]:
def th_normal_modes_n_mass_system(masses, springs):
    """
    Calculate the theoretical normal mode frequencies for N-mass system.

    Solves the generalized eigenvalue problem: K*v = œâ¬≤*M*v

    From theory.md, the stiffness matrix K is tridiagonal:
    K[i,i] = k[i] + k[i+1]
    K[i,i+1] = K[i+1,i] = -k[i+1]

    Parameters
    ----------
    masses : array-like
        Array of N masses [m‚ÇÅ, m‚ÇÇ, ..., m‚Çô]
    springs : array-like
        Array of N+1 spring constants [k‚ÇÅ, k‚ÇÇ, ..., k‚Çô‚Çä‚ÇÅ]

    Returns
    -------
    tuple
        (omegas, freqs, modes)
        Angular frequencies (rad/s), frequencies (Hz), and mode shapes (columns)
    """
    N = len(masses)

    M = np.diag(masses)

    K = np.zeros((N, N))

    K[0, 0] = springs[0] + springs[1]
    if N > 1:
        K[0, 1] = -springs[1]

    for i in range(1, N - 1):
        K[i, i - 1] = -springs[i]
        K[i, i] = springs[i] + springs[i + 1]
        K[i, i + 1] = -springs[i + 1]

    if N > 1:
        K[N - 1, N - 2] = -springs[N - 1]
        K[N - 1, N - 1] = springs[N - 1] + springs[N]

    # Solve generalized eigenvalue problem
    omega_sq, modes = eigh(K, M)

    idx = np.argsort(omega_sq)
    omega_sq = omega_sq[idx]
    modes = modes[:, idx]

    omega_sq = np.maximum(omega_sq, 0.0)

    omegas = np.sqrt(omega_sq)

    freqs = omegas / (2 * np.pi)

    return omegas, freqs, modes


### 4.3 Numerical Frequency Estimates (FFT)

`est_freqs_n_mass_system` takes simulated position data $x_i(t)$, removes DC offsets, computes mode shapes $V$ from the same generalized eigenproblem, projects to modal coordinates
$$
  q(t)=V^T M x(t),
$$
then FFTs each $q_r(t)$ and picks the dominant non-DC peak as the mode‚Äôs frequency (with a small threshold for ‚Äúnot excited‚Äù)

In [None]:
def est_freqs_n_mass_system(t, x_data, masses, springs, use_hann_window=True):
    """
    Estimate the normal mode frequencies for N-mass system using FFT.

    Projects the simulation data onto the normal mode coordinates and
    performs FFT to extract dominant frequencies.

    Parameters
    ----------
    t : array
        Time array
    x_data : array
        Position data, shape (N, n_timesteps)
    masses : array-like
        Array of N masses
    springs : array-like
        Array of N+1 spring constants
    use_hann_window : bool
        Whether to apply Hann window to reduce spectral leakage

    Returns
    -------
    array
        Estimated frequencies for each mode in Hz
    """
    t = np.asarray(t)
    dt = t[1] - t[0]
    n = t.size
    N = len(masses)

    x = np.asarray(x_data)

    x = x - x.mean(axis=1, keepdims=True)

    M = np.diag(masses)

    # Build stiffness matrix
    K = np.zeros((N, N))
    K[0, 0] = springs[0] + springs[1]
    if N > 1:
        K[0, 1] = -springs[1]
    for i in range(1, N - 1):
        K[i, i - 1] = -springs[i]
        K[i, i] = springs[i] + springs[i + 1]
        K[i, i + 1] = -springs[i + 1]
    if N > 1:
        K[N - 1, N - 2] = -springs[N - 1]
        K[N - 1, N - 1] = springs[N - 1] + springs[N]

    # Get mode shapes
    omega_sq, V = eigh(K, M)

    # Project onto modal coordinates: q = V^T M x
    q = V.T @ (M @ x)

    # Apply Hann window if requested
    if use_hann_window:
        window = np.hanning(n)
        q = q * window

    # FFT on each modal coordinate
    freqs = np.fft.rfftfreq(n, dt)
    Q = np.fft.rfft(q, axis=1)
    amp = np.abs(Q)

    rms = np.sqrt(np.mean(q**2, axis=1))
    rms_rel = rms / (np.max(rms) + 1e-30)

    f_est = []
    for i in range(N):
        if rms_rel[i] < 1e-3:
            f_est.append(0.0)
            continue

        peak_idx = np.argmax(amp[i, 1:]) + 1
        f_est.append(freqs[peak_idx])

    return np.array(f_est)


### 4.4 Animation of the N-Mass Chain

The `simulate_N_coupled_mass_spring_system` function defined below sets up arrays `masses` (length $N$) and `springs` (length $N+1$), integrates the first-order system with `solve_ivp(n_mass_spring_derivatives, ...)`, precomputes equilibrium positions and actual positions, computes spring extensions using $x_1$, $x_i-x_{i-1}$, and $-x_N$, prints theoretical vs estimated frequencies, and animates by updating $N+1$ spring artists and $N$ mass artists per frame via `FuncAnimation`.

In [None]:
def simulate_N_coupled_mass_spring_system(
    N_mass=3,
    x_init=None,
    v_init=None,
    masses=None,
    springs=None,
    system_length=12.0,
    simulation_time=20.0,
    fps=30,
    save_anim=False,
    filename=None,
):
    """
    Simulate and animate a coupled N-mass spring system.

    System: Wall ‚Äî[k‚ÇÅ]‚Äî m‚ÇÅ ‚Äî[k‚ÇÇ]‚Äî m‚ÇÇ ‚Äî[k‚ÇÉ]‚Äî ... ‚Äî[k‚Çô]‚Äî m‚Çô ‚Äî[k‚Çô‚Çä‚ÇÅ]‚Äî Wall

    Parameters
    ----------
    N_mass : int
        Number of masses (default=3, max=6)
    x_init : array-like or None
        Initial displacements from equilibrium for each mass.
        If None, uses alternating pattern.
    v_init : array-like or None
        Initial velocities for each mass. If None, uses zeros.
    masses : array-like or None
        Array of N masses. If None, uses uniform masses of 1.0 kg.
    springs : array-like or None
        Array of N+1 spring constants. If None, uses uniform springs of 10.0 N/m.
    system_length : float
        Distance between walls
    simulation_time : float
        Total simulation time in seconds
    fps : int
        Frames per second for animation
    save_anim : bool
        Whether to save the animation
    filename : str or None
        Filename for saved animation

    Returns
    -------
    tuple
        (fig, anim) - Figure and animation objects
    """

    # Validate N_mass
    if N_mass < 1:
        raise ValueError("N_mass must be at least 1")
    if N_mass > 6:
        raise ValueError("N_mass is limited to 6 due to complexity")

    # Set default parameters
    if masses is None:
        masses = np.ones(N_mass)
    else:
        masses = np.array(masses)
        if len(masses) != N_mass:
            raise ValueError(f"masses array must have length {N_mass}")

    if springs is None:
        springs = 10.0 * np.ones(N_mass + 1)
    else:
        springs = np.array(springs)
        if len(springs) != N_mass + 1:
            raise ValueError(f"springs array must have length {N_mass + 1}")

    if x_init is None:
        # Default: alternating displacements
        x_init = np.array([0.5 * (-1) ** i for i in range(N_mass)])
    else:
        x_init = np.array(x_init)
        if len(x_init) != N_mass:
            raise ValueError(f"x_init array must have length {N_mass}")

    if v_init is None:
        v_init = np.zeros(N_mass)
    else:
        v_init = np.array(v_init)
        if len(v_init) != N_mass:
            raise ValueError(f"v_init array must have length {N_mass}")

    # =========================================================================
    # NUMERICAL SOLUTION
    # =========================================================================

    # Initial state: [x‚ÇÅ, v‚ÇÅ, x‚ÇÇ, v‚ÇÇ, ..., x‚Çô, v‚Çô]
    y0 = []
    for i in range(N_mass):
        y0.extend([x_init[i], v_init[i]])

    # Time span
    t_span = (0, simulation_time)
    n_frames = int(simulation_time * fps) + 1
    t_eval = np.linspace(0, simulation_time, n_frames)

    # Solve the ODE
    print(f"Solving differential equations for {N_mass}-mass system...")
    solution = solve_ivp(
        n_mass_spring_derivatives,
        t_span,
        y0,
        args=(masses, springs),
        method="RK45",
        rtol=1e-5,
        atol=1e-7,
        dense_output=True,
    )

    if solution.sol is None:
        print("ERROR: ODE solver failed!")
        return None, None

    y = solution.sol(t=t_eval)

    # Extract positions and velocities for each mass
    x = np.zeros((N_mass, n_frames))
    v = np.zeros((N_mass, n_frames))
    for i in range(N_mass):
        x[i] = y[2 * i]
        v[i] = y[2 * i + 1]

    print(f"Solution computed: {len(t_eval)} time steps")

    # =========================================================================
    # PRECOMPUTE PHYSICAL QUANTITIES
    # =========================================================================

    # Equilibrium positions (evenly spaced between walls)
    eq_positions = np.linspace(
        system_length / (N_mass + 1), N_mass * system_length / (N_mass + 1), N_mass
    )

    # Actual positions
    pos_x = eq_positions[:, np.newaxis] + x

    # Spring extensions/compressions
    spring_ext = np.zeros((N_mass + 1, n_frames))
    spring_ext[0] = x[0]  # First spring: left wall to m1
    for i in range(1, N_mass):
        spring_ext[i] = x[i] - x[i - 1]  # Spring between masses
    spring_ext[N_mass] = -x[N_mass - 1]  # Last spring: last mass to right wall

    # Energies
    KE_all = np.sum(0.5 * masses[:, np.newaxis] * v**2, axis=0)

    PE_spring_all = 0.5 * springs[0] * x[0] ** 2
    for i in range(1, N_mass):
        PE_spring_all += 0.5 * springs[i] * (x[i] - x[i - 1]) ** 2
    PE_spring_all += 0.5 * springs[N_mass] * x[N_mass - 1] ** 2

    total_E_all = KE_all + PE_spring_all

    # =========================================================================
    # FREQUENCY ANALYSIS
    # =========================================================================

    # Theoretical normal modes
    omegas_theory, freqs_theory, modes = th_normal_modes_n_mass_system(masses, springs)

    # Numerical frequency estimation
    freqs_num = est_freqs_n_mass_system(t_eval, x, masses, springs)

    print(f"\nNormal Mode Frequencies for {N_mass}-mass system:")
    for i in range(N_mass):
        simul_val = freqs_num[i]
        if simul_val == 0.0:
            simul_str = "N/A"
        else:
            simul_str = f"{simul_val:.4f} Hz"

        print(
            f"  Mode {i + 1}: œâ = {omegas_theory[i]:.4f} rad/s (Theory: {freqs_theory[i]:.4f} Hz, Simul: {simul_str})"
        )

    # =========================================================================
    # ANIMATION SETUP
    # =========================================================================

    # Geometry parameters
    left_wall_x = 0.0
    right_wall_x = system_length
    y_position = 0.0
    wall_width = 0.25
    wall_height = 1

    # Spring parameters
    relaxed_spring_length = system_length / (N_mass + 1)
    hook_length = calculate_fixed_hook_length_mass_spring(relaxed_spring_length)
    spring_radius = 0.08 if N_mass <= 3 else 0.05
    spring_num_coils = 8

    # Mass visualization sizes (proportional to mass)
    mass_size_ref = 25 if N_mass <= 3 else 20
    avg_mass = np.mean(masses)
    mass_sizes = [max(15, mass_size_ref * np.cbrt(m / avg_mass)) for m in masses]

    # Color palette for masses
    colors = ["crimson", "royalblue", "green", "orange", "purple", "brown"]
    mass_colors = [colors[i % len(colors)] for i in range(N_mass)]

    # Figure setup
    fig, ax = plt.subplots(figsize=(14, 7))
    fig.subplots_adjust(
        top=0.92, bottom=0.37 if N_mass > 3 else 0.3, left=0.02, right=0.98
    )

    # Set plot limits
    padding = 0.25
    x_min = left_wall_x - padding
    x_max = right_wall_x + padding
    y_min = y_position - wall_height / 2 - 0.5 * padding
    y_max = y_position + wall_height / 2 + 0.5 * padding

    ax.set_xlim(x_min, x_max)
    ax.set_ylim(y_min, y_max)
    ax.set_xticks([])
    ax.set_yticks([])
    ax.set_xticklabels([])
    ax.set_yticklabels([])

    # Determine motion type
    if N_mass == 1:
        motion_type = "Single Mass Oscillation"
    elif np.allclose(x_init, x_init[0]):
        motion_type = "In-Phase Mode (All masses move together)"
    elif np.allclose(x_init[::2], x_init[0]) and np.allclose(x_init[1::2], -x_init[0]):
        motion_type = "Alternating Mode (Opposite directions)"
    else:
        motion_type = "Mixed Mode (Superposition)"

    # Title
    omega_str = ", ".join(
        [f"$\\omega_{i + 1}={omegas_theory[i]:.2f}$" for i in range(min(3, N_mass))]
    )
    if N_mass > 3:
        omega_str += "..."
    title_text = (
        f"{N_mass}-Mass Coupled Spring System: {motion_type}\n"
        f"Normal Modes: {omega_str} rad/s"
    )
    ax.set_title(title_text, fontsize=14, fontweight="bold", pad=10)

    # Initialize plot elements
    # Walls
    draw_wall(ax, left_wall_x, y_position, wall_width, wall_height)
    draw_wall(ax, right_wall_x, y_position, wall_width, wall_height)

    # Equilibrium reference lines
    for eq_pos in eq_positions:
        ax.axvline(eq_pos, color="gray", linestyle="--", lw=1, alpha=0.5, zorder=0)
    ax.axhline(y_position, color="gray", linestyle=":", lw=1, alpha=0.3, zorder=0)

    # Springs (will be updated) - N+1 springs
    spring_lines = []
    for i in range(N_mass + 1):
        (line,) = ax.plot([], [], color="dimgrey", lw=2, zorder=2)
        spring_lines.append(line)

    # Masses - N masses
    mass_plots = []
    for i in range(N_mass):
        (plot,) = ax.plot(
            [], [], "o", color=mass_colors[i], markersize=mass_sizes[i], zorder=10
        )
        mass_plots.append(plot)

    # Traces
    trace_length = min(300, n_frames // 2)
    traces_x = [[] for _ in range(N_mass)]
    traces_y = [[] for _ in range(N_mass)]
    trace_lines = []
    for i in range(N_mass):
        (line,) = ax.plot([], [], "-", color=mass_colors[i], alpha=0.3, lw=1, zorder=1)
        trace_lines.append(line)

    # =========================================================================
    # TEXT ANNOTATIONS
    # =========================================================================

    # Timer
    time_text = ax.text(
        0.02,
        0.98,
        "",
        transform=ax.transAxes,
        fontsize=12,
        verticalalignment="top",
        bbox=dict(boxstyle="round", facecolor="wheat", alpha=0.8),
    )

    # System info box (left side)
    info_text = ax.text(
        0.005,
        -0.02,
        "",
        fontsize=9,
        transform=ax.transAxes,
        va="top",
        ha="left",
        family="monospace",
        bbox=dict(boxstyle="round", facecolor="lightblue", alpha=0.8),
    )

    # Dynamic info box (middle)
    dynamic_text = ax.text(
        0.38,
        -0.02,
        "",
        fontsize=9,
        transform=ax.transAxes,
        va="top",
        ha="left",
        family="monospace",
        bbox=dict(boxstyle="round", facecolor="lightyellow", alpha=0.8),
    )

    # Frequency comparison box (right)
    freq_text = ax.text(
        0.76,
        -0.02,
        "",
        fontsize=9,
        transform=ax.transAxes,
        va="top",
        ha="left",
        family="monospace",
        bbox=dict(boxstyle="round", facecolor="lightgreen", alpha=0.8),
    )

    # Set static system info
    def format_info_list(label, items, force_break=False):
        """Helper to format list with optional line break and indentation"""
        if not force_break or len(items) <= 3:
            return f"{label}: " + ", ".join(items)
        else:
            mid = (len(items) + 1) // 2
            line1 = ", ".join(items[:mid])
            line2 = ", ".join(items[mid:])
            # Indent second line to align with first value
            indent = " " * (len(label) + 2)
            return f"{label}: {line1}\n{indent}{line2}"

    # Generate item strings
    m_items = [f"$m_{i + 1}={m:.2f}$" for i, m in enumerate(masses)]
    k_items = [f"$k_{i + 1}={k:.1f}$" for i, k in enumerate(springs)]
    x_items = [f"$x_{i + 1}={x:+.2f}$" for i, x in enumerate(x_init)]
    v_items = [f"$v_{i + 1}={v:+.2f}$" for i, v in enumerate(v_init)]

    should_break = N_mass > 3

    # Use slightly shorter labels for N>3 to save space
    l_mass = "Masses" if N_mass <= 3 else "Mass"
    l_spring = "Springs" if N_mass <= 3 else "Spr"
    l_pos = "Positions" if N_mass <= 3 else "Pos(0)"
    l_vel = "Velocities" if N_mass <= 3 else "Vel(0)"

    mass_str = format_info_list(l_mass, m_items, should_break)
    spring_str = format_info_list(l_spring, k_items, should_break)
    x_init_str = format_info_list(l_pos, x_items, should_break)
    v_init_str = format_info_list(l_vel, v_items, should_break)

    info_str = (
        f"SYSTEM PARAMETERS (N={N_mass})\n"
        "‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ\n"
        f"{mass_str}\n"
        f"{spring_str}\n"
        f"\n"
        f"INITIAL CONDITIONS\n"
        f"‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ\n"
        f"{x_init_str}\n"
        f"{v_init_str}\n"
    )
    info_text.set_text(info_str)

    # Set static frequency comparison
    freq_str = f"FREQUENCY ANALYSIS (N={N_mass} modes)\n" + "‚îÄ" * 25 + "\n"
    for i in range(N_mass):
        simul_val = freqs_num[i]
        if simul_val == 0.0:
            simul_str = "N/A"
        else:
            simul_str = f"{simul_val:.4f} Hz"
        freq_str += (
            f"Mode {i + 1}: Theory={freqs_theory[i]:.4f} Hz, Simul={simul_str}\n"
        )
    freq_text.set_text(freq_str)

    # =========================================================================
    # ANIMATION FUNCTION
    # =========================================================================

    def init():
        """Initialize animation"""
        for spring_line in spring_lines:
            spring_line.set_data([], [])
        for mass_plot in mass_plots:
            mass_plot.set_data([], [])
        for trace_line in trace_lines:
            trace_line.set_data([], [])
        time_text.set_text("")
        dynamic_text.set_text("")
        return (*spring_lines, *mass_plots, *trace_lines, time_text, dynamic_text)

    def update(frame):
        """Update animation for each frame"""
        # Draw first spring: Left wall to first mass
        x_s, y_s = draw_spring_with_hook(
            (left_wall_x + wall_width / 2, y_position),
            (pos_x[0, frame], y_position),
            spring_num_coils,
            spring_radius,
            hook_length,
        )
        spring_lines[0].set_data(x_s, y_s)

        # Draw intermediate springs: mass i to mass i+1
        for i in range(N_mass - 1):
            x_s, y_s = draw_spring_with_hook(
                (pos_x[i, frame], y_position),
                (pos_x[i + 1, frame], y_position),
                spring_num_coils,
                spring_radius,
                hook_length,
            )
            spring_lines[i + 1].set_data(x_s, y_s)

        # Draw last spring: last mass to right wall
        x_s, y_s = draw_spring_with_hook(
            (pos_x[N_mass - 1, frame], y_position),
            (right_wall_x - wall_width / 2, y_position),
            spring_num_coils,
            spring_radius,
            hook_length,
        )
        spring_lines[N_mass].set_data(x_s, y_s)

        # Update masses
        for i in range(N_mass):
            mass_plots[i].set_data([pos_x[i, frame]], [y_position])

        # Update traces
        for i in range(N_mass):
            traces_x[i].append(pos_x[i, frame])
            traces_y[i].append(y_position)

            if len(traces_x[i]) > trace_length:
                traces_x[i].pop(0)
                traces_y[i].pop(0)

            trace_lines[i].set_data(traces_x[i], traces_y[i])

        # Update timer
        time_text.set_text(f"Time: {t_eval[frame]:.2f} s")

        def format_dyn_list(label, values, force_break=False):
            items = [f"{val:+.2f}" for val in values]
            if not force_break:
                return f"{label}: " + ", ".join(items)
            else:
                mid = (len(items) + 1) // 2
                line1 = ", ".join(items[:mid])
                line2 = ", ".join(items[mid:])
                # Indent to align with first value (after ": ")
                indent = " " * (len(label) + 2)
                return f"{label}: {line1}\n{indent}{line2}"

        if N_mass <= 3:
            pos_str = ", ".join(
                [f"$x_{i + 1}$={x[i, frame]:+.2f}" for i in range(N_mass)]
            )
            vel_str = ", ".join(
                [f"v_{i + 1}={v[i, frame]:+.2f}" for i in range(N_mass)]
            )
            spring_ext_str = ", ".join(
                [
                    f"$\\Delta s_{i + 1}$={spring_ext[i, frame]:+.2f}"
                    for i in range(N_mass + 1)
                ]
            )

            # Construct display strings with full headers
            pos_display = f"Positions (m): {pos_str}"
            vel_display = f"Velocities (m/s): {vel_str}"
            spring_display = spring_ext_str
        else:
            # Compact pretty format
            pos_display = format_dyn_list("Pos", x[:, frame], force_break=True)
            vel_display = format_dyn_list("Vel", v[:, frame], force_break=True)

            # For spring extensions, format without label prefix since it goes under header
            s_items = [f"{spring_ext[i, frame]:+.2f}" for i in range(N_mass + 1)]
            mid = (len(s_items) + 1) // 2
            s_line1 = ", ".join(s_items[:mid])
            s_line2 = ", ".join(s_items[mid:])
            spring_display = f"{s_line1}\n{s_line2}"

        dynamic_str = (
            "CURRENT STATE\n"
            "‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ\n"
            f"{pos_display}\n"
            f"{vel_display}\n"
            f"\n"
            f"SPRING EXTENSIONS (m)\n"
            f"‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ\n"
            f"{spring_display}\n"
            f"\n"
            f"ENERGY\n"
            f"‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ\n"
            f"KE = {KE_all[frame]:.3f} J, "
            f"PE = {PE_spring_all[frame]:.3f} J, "
            f"Total = {total_E_all[frame]:.3f} J"
        )
        dynamic_text.set_text(dynamic_str)

        return (*spring_lines, *mass_plots, *trace_lines, time_text, dynamic_text)

    # Create animation
    print("Creating animation...")
    anim = FuncAnimation(
        fig,
        update,
        init_func=init,
        frames=n_frames,
        interval=int(1000 / fps),
        blit=False,
        repeat=False,
    )

    # Save animation if requested
    if save_anim:
        if filename is None:
            x_str = "_".join([f"x{i + 1}={x_init[i]:.1f}" for i in range(N_mass)])
            m_str = "_".join([f"m{i + 1}={masses[i]:.1f}" for i in range(N_mass)])
            filename = f"N{N_mass}_mass_spring_{x_str}_{m_str}.gif"

        save_dir = "OUTPUTS/ANIMATIONS/mass_spring_systems"
        os.makedirs(save_dir, exist_ok=True)
        filepath = os.path.join(save_dir, filename)
        try:
            print(f"Saving animation to {filepath}...")
            anim.save(filepath, writer="ffmpeg", fps=fps, codec="gif", dpi=100)
            print("Animation saved!")
            plt.close(fig)
        except Exception as e:
            print(f"Error saving animation: {e}")
    else:
        plt.show()

    return anim


def main_N_mass_spring_system_simulation():
    """Main function to run N-mass spring system simulation with user-defined parameters"""
    print("\n" + "=" * 60)
    print("COUPLED N-MASS SPRING SYSTEM SIMULATION AND ANIMATION")
    print("=" * 60)

    default_params = {
        "N_mass": 4,
        "system_length": 12.0,
        "simulation_time": 10.0,
        "fps": 30,
    }

    use_defaults = input("Use default parameters? (y/n): ").strip().lower() == "y"

    if use_defaults:
        N_mass = int(default_params["N_mass"])
        system_length = float(default_params["system_length"])
        simulation_time = float(default_params["simulation_time"])
        fps = int(default_params["fps"])

        # Let simulate_N_coupled_mass_spring_system pick its built-in defaults
        x_init = None
        v_init = None
        masses = None
        springs = None
    else:
        n_in = input(
            f"Enter number of masses (1-6) (default={default_params['N_mass']}): "
        ).strip()
        try:
            N_mass = int(n_in) if n_in else int(default_params["N_mass"])
        except ValueError:
            print("Invalid input. Using default number of masses.")
            N_mass = int(default_params["N_mass"])

        if N_mass < 1 or N_mass > 6:
            print("N_mass must be between 1 and 6. Using default.")
            N_mass = int(default_params["N_mass"])

        def _get_float(prompt, default):
            s = input(f"{prompt} (default={default}): ").strip()
            if s == "":
                return float(default)
            try:
                return float(s)
            except ValueError:
                print(f"Invalid input. Using default {default}.")
                return float(default)

        system_length = _get_float(
            "Enter system length (distance between walls)",
            default_params["system_length"],
        )
        simulation_time = _get_float(
            "Enter total simulation time (s)", default_params["simulation_time"]
        )
        fps = int(_get_float("Enter FPS", default_params["fps"]))

        def _get_comma_separated_array(label, n, default_val):
            """Get array values from comma-separated input"""
            use_custom = (
                input(f"Provide custom {label}? (y/n): ").strip().lower() == "y"
            )
            if not use_custom:
                return None

            # Show example format
            example = ", ".join([str(default_val)] * n)
            prompt = (
                f"  Enter {n} values for {label} (comma-separated, e.g., {example}): "
            )

            while True:
                user_input = input(prompt).strip()

                if not user_input:
                    # Empty input means use all defaults
                    return None

                try:
                    # Split by comma and strip whitespace from each value
                    values = [float(val.strip()) for val in user_input.split(",")]

                    if len(values) != n:
                        print(
                            f"Error: Expected {n} values, but got {len(values)}. Please try again."
                        )
                        continue

                    return np.array(values, dtype=float)

                except ValueError:
                    print(
                        "Invalid input. Please enter numeric values separated by commas."
                    )
                    continue

        x_init = _get_comma_separated_array("initial displacements x (m)", N_mass, 0.0)
        v_init = _get_comma_separated_array("initial velocities v (m/s)", N_mass, 0.0)
        masses = _get_comma_separated_array("masses m (kg)", N_mass, 1.0)
        springs = _get_comma_separated_array(
            "spring constants k (N/m)", N_mass + 1, 10.0
        )

    save_animation = input("Save animation? (y/n): ").strip().lower() == "y"
    filename = None
    if save_animation:
        filename_input = input(
            "Enter filename for animation (e.g., 'n_mass_spring.gif'): "
        ).strip()
        if filename_input:
            if not filename_input.lower().endswith(".gif"):
                filename_input += ".gif"
            filename = filename_input

    print("\nStarting simulation with the following parameters:")
    print(f"  Number of masses: {N_mass}")
    print(f"  System length: {system_length} m")
    print(f"  Simulation time: {simulation_time} s")
    print(f"  FPS: {fps}")
    if x_init is not None:
        print(f"  Initial displacements: {x_init}")
    if v_init is not None:
        print(f"  Initial velocities: {v_init}")
    if masses is not None:
        print(f"  Masses: {masses}")
    if springs is not None:
        print(f"  Spring constants: {springs}")
    print(f"  Save animation: {'Yes' if save_animation else 'No'}")

    anim = simulate_N_coupled_mass_spring_system(
        N_mass=N_mass,
        x_init=x_init,
        v_init=v_init,
        masses=masses,
        springs=springs,
        system_length=system_length,
        simulation_time=simulation_time,
        fps=fps,
        save_anim=save_animation,
        filename=filename,
    )
    return anim


if __name__ == "__main__":
    animation = main_N_mass_spring_system_simulation()

### 4.5 Phase Portraits and Configuration-Space Plots (Example: N=3)

- **Phase space**: for each mass $i$, the 2D phase portrait is $(x_i,\;v_i=\dot x_i)$. A point is the instantaneous state of that *single* degree of freedom; the curve shows how that state evolves in time.  
- **Configuration space**: the system's position state only. For 3 masses it's the 3D space $(x_1,x_2,x_3)$. One point is the *whole system shape* at an instant; the trajectory shows how the shape changes (e.g., energy sloshing between masses corresponds to moving along different directions in this 3D space).

**Energy contours**

In the phase plot for mass $i$, the code overlays contours of the **reference single-oscillator energy**
$$
E_{\text{ref},i}(x_i,v_i)=\frac12 m_i v_i^2+\frac12(k_{\text{left}}+k_{\text{right}})\,x_i^2.
$$
Why this formula makes sense for a *2D* $(x_i,v_i)$ plot:
- The **true** potential energy involving mass $i$ depends on neighbor positions too (e.g., $(x_i-x_{i-1})^2$, $(x_{i+1}-x_i)^2$), so it lives in higher-dimensional space and can't be contoured exactly on a 2D $(x_i,v_i)$ plane.
- The contour formula comes from a **fixed-neighbor approximation** for visualization: assume the neighbors are held at equilibrium $(x_{i-1}=0,\;x_{i+1}=0)$. Then the net restoring force on $x_i$ is
$$
m_i\ddot x_i=-k_{\text{left}}(x_i-0)-k_{\text{right}}(x_i-0)=-(k_{\text{left}}+k_{\text{right}})x_i,
$$
which corresponds to an effective harmonic potential $V_{\text{eff}}=\tfrac12(k_{\text{left}}+k_{\text{right}})x_i^2$. That's exactly what the contours show: "what energy level curves would look like if this mass were isolated with stiffness $k_{\text{eff}}$".

**How this differs from the total energy**  
- $E_{\text{ref},i}$ is **not** the total energy of the coupled system, and it is **not conserved** for mass $i$ in the coupled dynamics (because energy is exchanged with neighbors).  
- The **total mechanical energy** (for the ideal undamped mass‚Äìspring chain) uses *all* coordinates and the *true* spring extensions. For 3 masses with fixed walls and springs $(k_1,k_2,k_3,k_4)$, the total energy is
$$
E_{\text{tot}}=\frac12 m_1 v_1^2+\frac12 m_2 v_2^2+\frac12 m_3 v_3^2
+\frac12 k_1 x_1^2+\frac12 k_2(x_2-x_1)^2+\frac12 k_3(x_3-x_2)^2+\frac12 k_4 x_3^2.
$$
This depends on $(x_1,x_2,x_3,v_1,v_2,v_3)$, so we can't represent its level sets on a single $(x_i,v_i)$ plot without fixing additional variables.

**How it's implemented in the code**  
- **Energy contours**: `compute_energy_levels(...)` builds a meshgrid $(X,V)$ and computes $E=\tfrac12 m V^2+\tfrac12(k_{\text{left}}+k_{\text{right}})X^2$, then each phase subplot calls `ax.contour(X,V,E,...)` to draw dashed contour lines.
- **Phase space trajectories**: `plot_3mass_phase_configuration(...)` solves the ODE with `solve_ivp(n_mass_spring_derivatives, ...)`, extracts $(x_i,v_i)$ time series, and plots the trajectory with a time colormap plus start/end markers and direction arrows.
- **Configuration space**: the 3D plot draws the parametric curve $(x_1(t),x_2(t),x_3(t))$ with time-coded color, start/end markers, and a few 3D direction arrows.

In [None]:
class Arrow3D(FancyArrowPatch):
    """Helper class for 3D arrows"""

    def __init__(self, xs, ys, zs, *args, **kwargs):
        super().__init__((0, 0), (0, 0), *args, **kwargs)
        self._verts3d = xs, ys, zs

    def do_3d_projection(self, renderer=None):
        xs3d, ys3d, zs3d = self._verts3d
        M = self.axes.get_proj()  # type: ignore
        xs, ys, zs = proj3d.proj_transform(xs3d, ys3d, zs3d, M)
        self.set_positions((xs[0], ys[0]), (xs[1], ys[1]))
        return np.min(zs)


def compute_energy_levels(x_range, v_range, mass, spring_left, spring_right):
    """
    Compute energy contour levels for phase space plot.

    Energy: E = 0.5 * m * v¬≤ + 0.5 * (k_left + k_right) * x¬≤
    """
    X, V = np.meshgrid(x_range, v_range)
    # Potential energy from both adjacent springs
    k_eff = spring_left + spring_right
    E = 0.5 * mass * V**2 + 0.5 * k_eff * X**2
    return X, V, E


def plot_3mass_phase_configuration(
    masses=None,
    springs=None,
    x_init=None,
    v_init=None,
    simulation_time=20.0,
    n_points=1000,
    colormap="viridis",
    arrow_density=30,
    energy_contours=8,
    fps=40,
    save_fig=False,
    filename=None,
):
    """
    Generates 4 enhanced static subplots for a 3-mass system with:
    - Time-coded color gradients
    - Start/end markers
    - Direction arrows
    - Energy contours (phase space)
    - Projection shadows (3D configuration space)

    Parameters
    ----------
    masses : array-like, optional
        Mass values [m1, m2, m3]
    springs : array-like, optional
        Spring constants [k1, k2, k3, k4]
    x_init : array-like, optional
        Initial displacements [x1, x2, x3]
    v_init : array-like, optional
        Initial velocities [v1, v2, v3]
    simulation_time : float
        Total simulation time
    n_points : int
        Number of time points
    colormap : str
        Colormap name for time progression
    arrow_density : int
        Number of direction arrows to plot
    energy_contours : int
        Number of energy contour levels
    save_fig : bool
        Whether to save the figure
    filename : str, optional
        Filename to save the figure
    """
    N_mass = 3

    # Default parameters if None
    if masses is None:
        masses = np.ones(N_mass)
    else:
        masses = np.array(masses)
        if len(masses) != N_mass:
            raise ValueError(f"masses array must have length {N_mass}")

    if springs is None:
        springs = 10.0 * np.ones(N_mass + 1)
    else:
        springs = np.array(springs)
        if len(springs) != N_mass + 1:
            raise ValueError(f"springs array must have length {N_mass + 1}")

    if x_init is None:
        x_init = np.array([0.5 * (-1) ** i for i in range(N_mass)])
    else:
        x_init = np.array(x_init)
        if len(x_init) != N_mass:
            raise ValueError(f"x_init array must have length {N_mass}")

    if v_init is None:
        v_init = np.zeros(N_mass)
    else:
        v_init = np.array(v_init)
        if len(v_init) != N_mass:
            raise ValueError(f"v_init array must have length {N_mass}")

    # Initial state vector
    y0 = []
    for i in range(N_mass):
        y0.extend([x_init[i], v_init[i]])

    # Solve ODE
    print("Solving ODE for enhanced phase space plots...")
    t_span = (0, simulation_time)
    t_eval = np.linspace(0, simulation_time, n_points)

    sol = solve_ivp(
        n_mass_spring_derivatives,
        t_span,
        y0,
        args=(masses, springs),
        method="RK45",
        t_eval=t_eval,
        rtol=1e-5,
        atol=1e-7,
    )

    if not sol.success:
        print("ODE solver failed for phase space plots.")
        return None, None

    y = sol.y

    # Extract data
    x1, v1 = y[0], y[1]
    x2, v2 = y[2], y[3]
    x3, v3 = y[4], y[5]

    # Create figure
    fig = plt.figure(figsize=(14, 11))
    fig.suptitle(
        (
            f"Phase Space and Configuration Space for Couple Mass-Spring System with {N_mass} Masses\n"
            f"ICs: $x$=[{x_init[0]:.2f}, {x_init[1]:.2f}, {x_init[2]:.2f}], "
            f"$v$=[{v_init[0]:.2f}, {v_init[1]:.2f}, {v_init[2]:.2f}], $T_{{sim}}={simulation_time:.1f}$ s"
        ),
        fontsize=15,
        fontweight="bold",
    )
    fig.subplots_adjust(
        top=0.9, bottom=0.08, left=0.08, right=0.96, hspace=0.3, wspace=0.3
    )

    # Colormap for time progression
    cmap = plt.get_cmap(colormap)
    colors = cmap(np.linspace(0, 1, n_points))
    time_norm = plt.Normalize(0, simulation_time)  # type: ignore

    # =========================================================================
    # PHASE SPACE PLOTS WITH ENHANCEMENTS
    # =========================================================================

    phase_data = [
        (x1, v1, masses[0], springs[0], springs[1], "Mass 1", (2, 2, 1)),
        (x2, v2, masses[1], springs[1], springs[2], "Mass 2", (2, 2, 2)),
        (x3, v3, masses[2], springs[2], springs[3], "Mass 3", (2, 2, 3)),
    ]

    for (
        x_data,
        v_data,
        mass,
        k_left,
        k_right,
        label,
        subplot_pos,
    ) in phase_data:
        ax = fig.add_subplot(*subplot_pos)

        # 1. Plot trajectory with time-coded colors
        for i in range(n_points - 1):
            ax.plot(
                x_data[i : i + 2],
                v_data[i : i + 2],
                color=colors[i],
                linewidth=1.5,
                alpha=0.8,
            )

        # 2. Start and end markers
        ax.plot(
            x_data[0],
            v_data[0],
            "o",
            color="lime",
            markersize=10,
            markeredgecolor="black",
            markeredgewidth=2,
            label="Start",
            zorder=5,
        )
        ax.plot(
            x_data[-1],
            v_data[-1],
            "o",
            color="red",
            markersize=10,
            markeredgecolor="black",
            markeredgewidth=2,
            label="End",
            zorder=5,
        )

        # 3. Direction arrows
        arrow_indices = np.linspace(0, n_points - 1, arrow_density, dtype=int)
        x_span = x_data.max() - x_data.min()

        for idx in arrow_indices[:-1]:
            if idx + 1 < n_points:
                dx = x_data[idx + 1] - x_data[idx]
                dy = v_data[idx + 1] - v_data[idx]
                norm = np.sqrt(dx**2 + dy**2)

                if norm > 1e-9:
                    dx_norm = dx / norm
                    dy_norm = dy / norm

                    arrow_len_x = dx_norm * (x_span * 0.1)
                    arrow_len_y = dy_norm * (x_span * 0.1)

                    arrow = FancyArrowPatch(
                        posA=(x_data[idx], v_data[idx]),
                        posB=(x_data[idx] + arrow_len_x, v_data[idx] + arrow_len_y),
                        arrowstyle="-|>",
                        mutation_scale=25,  # Large head size in points
                        color=colors[idx],
                        linewidth=1.5,
                        zorder=10,
                    )
                    ax.add_patch(arrow)

        # 4. Energy contours
        x_range = np.linspace(x_data.min() * 1.2, x_data.max() * 1.2, 100)
        v_range = np.linspace(v_data.min() * 1.2, v_data.max() * 1.2, 100)
        X, V, E = compute_energy_levels(x_range, v_range, mass, k_left, k_right)

        contours = ax.contour(
            X,
            V,
            E,
            levels=energy_contours,
            colors="gray",
            linewidths=0.8,
            alpha=0.4,
            linestyles="--",
            zorder=1,
        )
        ax.clabel(contours, inline=True, fontsize=7, fmt="E=%.2f")

        ax.set_title(
            f"Phase Space: {label} ($x_{label[-1]}, v_{label[-1]}$)",
            fontweight="bold",
            fontsize=12,
            pad=10,
        )
        ax.set_xlabel(f"$x_{label[-1]}$ (m)", fontsize=11)
        ax.set_ylabel(f"$v_{label[-1]}$ (m/s)", fontsize=11)
        ax.grid(True, alpha=0.3, linestyle=":")
        ax.legend(loc="upper right", fontsize=9)

    # =========================================================================
    # 3D CONFIGURATION SPACE WITH ENHANCEMENTS
    # =========================================================================

    ax4 = fig.add_subplot(2, 2, 4, projection="3d")

    # 1. Plot trajectory with time-coded colors
    for i in range(n_points - 1):
        ax4.plot(
            x1[i : i + 2],
            x2[i : i + 2],
            x3[i : i + 2],
            color=colors[i],
            linewidth=2,
            alpha=0.8,
        )

    # 2. Start and end markers
    ax4.scatter(
        x1[0],
        x2[0],
        x3[0],
        c="lime",
        s=150,  # type: ignore
        marker="o",
        edgecolors="black",
        linewidths=2,
        label="Start",
        zorder=10,
    )
    ax4.scatter(
        x1[-1],
        x2[-1],
        x3[-1],
        c="red",
        s=150,  # type: ignore
        marker="o",
        edgecolors="black",
        linewidths=2,
        label="End",
        zorder=10,
    )

    # 3. Direction arrows
    arrow_3d_indices = np.linspace(0, n_points - 1, arrow_density // 2, dtype=int)
    for idx in arrow_3d_indices[:-1]:
        if idx + 5 < n_points:
            arrow = Arrow3D(
                [x1[idx], x1[idx + 5]],
                [x2[idx], x2[idx + 5]],
                [x3[idx], x3[idx + 5]],
                mutation_scale=20,
                lw=2,
                arrowstyle="-|>",
                color=colors[idx],
                alpha=0.8,
            )
            ax4.add_artist(arrow)

    # 4. Projection shadows on floor planes

    x1_min, x1_max = x1.min(), x1.max()
    x2_min, x2_max = x2.min(), x2.max()
    x3_min, x3_max = x3.min(), x3.max()

    padding = 0.05
    x1_padding = (x1_max - x1_min) * padding
    x2_padding = (x2_max - x2_min) * padding
    x3_padding = (x3_max - x3_min) * padding

    ax4.plot(
        x1,
        x2,
        zs=x3_min - x3_padding,
        zdir="z",
        color="gray",
        alpha=0.3,
        lw=1,
        linestyle="--",
    )
    ax4.plot(
        x1,
        x3,
        zs=x2_max + x2_padding,
        zdir="y",
        color="gray",
        alpha=0.3,
        lw=1,
        linestyle="--",
    )
    ax4.plot(
        x2,
        x3,
        zs=x1_min - x1_padding,
        zdir="x",
        color="gray",
        alpha=0.3,
        lw=1,
        linestyle="--",
    )

    sm = plt.cm.ScalarMappable(cmap=cmap, norm=time_norm)
    sm.set_array([])
    cbar = plt.colorbar(sm, ax=ax4, pad=0.2, shrink=0.8)
    cbar.set_label("Time (s)", fontsize=10)

    ax4.set_title(
        "Configuration Space ($x_1, x_2, x_3$)", fontweight="bold", fontsize=12, pad=10
    )
    ax4.set_xlabel("$x_1$ (m)", fontsize=11)
    ax4.set_ylabel("$x_2$ (m)", fontsize=11)
    ax4.set_zlabel("$x_3$ (m)", fontsize=11)  # type: ignore
    ax4.legend(loc="upper right", fontsize=9)

    ax4.set_xlim(x1_min - x1_padding, x1_max + x1_padding)
    ax4.set_ylim(x2_min - x2_padding, x2_max + x2_padding)
    ax4.set_zlim(x3_min - x3_padding, x3_max + x3_padding)  # type: ignore
    ax4.grid(True, alpha=0.3, linestyle=":")

    def animate(frame):
        ax4.view_init(elev=20, azim=frame)  # type: ignore
        return (ax4,)

    anim = FuncAnimation(
        fig,
        animate,
        frames=np.arange(0, 360, 2),
        interval=int(1000 / fps),
        blit=False,
        repeat=False,
    )

    if save_fig:
        if filename is None:
            filename = "3mass_phase_configuration_demo_0.gif"

        save_dir = "OUTPUTS/ANIMATIONS/mass_spring_systems"
        os.makedirs(save_dir, exist_ok=True)
        filepath = os.path.join(save_dir, filename)
        try:
            print(f"Saving figure to {filepath}...")
            anim.save(filepath, writer="pillow", fps=fps)
            print("Figure saved successfully.")
        except Exception as e:
            print(f"Error saving figure: {e}")

    return fig, anim


# fig, anim = plot_3mass_phase_configuration(
#     x_init=[0.0, 0.0, -0.5],
#     simulation_time=20.0,
#     n_points=1000,
#     colormap="viridis",
#     arrow_density=18,
#     energy_contours=8,
#     save_fig=True,
#     filename="3mass_phase_configuration_demo_0.gif",
#     fps=30,
# )


def main_3Mass_phase_configuration_plots():
    """Main function to run enhanced phase space and configuration space plots for 3-mass system"""
    print("\n" + "=" * 70)
    print("ENHANCED PHASE SPACE AND CONFIGURATION SPACE PLOTS FOR 3-MASS SPRING SYSTEM")
    print("=" * 70)

    default_params = {
        "simulation_time": 20.0,
        "n_points": 1000,
        "colormap": "viridis",
        "arrow_density": 30,
        "energy_contours": 8,
        "fps": 40,
    }

    use_defaults = input("Use default parameters? (y/n): ").strip().lower() == "y"

    if use_defaults:
        simulation_time = float(default_params["simulation_time"])
        n_points = int(default_params["n_points"])
        colormap = default_params["colormap"]
        arrow_density = int(default_params["arrow_density"])
        energy_contours = int(default_params["energy_contours"])
        fps = int(default_params["fps"])

        x_init = None
        v_init = None
        masses = None
        springs = None
    else:

        def _get_float(prompt, default):
            s = input(f"{prompt} (default={default}): ").strip()
            if s == "":
                return float(default)
            try:
                return float(s)
            except ValueError:
                print(f"Invalid input. Using default {default}.")
                return float(default)

        simulation_time = _get_float(
            "Enter total simulation time (s)", default_params["simulation_time"]
        )
        n_points = int(
            _get_float("Enter number of time points", default_params["n_points"])
        )
        colormap = input(
            f"Enter colormap name (default={default_params['colormap']}): "
        ).strip()
        if colormap == "":
            colormap = default_params["colormap"]
        arrow_density = int(
            _get_float(
                "Enter number of direction arrows", default_params["arrow_density"]
            )
        )
        energy_contours = int(
            _get_float(
                "Enter number of energy contour levels",
                default_params["energy_contours"],
            )
        )
        fps = int(_get_float("Enter FPS", default_params["fps"]))

        def _get_comma_separated_array(label, n, default_val):
            """Get array values from comma-separated input"""
            use_custom = (
                input(f"Provide custom {label}? (y/n): ").strip().lower() == "y"
            )
            if not use_custom:
                return None

            example = ", ".join([str(default_val)] * n)
            prompt = (
                f"  Enter {n} values for {label} (comma-separated, e.g., {example}): "
            )

            while True:
                user_input = input(prompt).strip()

                if not user_input:
                    return None

                try:
                    values = [float(val.strip()) for val in user_input.split(",")]

                    if len(values) != n:
                        print(
                            f"Error: Expected {n} values, but got {len(values)}. Please try again."
                        )
                        continue

                    return np.array(values, dtype=float)

                except ValueError:
                    print(
                        "Invalid input. Please enter numeric values separated by commas."
                    )
                    continue

        x_init = _get_comma_separated_array("initial displacements x (m)", 3, 0.0)
        v_init = _get_comma_separated_array("initial velocities v (m/s)", 3, 0.0)
        masses = _get_comma_separated_array("masses m (kg)", 3, 1.0)
        springs = _get_comma_separated_array("spring constants k (N/m)", 4, 10.0)

    save_figure = input("Save figure? (y/n): ").strip().lower() == "y"
    filename = None
    if save_figure:
        filename_input = input(
            "Enter filename for figure (e.g., '3mass_phase_configuration.gif'): "
        ).strip()
        if filename_input:
            if not filename_input.lower().endswith(".gif"):
                filename_input += ".gif"
            filename = filename_input

    print(
        "\nStarting enhanced phase space and configuration space plots with the following parameters:"
    )
    print(f"  Simulation time: {simulation_time} s")
    print(f"  Number of time points: {n_points}")
    print(f"  Colormap: {colormap}")
    print(f"  Arrow density: {arrow_density}")
    print(f"  Energy contours: {energy_contours}")
    print(f"  FPS: {fps}")
    if x_init is not None:
        print(f"  Initial displacements: {x_init}")
    if v_init is not None:
        print(f"  Initial velocities: {v_init}")
    if masses is not None:
        print(f"  Masses: {masses}")
    if springs is not None:
        print(f"  Spring constants: {springs}")
    print(f"  Save figure: {'Yes' if save_figure else 'No'}")

    fig, anim = plot_3mass_phase_configuration(
        masses=masses,
        springs=springs,
        x_init=x_init,
        v_init=v_init,
        simulation_time=simulation_time,
        n_points=n_points,
        colormap=colormap,
        arrow_density=arrow_density,
        energy_contours=energy_contours,
        fps=fps,
        save_fig=save_figure,
        filename=filename,
    )
    return fig, anim


if __name__ == "__main__":
    fig, anim = main_3Mass_phase_configuration_plots()

### 4.6 Poincar√© Sections (Example: N=3)

A **Poincar√© section** is a way to reduce a continuous-time dynamical system to a "stroboscopic" discrete map: instead of plotting the full trajectory in a high‚Äëdimensional state space, you record the system state only when it crosses a chosen **surface of section** (a "trigger condition"). The resulting set of intersection points is much easier to interpret than the full trajectory.

**Physical insights you can draw**
- **Periodic motion**: the section shows a small finite set of points (often 1 point, or a few points if you hit the orbit multiple times per cycle).
- **Quasi-periodic motion (multiple incommensurate frequencies)**: the section forms a closed curve (intersection of a torus with the section plane).
- **Chaotic motion** (in nonlinear systems): the section fills an area with scattered points ("chaotic sea"), sometimes mixed with remaining invariant curves (KAM-type structure).

For our **3-mass linear mass‚Äìspring system**, the dynamics are conservative and (for generic initial conditions) a superposition of 3 normal modes, so we usually expect **quasi-periodic structure** (curves) unless frequencies become commensurate (then we can see repeating finite patterns).

**What "target/trigger conditions" are used in the code**

In `plot_poincare_section_grid(...)` we build **four** different sections (2√ó2 grid), each defined by a different crossing test, then plots different coordinates at those instants

1) **Trigger:** $x_2 = 0$ with **upward crossing** ($x_2$ goes from negative to positive, which implies $v_2>0$ at the crossing)  
    **Plot:** $(x_1, v_1)$

2) **Trigger:** $x_1 = 0$ with **upward crossing** ($x_1$ negative ‚Üí positive, so $v_1>0$)  
    **Plot:** $(x_2, v_2)$
3) **Trigger:** $x_1 = x_2$ (a "configuration synchronization" condition; detects any sign change of $x_1-x_2$, i.e., crossing in either direction)  
    **Plot:** $(x_3, v_3)$

4) **Trigger:** $v_1 = v_2$ with **crossing from below** ($v_1-v_2$ negative ‚Üí positive)  
    **Plot:** $(x_1, x_2)$
**How it's implemented (in code terms)**
- The code first integrates the ODE with `solve_ivp(n_mass_spring_derivatives, ...)` to get sampled arrays $x_i(t)$, $v_i(t)$.
- For each trigger, it scans consecutive samples and detects a crossing by checking a sign change (e.g., `x2[i] < 0 and x2[i+1] > 0`).
- It then uses **linear interpolation** to estimate the crossing time fraction
  $$
  \text{fraction}=\frac{-s_i}{s_{i+1}-s_i}
  $$
  where $s$ is the trigger signal (like $x_2$, $x_1$, $x_1-x_2$, or $v_1-v_2$). That fraction is used to interpolate the plotted variables (and the time) at the crossing.
- Finally it scatter-plots the sampled points, colored by crossing time ("plasma" colormap), and prints how many section points were found.

In [None]:
def plot_poincare_section_grid(
    masses=None,
    springs=None,
    x_init=None,
    v_init=None,
    simulation_time=200.0,
    n_points=2000,
    save_fig=False,
    filename=None,
):
    """
    Generates a 2x2 grid of Poincar√© sections for a 3-mass system
    with different trigger conditions.
    """
    N_mass = 3

    # Default parameters if None
    if masses is None:
        masses = np.ones(N_mass)
    else:
        masses = np.array(masses)

    if springs is None:
        springs = 10.0 * np.ones(N_mass + 1)
    else:
        springs = np.array(springs)

    if x_init is None:
        x_init = np.array([0.5, -0.3, 0.2])
    else:
        x_init = np.array(x_init)

    if v_init is None:
        v_init = np.zeros(N_mass)
    else:
        v_init = np.array(v_init)

    # Initial state vector
    y0 = []
    for i in range(N_mass):
        y0.extend([x_init[i], v_init[i]])

    # Solve ODE
    print("Solving ODE for Poincar√© sections (this may take a moment)...")
    t_span = (0, simulation_time)
    t_eval = np.linspace(0, simulation_time, n_points)

    sol = solve_ivp(
        n_mass_spring_derivatives,
        t_span,
        y0,
        args=(masses, springs),
        method="RK45",
        t_eval=t_eval,
        rtol=1e-5,
        atol=1e-7,
    )

    if not sol.success:
        print("ODE solver failed.")
        return

    y = sol.y
    # Extract all state variables
    x1 = y[0]
    v1 = y[1]
    x2 = y[2]
    v2 = y[3]
    x3 = y[4]
    v3 = y[5]

    # Create figure with 2x2 grid
    fig, axes = plt.subplots(2, 2, figsize=(14, 12))
    fig.suptitle(
        f"Poincar√© Sections for Coupled Mass-Spring System with 3 Masses\n"
        f"Initial: x=[{x_init[0]:.2f}, {x_init[1]:.2f}, {x_init[2]:.2f}], "
        f"v=[{v_init[0]:.2f}, {v_init[1]:.2f}, {v_init[2]:.2f}], T={simulation_time}s",
        fontsize=14,
        fontweight="bold",
    )
    fig.subplots_adjust(hspace=0.3, wspace=0.25)

    # =========================================================================
    # SUBPLOT 1: x‚ÇÇ = 0, v‚ÇÇ > 0 ‚Üí Plot (x‚ÇÅ, v‚ÇÅ)
    # =========================================================================
    poincare_1 = {"x": [], "y": [], "time": []}

    for i in range(len(x2) - 1):
        if x2[i] < 0 and x2[i + 1] > 0:
            fraction = -x2[i] / (x2[i + 1] - x2[i])
            poincare_1["x"].append(x1[i] + fraction * (x1[i + 1] - x1[i]))
            poincare_1["y"].append(v1[i] + fraction * (v1[i + 1] - v1[i]))
            poincare_1["time"].append(
                t_eval[i] + fraction * (t_eval[i + 1] - t_eval[i])
            )

    ax1 = axes[0, 0]
    if len(poincare_1["x"]) > 0:
        # Color by time
        scatter = ax1.scatter(
            poincare_1["x"],
            poincare_1["y"],
            c=poincare_1["time"],
            cmap="plasma",
            s=50,
            alpha=0.7,
            edgecolors="black",
            linewidths=0.5,
        )

    ax1.set_title(
        "Trigger: $x_2 = 0$, $v_2 > 0$\nPlot: $(x_1, v_1)$",
        fontsize=12,
        fontweight="bold",
    )
    ax1.set_xlabel("$x_1$ (m)", fontsize=11)
    ax1.set_ylabel("$v_1$ (m/s)", fontsize=11)
    ax1.grid(True, alpha=0.3, linestyle="--")
    ax1.axhline(0, color="k", linewidth=0.5, alpha=0.4)
    ax1.axvline(0, color="k", linewidth=0.5, alpha=0.4)
    ax1.text(
        0.02,
        0.98,
        f"N = {len(poincare_1['x'])} points",
        transform=ax1.transAxes,
        fontsize=9,
        va="top",
        bbox=dict(boxstyle="round", facecolor="wheat", alpha=0.8),
    )

    # =========================================================================
    # SUBPLOT 2: x‚ÇÅ = 0, v‚ÇÅ > 0 ‚Üí Plot (x‚ÇÇ, v‚ÇÇ)
    # =========================================================================
    poincare_2 = {"x": [], "y": [], "time": []}

    for i in range(len(x1) - 1):
        if x1[i] < 0 and x1[i + 1] > 0:
            fraction = -x1[i] / (x1[i + 1] - x1[i])
            poincare_2["x"].append(x2[i] + fraction * (x2[i + 1] - x2[i]))
            poincare_2["y"].append(v2[i] + fraction * (v2[i + 1] - v2[i]))
            poincare_2["time"].append(
                t_eval[i] + fraction * (t_eval[i + 1] - t_eval[i])
            )

    ax2 = axes[0, 1]
    if len(poincare_2["x"]) > 0:
        scatter = ax2.scatter(
            poincare_2["x"],
            poincare_2["y"],
            c=poincare_2["time"],
            cmap="plasma",
            s=50,
            alpha=0.7,
            edgecolors="black",
            linewidths=0.5,
        )
        fig.colorbar(scatter, ax=axes[:, 1], label="Time (s)", shrink=0.5)

    ax2.set_title(
        "Trigger: $x_1 = 0$, $v_1 > 0$\nPlot: $(x_2, v_2)$",
        fontsize=12,
        fontweight="bold",
    )
    ax2.set_xlabel("$x_2$ (m)", fontsize=11)
    ax2.set_ylabel("$v_2$ (m/s)", fontsize=11)
    ax2.grid(True, alpha=0.3, linestyle="--")
    ax2.axhline(0, color="k", linewidth=0.5, alpha=0.4)
    ax2.axvline(0, color="k", linewidth=0.5, alpha=0.4)
    ax2.text(
        0.02,
        0.98,
        f"N = {len(poincare_2['x'])} points",
        transform=ax2.transAxes,
        fontsize=9,
        va="top",
        bbox=dict(boxstyle="round", facecolor="wheat", alpha=0.8),
    )

    # =========================================================================
    # SUBPLOT 3: x‚ÇÅ = x‚ÇÇ (configuration synchronization) ‚Üí Plot (x‚ÇÉ, v‚ÇÉ)
    # =========================================================================
    poincare_3 = {"x": [], "y": [], "time": []}

    for i in range(len(x1) - 1):
        diff_curr = x1[i] - x2[i]
        diff_next = x1[i + 1] - x2[i + 1]
        # Detect when x1 crosses x2 (either direction)
        if diff_curr * diff_next < 0:  # Sign change indicates crossing
            # Linear interpolation to find exact crossing point
            fraction = -diff_curr / (diff_next - diff_curr)
            poincare_3["x"].append(x3[i] + fraction * (x3[i + 1] - x3[i]))
            poincare_3["y"].append(v3[i] + fraction * (v3[i + 1] - v3[i]))
            poincare_3["time"].append(
                t_eval[i] + fraction * (t_eval[i + 1] - t_eval[i])
            )

    ax3 = axes[1, 0]
    if len(poincare_3["x"]) > 0:
        scatter = ax3.scatter(
            poincare_3["x"],
            poincare_3["y"],
            c=poincare_3["time"],
            cmap="plasma",
            s=50,
            alpha=0.7,
            edgecolors="black",
            linewidths=0.5,
        )

    ax3.set_title(
        "Trigger: $x_1 = x_2$ (config sync)\nPlot: $(x_3, v_3)$",
        fontsize=12,
        fontweight="bold",
    )
    ax3.set_xlabel("$x_3$ (m)", fontsize=11)
    ax3.set_ylabel("$v_3$ (m/s)", fontsize=11)
    ax3.grid(True, alpha=0.3, linestyle="--")
    ax3.axhline(0, color="k", linewidth=0.5, alpha=0.4)
    ax3.axvline(0, color="k", linewidth=0.5, alpha=0.4)
    ax3.text(
        0.02,
        0.98,
        f"N = {len(poincare_3['x'])} points",
        transform=ax3.transAxes,
        fontsize=9,
        va="top",
        bbox=dict(boxstyle="round", facecolor="wheat", alpha=0.8),
    )

    # =========================================================================
    # SUBPLOT 4: v‚ÇÅ = v‚ÇÇ (crossing from below) ‚Üí Plot (x‚ÇÅ, x‚ÇÇ)
    # =========================================================================
    poincare_4 = {"x": [], "y": [], "time": []}

    for i in range(len(v1) - 1):
        diff_curr = v1[i] - v2[i]
        diff_next = v1[i + 1] - v2[i + 1]
        if diff_curr < 0 and diff_next > 0:  # v1 crosses v2 from below
            fraction = -diff_curr / (diff_next - diff_curr)
            poincare_4["x"].append(x1[i] + fraction * (x1[i + 1] - x1[i]))
            poincare_4["y"].append(x2[i] + fraction * (x2[i + 1] - x2[i]))
            poincare_4["time"].append(
                t_eval[i] + fraction * (t_eval[i + 1] - t_eval[i])
            )

    ax4 = axes[1, 1]
    if len(poincare_4["x"]) > 0:
        scatter = ax4.scatter(
            poincare_4["x"],
            poincare_4["y"],
            c=poincare_4["time"],
            cmap="plasma",
            s=50,
            alpha=0.7,
            edgecolors="black",
            linewidths=0.5,
        )

    ax4.set_title(
        "Trigger: $v_1 = v_2$ (from below)\nPlot: $(x_1, x_2)$",
        fontsize=12,
        fontweight="bold",
    )
    ax4.set_xlabel("$x_1$ (m)", fontsize=11)
    ax4.set_ylabel("$x_2$ (m)", fontsize=11)
    ax4.grid(True, alpha=0.3, linestyle="--")
    ax4.axhline(0, color="k", linewidth=0.5, alpha=0.4)
    ax4.axvline(0, color="k", linewidth=0.5, alpha=0.4)
    ax4.text(
        0.02,
        0.98,
        f"N = {len(poincare_4['x'])} points",
        transform=ax4.transAxes,
        fontsize=9,
        va="top",
        bbox=dict(boxstyle="round", facecolor="wheat", alpha=0.8),
    )

    plt.show()

    if save_fig:
        if filename is None:
            x_str = "_".join([f"{xi:.2f}" for xi in x_init])
            filename = f"poincare_section_xinit_{x_str}.png"

        save_dir = "OUTPUTS/FIGURES/mass_spring_systems"
        os.makedirs(save_dir, exist_ok=True)
        filepath = os.path.join(save_dir, filename)
        try:
            print(f"Saving figure to {filepath}...")
            fig.savefig(filepath, dpi=300, bbox_inches="tight")
            print("Figure saved successfully.")
        except Exception as e:
            print(f"Error saving figure: {e}")

    return fig


def main_poincare_section_plots():
    """Main function to run Poincar√© section grid plots for 3-mass system"""
    print("\n" + "=" * 70)
    print("POINCAR√â SECTION GRID PLOTS FOR 3-MASS SPRING SYSTEM")
    print("=" * 70)

    default_params = {
        "simulation_time": 200.0,
        "n_points": 2000,
    }

    use_defaults = input("Use default parameters? (y/n): ").strip().lower() == "y"

    if use_defaults:
        simulation_time = float(default_params["simulation_time"])
        n_points = int(default_params["n_points"])

        x_init = None
        v_init = None
        masses = None
        springs = None
    else:

        def _get_float(prompt, default):
            s = input(f"{prompt} (default={default}): ").strip()
            if s == "":
                return float(default)
            try:
                return float(s)
            except ValueError:
                print(f"Invalid input. Using default {default}.")
                return float(default)

        simulation_time = _get_float(
            "Enter total simulation time (s)", default_params["simulation_time"]
        )
        n_points = int(
            _get_float("Enter number of time points", default_params["n_points"])
        )

        def _get_comma_separated_array(label, n, default_val):
            """Get array values from comma-separated input"""
            use_custom = (
                input(f"Provide custom {label}? (y/n): ").strip().lower() == "y"
            )
            if not use_custom:
                return None

            example = ", ".join([str(default_val)] * n)
            prompt = (
                f"  Enter {n} values for {label} (comma-separated, e.g., {example}): "
            )

            while True:
                user_input = input(prompt).strip()

                if not user_input:
                    return None

                try:
                    values = [float(val.strip()) for val in user_input.split(",")]

                    if len(values) != n:
                        print(
                            f"Error: Expected {n} values, but got {len(values)}. Please try again."
                        )
                        continue

                    return np.array(values, dtype=float)

                except ValueError:
                    print(
                        "Invalid input. Please enter numeric values separated by commas."
                    )
                    continue

        x_init = _get_comma_separated_array("initial displacements x (m)", 3, 0.0)
        v_init = _get_comma_separated_array("initial velocities v (m/s)", 3, 0.0)
        masses = _get_comma_separated_array("masses m (kg)", 3, 1.0)
        springs = _get_comma_separated_array("spring constants k (N/m)", 4, 10.0)

    save_figure = input("Save figure? (y/n): ").strip().lower() == "y"
    filename = None
    if save_figure:
        filename_input = input(
            "Enter filename for figure (e.g., 'poincare_section.png'): "
        ).strip()
        if filename_input:
            if not filename_input.lower().endswith(".png"):
                filename_input += ".png"
            filename = filename_input

    print("\nStarting Poincar√© section grid plots with the following parameters:")
    print(f"  Simulation time: {simulation_time} s")
    print(f"  Number of time points: {n_points}")
    if x_init is not None:
        print(f"  Initial displacements: {x_init}")
    if v_init is not None:
        print(f"  Initial velocities: {v_init}")
    if masses is not None:
        print(f"  Masses: {masses}")
    if springs is not None:
        print(f"  Spring constants: {springs}")
    print(f"  Save figure: {'Yes' if save_figure else 'No'}")

    fig = plot_poincare_section_grid(
        masses=masses,
        springs=springs,
        x_init=x_init,
        v_init=v_init,
        simulation_time=simulation_time,
        n_points=n_points,
        save_fig=save_figure,
        filename=filename,
    )
    return fig


if __name__ == "__main__":
    figure = main_poincare_section_plots()

## 5. References

1. [Idema, T. (TU Delft Open). *8.4: Coupled Oscillators*. Physics LibreTexts.](https://phys.libretexts.org/Bookshelves/University_Physics/Mechanics_and_Relativity_(Idema)/08%3A_Oscillations/8.04%3A_Coupled_Oscillators)
2. [Lecture 4: Coupled Oscillators. Department of Physics, University of Toronto](https://www.physics.utoronto.ca/~sandra/PHY238Y/Lectures/Lect4_Coupl_osc.pdf)
3. [Fendt, W. *Coupled Pendula Simulation*. Walter Fendt Physics Applets.](https://www.walter-fendt.de/html5/phen/coupledpendula_en.htm)
4. [Physics LibreTexts. *8.4: Coupled Oscillators and Normal Modes*. University of California, Davis.](https://phys.libretexts.org/Courses/University_of_California_Davis/UCD%3A_Physics_9HA__Classical_Mechanics/8%3A_Small_Oscillations/8.4%3A_Coupled_Oscillators_and_Normal_Modes)
5. [Coupled Pendulum - Normal Modes & Frequencies | Lagrangian Approach (YouTube)](https://www.youtube.com/watch?v=M-hLBxx7MeE)
6. [The Dance of Coupled Oscillators | Understanding Normal Modes & Frequencies - YouTube](https://www.youtube.com/watch?v=4WhrNjg3I_o)



