# Quick Start Guide



## Basics

Import `mlx.core` and make an `array`

In [1]:
import mlx.core as mx
a = mx.array([1, 2, 3,4])
a.shape

(4,)

In [3]:
a.dtype

mlx.core.int32

In [6]:
b = mx.array([1.0,2.0,3.0,4.0])
b.dtype
b.shape

(4,)

Operations in MLX are lazy. The outputs of MLX are not computed until they are needed. To force an array to be evalated use `eval()`. Arrays will automatically be evaluated in few cases. For example, inspecting a scalar with `array.item()`, printing an array or converting an array from  `array.numpy.ndarray` all automatically evaluate the array.

In [7]:
c = a+b  # c is not evaluated yet
mx.eval(c)  # c is evaluated now
print(c)
# or 

c = a+b # c is not evaluated yet
print(c) # c is evaluated now

# or

c = a+b # c is not evaluated yet
import numpy as np
np.array(c) # c is evaluated now
print(c) # c is evaluated now

array([2, 4, 6, 8], dtype=float32)
array([2, 4, 6, 8], dtype=float32)
array([2, 4, 6, 8], dtype=float32)


### Why Lazy Evaluation?

When you perform operations in MLX, no computation actually happens. Instead a compute graph is recorded. The actual computation only happens if an `eval()` is performed.

MLX uses lazy evaluation because it has some nice features


### Transforming Compute Graphhs

- What is a Compute Graph?

A compute graph is like a map of all the operations needed to perform a computation. It shows how inputs are processed step-by-step to produce outputs.

- What is Lazy Evaluation?

Lazy evaluation means that instead of performing calculations immediately, we "record" the operations in a compute graph and wait to execute them until they are actually needed.

- Why is This Useful?
Recording operations lets us do powerful things like:

1. Transform Functions: Tools like grad() (for computing gradients) and vmap() (for vectorized computations) can use the graph to modify or optimize the original function.
2. Optimize the Graph: The graph can be reorganized to make computations faster or more efficient before running them.

- How is This Done in MLX?
MLX does not pre-compile or store reusable compute graphs. Instead, it creates them dynamically as you run your program.
Lazy evaluation makes it easier to add features like compilation in the future, which could improve performance

### Only Compute What You Use

- What Does This Mean?

In MLX, if an output from the compute graph is not needed, it won’t be computed. For example:
If you create a function with multiple outputs but only use one of them, MLX will skip the unnecessary calculations.

- Why Is This Helpful?

It saves time and computational resources because you don’t have to manually control or optimize which outputs are computed. MLX handles this automatically.

### When to use `eval()`?

The key decision in using `eval()` is balancing the size of the compute graph and the overhead of graph evaluation. 


1. **Trade-offs in Using `eval()`**


- If you call eval() after every small computation, you incur fixed overhead repeatedly.
- Example: The provided "bad idea" code snippet does this:

In [None]:
for _ in range(100):
    a = a + b
    mx.eval(a)  # Frequent evaluations
    b = b * 2
    mx.eval(b)  # Frequent evaluations


- Drawbacks:

    - High overhead due to repeated graph evaluations.
    - Inefficient use of resources.


- **Large Graphs with Delayed `eval()` Calls:**
    - If you delay `eval()` too long, the compute graph can become excessively large.
    - While large graphs are computationally correct, they can become costly due to increased memory usage and graph traversal overhead.

Use `eval()` at a balance point where the compute graph is large enough to minimize frequent overhead, but not so large that it becomes inefficient.

## Function and Graph Transformations

MLX has standard function transformations like `grad()` and `vmap()`. Transformations can be composed arbitrarily. For example `grad(vmap(grad(fn)))` is allowed

- `grad(fn)` (Gradient):
    - Computes the derivative of the function `fn` with respect to its input.

Example:

In [8]:
x = mx.array(0.0)
mx.grad(mx.sin)(x)  # Computes the derivative of sin(x) at x = 0
# Output: array(1, dtype=float32)


array(1, dtype=float32)

- `vmap(fn)` (Vectorized Map):

    - Applies `fn` over a batch of inputs simultaneously, vectorizing the computation for efficiency.

You can combine transformations to achieve more complex behaviors.

