This notebook demonstrates:

- Difference between object-oriented and functional programming, especially for simulations
- How to simulate a physical system (oscillator) using both OOP and pure functional styles
- Handling mutable vs. immutable state in numerical code
- Advantages of pure functions for reproducibility and parallelism in JAX/Flax


### OOP approach for simulating a simple oscillator

In [3]:
import matplotlib.pyplot as plt
import numpy as np

# --- OOP APPROACH ---
class OscillatorSim:
    def __init__(self, k=1.0, gamma=0.1, x0=1.0, v0=0.0):
        self.k = k
        self.gamma = gamma
        self.x = x0
        self.v = v0
        self.history = []  # Side effect: coupling storage with computation

    def step(self, dt=0.1):
        # Semi-implicit Euler
        acc = -self.k * self.x - self.gamma * self.v
        self.v += acc * dt  # Mutation
        self.x += self.v * dt  # Mutation
        self.history.append((self.x, self.v)) # Hidden state growth

    def run(self, steps):
        for _ in range(steps):
            self.step()
        return self.history

# Usage
sim = OscillatorSim(k=2.0)
traj = sim.run(100)
# Problem: 'sim' is now "dirty". To run again with different k,
# I must instantiate a NEW object. I cannot easily "rewind".

#### Functional approach for simulating a simple oscillator

In [1]:
from dataclasses import dataclass

# 1. Define State explicitly (Data, not Object)
# frozen=True means we CANNOT do state.x = 5. It will raise an error.
@dataclass(frozen=True)
class State:
    x: float
    v: float

@dataclass(frozen=True)
class Params:
    k: float
    gamma: float

# 2. Pure Update Function (Dynamics)
# Notice: No "self". Just inputs -> outputs.
def step_physics(state: State, params: Params, dt: float = 0.1) -> State:
    acc = -params.k * state.x - params.gamma * state.v

    # We calculate NEW values
    new_v = state.v + acc * dt
    new_x = state.x + new_v * dt

    # We return a NEW State object. The old 'state' is untouched.
    return State(x=new_x, v=new_v)

# Usage
initial_state = State(x=1.0, v=0.0)
params = Params(k=2.0, gamma=0.1)

# Explicit Loop (We control the history, not the object)
trajectory = [initial_state]
current_state = initial_state

for _ in range(100):
    current_state = step_physics(current_state, params)
    trajectory.append(current_state)

#### Run an ensemble of simulations in OOP

In [4]:
# OOP Ensemble
k_values = [1.0, 2.0, 3.0, 4.0, 5.0]
sims = [OscillatorSim(k=k) for k in k_values]
results = []

for sim in sims:
    results.append(sim.run(100)) # Sequential execution

#### Run an ensemble of simulations in FP

In [5]:
# FP Ensemble
k_values = [1.0, 2.0, 3.0, 4.0, 5.0]

def run_simulation(k_val):
    # This function is completely isolated (Pure)
    p = Params(k=k_val, gamma=0.1)
    s = State(x=1.0, v=0.0)
    traj = []
    for _ in range(100):
        s = step_physics(s, p)
        traj.append(s)
    return traj

# Map is a functional concept.
# This single line replaces the loop management above.
ensemble_results = list(map(run_simulation, k_values))

#### Discussion: Why write "uglier" code?

You might notice that the FP version requires more boilerplate. 
We have to define `State` classes and pass state into every function. 
The OOP version hides all that.

So why do we do it?

*Transparency*: In OOP, `sim.step()` changes things "under the hood." In FP, `new_state = step(old_state)`
 makes the data flow 100% visible.

The "Wall": OOP is easy for one simulation. 
But if you want to run 10,000 simulations in parallel on a GPU, or calculate the 
derivative of the simulation with respect to gravity ( 
$d Gravity/d Pos$ ), OOP hits a wall.

The Payoff: The "clunky" FP code we just wrote is JAX-ready. 
