# 1. Functional Programming vs. Object-Oriented Programming (OOP)

## Introduction: Why are we here?
Before diving into Flax and JAX, we must shift our mental model. Most of us are trained 
in **Object-Oriented Programming (OOP)**, where we bundle data and logic together into "objects."

However, modern Machine Learning (ML) frameworks, specifically JAX, rely 
heavily on **Functional Programming (FP)**. If you try to write JAX code using standard
 OOP habits, you will fight the framework at every step.

This notebook will clarify the difference between these two paradigms and explain 
*why* FP is the preferred choice for high-performance scientific computing.

---

## The Two Mental Models

### 1. The OOP Model: "Stateful Objects"
In OOP, objects encapsulate state (variables) and behavior (methods). 
When you call a method, the object often changes internally. This is called a **Side Effect**.

* **Analogy:** A car's odometer. You drive the car (method), and the mileage counter 
updates (internal state change). You don't need to manually calculate the new mileage; 
the car "remembers" it.
* **The Trap:** If you run a function twice, you might get different results because 
the internal state has changed. This makes debugging physics simulations or distributed 
training loops difficult.

### 2. The Functional Model: "Pure Functions"
In FP, data and behavior are strictly separated.
1.  **Pure Functions:** A function's output depends *only* on its input. 
It has no memory of history. $f(x)$ is always $y$.
2.  **Immutability:** We do not modify variables in place. Instead, we generate 
*new* variables representing the updated state.

* **Analogy:** A physics equation. $Position(t)$ gives you the location at time $t$. 
The equation doesn't "change" the universe; it simply describes a value based on inputs.


In [None]:
# OOP Approach
class StatefulCounter:
    def __init__(self, count=0):
        self.count = count

    def increment(self):
        self.count += 1  # Side effect: internal state mutation
        return self.count

counter = StatefulCounter()
print(counter.increment()) # 1
print(counter.increment()) # 2
# Problem: If I run counter.increment() again, the result depends on history.

In [None]:
# Functional Approach
# State is just data (a dictionary, a tuple, or a dataclass)
State = int

def increment(count: State) -> State:
    return count + 1

# The user manages the state explicitly
current_count = 0
next_count = increment(current_count)

print(current_count) # 0 (Immutability preserved)
print(next_count)    # 1

### Functional vs Object-Oriented Programming

In the OOP StatefulCounter example, the method `increment()` takes zero arguments 
but produces different outputs every time you call it.

```python
counter.increment() # Returns 1
counter.increment() # Returns 2
counter.increment() # Returns 3
```

Here is why this "dependence on history" is a problem, specifically in the context of 
Scientific Machine Learning:

#### The "Hidden State" Trap

In the OOP version, the output depends on a variable (`self.count`) that lives inside 
the object. 
To understand what `counter.increment()` will return, you have to know exactly how 
many times it was called previously.

*The Scientific Issue:* Imagine you are running a physics simulation or training a 
neural network in a Jupyter notebook.

You run the cell that updates the weights once. The loss goes down.

You accidentally run the cell again (or run cells out of order). The weights update again.

Your model state is now different from what you think it is, and you cannot reproduce 
that specific result without restarting the kernel and running everything in exact order.

#### Purity and Reproducibility

In math and science, we expect functions to be Pure. If we define a function $f(x)=x+1$ , 
then $f(2)$ is always $3$. It doesn't matter if it's Tuesday, or if you calculated $f(100)$ 
five minutes ago.

**OOP (Impure)**: `increment()` acts like a function where `f()` changes every time.

**FP (Pure)**: `increment(2)` is always `3`.


#### Parallelization (The "Race Condition")

Scientific ML requires massive parallelization (running on thousands of GPU cores).

*OOP*: If two different GPU cores try to run `counter.increment()` on the same object 
at the same time, they might fight over `self.count` (a race condition). You might end 
up with `1` when you expected `2`.

*FP*: Since `increment(count)` creates a new value rather than modifying the old one, 
it is safe to run on a million cores simultaneously without them interfering with each other.

#### Summary

When the result depends on history (implicit state), you lose determinism.

**OOP**: "Update the world." (Easier to write, harder to debug/parallelize).

**FP**: "Given the world as it is now, calculate the next world."