In [9]:
mx.grad(mx.grad(mx.sin))(x)
# This computes the second derivative of sin(x), which is -sin(x) at x = 0.
# Output: array(-0, dtype=float32)


array(-0, dtype=float32)

### Advanced Gradient Transformations

MLX offers advanced transformations for working with gradients, such as `vjp` (Vector-Jacobian Product) and `jvp` (Jacobian-Vector Product). These tools are powerful for sensitivity analysis and computational efficiency when working with derivatives in machine learning, physics simulations, and optimization tasks.


### **a. `vjp(fn)` - Vector-Jacobian Product**

- **What It Does:**
  - Computes the product of a vector with the Jacobian of a function $ f(x) $.
  - If $ J $ is the Jacobian matrix of $ f(x) $ and $ v $ is the vector, it calculates $ v \cdot J $.

- **Use Case:**
  - Efficient for propagating gradients backward in reverse-mode differentiation.


In [11]:
# Define a simple function
def fn(x):
    return x ** 2

# Input
x = mx.array(3.0)
v = mx.array(2.0)  # The vector we want to multiply with the Jacobian

# Compute jvp
output, jvp_result = mx.jvp(fn, (x,), (v,))

print(f"Output of fn(x): {output}")  # Output: 9.0
print(f"Jacobian-Vector Product (J * v): {jvp_result}")  # Output: 6.0


Output of fn(x): [array(9, dtype=float32)]
Jacobian-Vector Product (J * v): [array(12, dtype=float32)]


- **Explanation:**
  - The function $ f(x) = x^2 $ has a derivative (Jacobian) $ J = 2x $.
  - For $ x = 3 $, $ J = 6 $, and $ v = 2 $, the `vjp` calculates $ v \cdot J = 2 \cdot 6 = 12 $.


### **b. `jvp(fn)` - Jacobian-Vector Product**

- **What It Does:**
  - Computes the product of the Jacobian of a function $ f(x) $ with a vector $ v $.
  - If $ J $ is the Jacobian, it calculates $ J \cdot v $.

- **Use Case:**
  - Useful for forward-mode differentiation or analyzing how small changes in inputs affect outputs.


In [12]:
# Define a simple function
def fn(x):
    return x ** 2

# Input
x = mx.array(3.0)
v = mx.array(1.0)  # The vector we want to multiply with the Jacobian

# Compute jvp
output, jvp_result = mx.jvp(fn, (x,), (v,))

print(f"Output of fn(x): {output}")  # Output: 9.0
print(f"Jacobian-Vector Product (J * v): {jvp_result}")  # Output: 6.0


Output of fn(x): [array(9, dtype=float32)]
Jacobian-Vector Product (J * v): [array(6, dtype=float32)]


- **Explanation:**
  - The Jacobian of $ f(x) = x^2 $ is $ J = 2x $. 
  - For $ x = 3 $ and $ v = 1 $, the `jvp` computes $ J \cdot v = 6 \cdot 1 = 6 $.


### **c. `value_and_grad(fn)`**

- **What It Does:**
  - Computes both the value of a function and its gradient in a single call.
  - This is efficient because it avoids redundant computations.

- **Use Case:**
  - Ideal for machine learning tasks where both the loss value and its gradient are required.


In [13]:
# Define a simple function
def loss_fn(x):
    return (x - 3) ** 2

# Input
x = mx.array(5.0)

# Compute value and gradient
value, grad = mx.value_and_grad(loss_fn)(x)

print(f"Value of the function: {value}")  # Output: 4.0
print(f"Gradient of the function: {grad}")  # Output: 4.0


Value of the function: array(4, dtype=float32)
Gradient of the function: array(4, dtype=float32)


- **Explanation:**
  - The loss function $f(x) = (x - 3)^2 $ has:
    - Value $f(5) = (5 - 3)^2 = 4 $.
    - Gradient $f'(x) = 2(x - 3) $, so $f'(5) = 2 \cdot (5 - 3) = 4 $.

---


1. **`vjp(fn)`**: Use for propagating gradients backward efficiently.
2. **`jvp(fn)`**: Use for forward-mode sensitivity analysis.
3. **`value_and_grad(fn)`**: Use to compute both the output and gradient of a function in one step, saving time and resources.





