# JAX for Scientific Machine Learning: From Basics to PINNs and Operator Learning

**Audience:** Master's-level interns with basic Python and ML background  
**Tools:** JAX, Haiku/Flax/Equinox (for NN), Optax (for optimizers), Matplotlib/Plotly (for plotting)

---

## Overview

This tutorial will guide you from the foundations of JAX and neural networks to advanced applications in scientific ML, specifically Physics-Informed Neural Networks (PINNs) and Operator Learning (DeepONet, FNO).  
Each module includes readings, code exercises, and milestones.

---

## Module 1: JAX Basics

**Objective:** Understand JAX's array programming, automatic differentiation, and functional programming style.

### Readings:
- [JAX 101: The Sharp Bits](https://docs.jax.dev/en/latest/notebooks/Common_Gotchas_in_JAX.html)
- [JAX: Autograd and XLA](https://docs.jax.dev/en/latest/advanced-autodiff.html#advanced-autodiff)
- [Quickstart to other advanced concepts of jax](https://docs.jax.dev/en/latest/quickstart.html#)

### Key Concepts:
- `jax.numpy` vs. `numpy`
- Pure functions
- `grad`, `jit`, `vmap`, `pmap`
- Static vs. dynamic arrays

### Assignment 1:
- Implement and plot the gradient of a univariate function (e.g., `f(x) = sin(x) + x^2`) using JAX's `grad`.
- Compare the result with finite differences.

---

## Module 2: Building Neural Networks in JAX

**Objective:** Learn to build and train neural networks using JAX libraries.

### Readings:
- [The Flax Philosophy](https://flax-linen.readthedocs.io/en/latest/philosophy.html)
- [Haiku: Basic Tutorial](https://dm-haiku.readthedocs.io/en/latest/)
- [Equinox: Simple neural nets](https://docs.kidger.site/equinox/all-of-equinox/)

### Key Concepts:
- Parameter initialization and management
- Forward function
- Loss computation
- Parameter updates using Optax

### Assignment 2:
- Implement a fully-connected neural network for 1D regression (e.g., fit `f(x) = sin(3x)`) in Flax/Haiku/Equinox.
- Visualize training loss and predictions.

---

## Module 3: Physics-Informed Neural Networks (PINNs)

**Objective:** Understand and implement PINNs for solving differential equations.

### Readings:
- [Original PINN paper (Raissi et al.)](https://www.sciencedirect.com/science/article/pii/S0021999118307125)
- [PINN overview (Distill-style)](https://maziarraissi.github.io/PINNs/)
 

### Key Concepts:
- Defining PDE residuals with automatic differentiation
- Collocation and boundary points
- Loss as a sum of physics and boundary losses
 
---

## Module 4: Operator Learning (DeepONet, FNO)

**Objective:** Learn neural operators that map functions to functions, a step beyond standard PINNs.

### Readings:
- [DeepONet: Theory and Applications](https://www.nature.com/articles/s42256-021-00302-5)
- [Fourier Neural Operator (FNO) paper](https://arxiv.org/abs/2010.08895)
- [U-FNO](https://arxiv.org/abs/2109.03697)

### Key Concepts:
- Neural operators: learning mappings between infinite-dimensional spaces
- Training on families of PDEs or parametric equations
- Encoder-branch/trunk structure (DeepONet), spectral convolution (FNO)

 

## Recommended Additional Readings

- [Neural Operators: Survey](https://arxiv.org/abs/2108.08481)
- [Physics-Informed Machine Learning Review](https://arxiv.org/abs/2308.08468)
- [JAX Scientific ML resources](https://github.com/google/jax#scientific-machine-learning)

---

## Tips

- Use Google Colab or JAX on GPU for faster training.
- Keep all code modular and functions pure for easier debugging and JAX compatibility.
- For further research, explore [Equinox](https://github.com/patrick-kidger/equinox) and [Diffrax](https://github.com/patrick-kidger/diffrax) for advanced scientific ML with JAX.

---

**Happy Learning!**