# üìò Notebook 1: JAX Fundamentals - Your First Steps

Welcome to your JAX journey! This notebook introduces you to JAX, Google's high-performance numerical computing library.

## üéØ What You'll Learn (20-30 minutes)

By the end of this notebook, you'll understand:
- ‚úÖ What JAX is and why it exists
- ‚úÖ How JAX arrays work (and how they're like NumPy)
- ‚úÖ The **key difference**: immutability (and why it matters)
- ‚úÖ Basic operations: math, statistics, reshaping
- ‚úÖ How to convert between JAX and NumPy

## ü§î What is JAX?

Think of JAX as **"NumPy with superpowers"**:

```
JAX = NumPy + GPU/TPU support + Special transformations
```

**The three key capabilities:**
1. **J** - **JIT compilation**: Makes your code 10-100x faster
2. **A** - **Automatic differentiation**: Computes gradients automatically (essential for ML)
3. **X** - **XLA**: Google's optimizing compiler that runs on CPU/GPU/TPU

**In simple terms:** If you know NumPy, you already know most of JAX!

## üîë Key Characteristics

### 1. NumPy Compatible
```python
import jax.numpy as jnp  # JAX's version of NumPy

# Works just like NumPy!
a = jnp.array([1, 2, 3])
b = jnp.array([4, 5, 6])
result = a + b  # [5, 7, 9]
```

### 2. Immutable Arrays (Functional Programming)
**The #1 difference from NumPy:**
- NumPy arrays can be modified in place: `arr[0] = 10`  ‚úÖ
- JAX arrays **cannot** be modified in place: `arr[0] = 10`  ‚ùå

**Why?** JAX uses functional programming for better optimization and parallelization.

**How to update?** Use `.at[].set()`:
```python
# NumPy way (doesn't work in JAX):
# arr[0] = 10  ‚ùå

# JAX way (creates a NEW array):
new_arr = arr.at[0].set(10)  ‚úÖ
```

### 3. Hardware Accelerated
JAX automatically runs on GPU/TPU if available (no code changes needed!).

### 4. Function Transformations
JAX can transform your functions:
- `jax.jit()` - Make code faster
- `jax.grad()` - Compute gradients
- `jax.vmap()` - Automatic batching

We'll cover these in later notebooks!

### 5. Pure Functions Work Best
JAX prefers functions that:
- Take inputs, return outputs
- Don't modify global variables
- Don't have side effects (like printing)

Don't worry if this sounds confusing - we'll show examples throughout!

## üìö What's in This Notebook?

This notebook covers **JAX basics** - the foundation for everything else:

1. **Creating JAX arrays** (just like NumPy)
2. **Arithmetic operations** (addition, multiplication, etc.)
3. **Mathematical functions** (sin, cos, log, exp, etc.)
4. **Statistical functions** (mean, std, median, etc.)
5. **Array manipulation** (reshaping, stacking, slicing)
6. **Immutability** (the key difference from NumPy)
7. **Converting between JAX and NumPy**

## üöÄ Let's Get Started!

**Ready?** Run each code cell below and read the comments. Every line is explained!

## JAX as NumPy - Basic Operations

If you've used NumPy, this will look familiar. JAX arrays behave like NumPy arrays with one key difference: they're immutable. You can't modify them in place, but you get automatic GPU/TPU support in return.

In [1]:
# =============================================================================
# JAX AS NUMPY - COMPREHENSIVE DEMONSTRATION
# =============================================================================

import jax
import jax.numpy as jnp
import numpy as np

print("=" * 70)
print("ARRAY CREATION AND BASIC OPERATIONS")
print("=" * 70)

# -----------------------------------------------------------------------------
# Creating JAX Arrays
# -----------------------------------------------------------------------------
# JAX arrays are created just like NumPy arrays but are immutable
a = jnp.array([1.0, 2.0, 3.0])
b = jnp.array([4.0, 5.0, 6.0])

print("\nüì¶ Array Creation:")
print(f"Array a: {a}")
print(f"Array b: {b}")
print(f"Array a shape: {a.shape}, dtype: {a.dtype}")

# -----------------------------------------------------------------------------
# Arithmetic Operations
# -----------------------------------------------------------------------------
print("\n‚ûï Arithmetic Operations:")
print(f"Sum (a + b):                    {a + b}")
print(f"Element-wise multiplication:    {a * b}")
print(f"Dot product:                    {jnp.dot(a, b)}")
print(f"Matrix multiplication:          {jnp.matmul(a, b)}")  # Same as dot for 1D

# -----------------------------------------------------------------------------
# Mathematical Functions
# -----------------------------------------------------------------------------
print("\nüìê Mathematical Functions:")
print(f"Sine of a:                      {jnp.sin(a)}")
print(f"Exponential of b:               {jnp.exp(b)}")
print(f"Logarithm of a:                 {jnp.log(a)}")
print(f"Square root of b:               {jnp.sqrt(b)}")
print(f"Power (a^2):                    {jnp.power(a, 2)}")

# -----------------------------------------------------------------------------
# Statistical Functions
# -----------------------------------------------------------------------------
print("\nüìä Statistical Functions:")
print(f"Mean of a:                      {jnp.mean(a)}")
print(f"Standard deviation of b:        {jnp.std(b)}")
print(f"Variance of a:                  {jnp.var(a)}")
print(f"Median of b:                    {jnp.median(b)}")
print(f"Maximum value in a:             {jnp.max(a)}")
print(f"Minimum value in b:             {jnp.min(b)}")

# -----------------------------------------------------------------------------
# Aggregation Functions
# -----------------------------------------------------------------------------
print("\nüî¢ Aggregation Functions:")
print(f"Sum of all elements in a:       {jnp.sum(a)}")
print(f"Product of all elements in b:   {jnp.prod(b)}")
print(f"Cumulative sum of a:            {jnp.cumsum(a)}")
print(f"Cumulative product of a:        {jnp.cumprod(a)}")

# -----------------------------------------------------------------------------
# Array Manipulation - Reshaping
# -----------------------------------------------------------------------------
print("\nüîÑ Array Reshaping:")
a_reshaped = a.reshape((3, 1))
print(f"Reshaped a to (3,1):\n{a_reshaped}")
print(f"Transpose of reshaped b:\n{b.reshape((3, 1)).T}")

# -----------------------------------------------------------------------------
# Array Manipulation - Stacking and Concatenation
# -----------------------------------------------------------------------------
print("\nüìö Stacking and Concatenation:")
print(f"Vertical stack (vstack):\n{jnp.vstack((a, b))}")
print(f"Horizontal stack (hstack):\n{jnp.hstack((a, b))}")
print(f"Concatenate (same as hstack for 1D):\n{jnp.concatenate((a, b))}")

# -----------------------------------------------------------------------------
# Array Query Operations
# -----------------------------------------------------------------------------
print("\nüîç Array Query Operations:")
print(f"Unique elements in b:           {jnp.unique(b)}")
print(f"Sorted a (descending):          {jnp.sort(a)[::-1]}")
print(f"Indices where a > 2:            {jnp.where(a > 2)}")
print(f"Boolean mask (a > 2):           {a > 2}")
print(f"Elements of a where a > 2:      {a[a > 2]}")

# -----------------------------------------------------------------------------
# Conversion Between JAX and NumPy
# -----------------------------------------------------------------------------
print("\nüîÑ JAX ‚Üî NumPy Conversion:")
numpy_array = np.array(a)
print(f"JAX to NumPy:                   {numpy_array} (type: {type(numpy_array)})")
jax_array = jnp.array(numpy_array)
print(f"NumPy to JAX:                   {jax_array} (type: {type(jax_array)})")

# -----------------------------------------------------------------------------
# IMMUTABILITY - Key Difference from NumPy
# -----------------------------------------------------------------------------
print("\n" + "=" * 70)
print("‚ö†Ô∏è  JAX ARRAYS ARE IMMUTABLE")
print("=" * 70)
print("""
Unlike NumPy, JAX arrays cannot be modified in place.
Operations return NEW arrays rather than modifying existing ones.

‚ùå This will FAIL:
   a[1] = 10.0  # TypeError: JAX arrays are immutable

‚úÖ Use this instead:
   a = a.at[1].set(10.0)  # Returns a new array with index 1 set to 10.0
""")

# Demonstrate immutable update
a_updated = a.at[1].set(10.0)
print(f"Original a:  {a}")
print(f"Updated a:   {a_updated}")
print("Notice: Original 'a' is unchanged!\n")

# -----------------------------------------------------------------------------
# SUMMARY
# -----------------------------------------------------------------------------
print("=" * 70)
print("KEY POINTS")
print("=" * 70)
print("""
‚úÖ JAX provides a NumPy-compatible API (jax.numpy)
‚úÖ Most NumPy operations work identically in JAX
‚úÖ JAX arrays are immutable - use .at[].set() for updates
‚úÖ JAX automatically runs on GPU/TPU when available
‚úÖ Seamlessly convert between JAX and NumPy arrays
‚úÖ JAX is designed for high-performance numerical computing
‚úÖ Perfect for machine learning, scientific computing, and simulations
""")

ARRAY CREATION AND BASIC OPERATIONS

üì¶ Array Creation:
Array a: [1. 2. 3.]
Array b: [4. 5. 6.]
Array a shape: (3,), dtype: float32

‚ûï Arithmetic Operations:
Sum (a + b):                    [5. 7. 9.]
Element-wise multiplication:    [ 4. 10. 18.]
Dot product:                    32.0
Matrix multiplication:          32.0

üìê Mathematical Functions:
Sine of a:                      [0.84147096 0.9092974  0.14112   ]
Exponential of b:               [ 54.598152 148.41316  403.4288  ]
Logarithm of a:                 [0.        0.6931472 1.0986123]
Square root of b:               [2.        2.236068  2.4494898]
Power (a^2):                    [1. 4. 9.]

üìä Statistical Functions:
Mean of a:                      2.0
Standard deviation of b:        0.8164966106414795
Variance of a:                  0.6666666865348816
Median of b:                    5.0
Maximum value in a:             3.0
Minimum value in b:             4.0

üî¢ Aggregation Functions:
Sum of all elements in a:       6.0
