# JAX for Scientific Machine Learning: From Basics to PINNs and Operator Learning

**Audience:** Master's-level interns with basic Python and ML background  
**Tools:** JAX, Haiku/Flax/Equinox (for NN), Optax (for optimizers), Matplotlib/Plotly (for plotting)

---

## Overview

This tutorial will guide you from the foundations of JAX and neural networks to advanced applications in scientific ML, specifically Physics-Informed Neural Networks (PINNs) and Operator Learning (DeepONet, FNO).  
Each module includes readings, code exercises, and milestones.

---

## Module 1: JAX Basics

**Objective:** Understand JAX's array programming, automatic differentiation, and functional programming style.

### Readings:
- [JAX 101: The Sharp Bits](https://jax.readthedocs.io/en/latest/notebooks/quickstart.html)
- [JAX: Autograd and XLA](https://jax.readthedocs.io/en/latest/notebooks/autodiff_cookbook.html)
- [A Friendly Introduction to JAX for ML](https://uvadlc-notebooks.readthedocs.io/en/latest/tutorial_notebooks/JAX/tutorial1/Introduction_to_JAX.html)

### Key Concepts:
- `jax.numpy` vs. `numpy`
- Pure functions
- `grad`, `jit`, `vmap`, `pmap`
- Static vs. dynamic arrays

### Assignment 1:
- Implement and plot the gradient of a univariate function (e.g., `f(x) = sin(x) + x^2`) using JAX's `grad`.
- Compare the result with finite differences.

---

## Module 2: Building Neural Networks in JAX

**Objective:** Learn to build and train neural networks using JAX libraries.

### Readings:
- [Flax: Getting Started](https://flax.readthedocs.io/en/latest/getting_started.html)
- [Haiku: Basic Tutorial](https://dm-haiku.readthedocs.io/en/latest/notebooks/haiku_basics.html)
- [Equinox: Simple neural nets](https://docs.kidger.site/equinox/examples/simple_neural_network/)

### Key Concepts:
- Parameter initialization and management
- Forward function
- Loss computation
- Parameter updates using Optax

### Assignment 2:
- Implement a fully-connected neural network for 1D regression (e.g., fit `f(x) = sin(3x)`) in Flax/Haiku/Equinox.
- Visualize training loss and predictions.

---

## Module 3: Physics-Informed Neural Networks (PINNs)

**Objective:** Understand and implement PINNs for solving differential equations.

### Readings:
- [Original PINN paper (Raissi et al.)](https://www.sciencedirect.com/science/article/pii/S0021999118307125)
- [PINN overview (Distill-style)](https://maziarraissi.github.io/PINNs/)
- [A JAX PINN Example: 1D Poisson](https://colab.research.google.com/github/google/jax/blob/main/examples/pinn_poisson.ipynb)

### Key Concepts:
- Defining PDE residuals with automatic differentiation
- Collocation and boundary points
- Loss as a sum of physics and boundary losses

### Assignment 3:
- Implement a PINN to solve the 1D Poisson equation:  
  \[
  -u''(x) = f(x), \quad x \in (0,1), \quad u(0)=u(1)=0
  \]  
  with \(f(x) = \pi^2 \sin(\pi x)\) (exact solution: \(u(x) = \sin(\pi x)\)).
- Plot the learned solution vs. the exact solution.

---

## Module 4: Operator Learning (DeepONet, FNO)

**Objective:** Learn neural operators that map functions to functions, a step beyond standard PINNs.

### Readings:
- [DeepONet: Theory and Applications](https://www.nature.com/articles/s41467-021-27537-5)
- [Fourier Neural Operator (FNO) paper](https://arxiv.org/abs/2010.08895)
- [Operator Learning with Neural Networks (Review)](https://arxiv.org/abs/2111.05518)
- [JAX DeepONet Example](https://github.com/kidger-org/equinox/blob/main/examples/deeponet.py)
- [JAX FNO Example](https://github.com/kidger-org/equinox/blob/main/examples/fourier_neural_operator.py)

### Key Concepts:
- Neural operators: learning mappings between infinite-dimensional spaces
- Training on families of PDEs or parametric equations
- Encoder-branch/trunk structure (DeepONet), spectral convolution (FNO)

### Assignment 4:
- **DeepONet**: Implement DeepONet in JAX to learn the solution operator of the Poisson equation with variable right-hand side \(f(x)\).  
  Generate a dataset with several \(f(x)\) (e.g., sine/cosine combinations) and corresponding solutions.
- **(Bonus)**: Implement a simplified FNO for 1D problems.

---

## Final Project: PINN or Operator Learning for a "Real" Problem

- Pick a PDE (e.g., Burgers', heat, wave equation) or a parametric family
- Design and train a PINN or neural operator, visualize results, and compare with classical methods

---

## Recommended Additional Readings

- [Neural Operators: Survey](https://arxiv.org/abs/2111.05518)
- [Physics-Informed Machine Learning Review](https://arxiv.org/abs/2107.09443)
- [JAX Scientific ML resources](https://github.com/google/jax#scientific-machine-learning)

---

## Tips

- Use Google Colab or JAX on GPU for faster training.
- Keep all code modular and functions pure for easier debugging and JAX compatibility.
- For further research, explore [Equinox](https://github.com/patrick-kidger/equinox) and [Diffrax](https://github.com/patrick-kidger/diffrax) for advanced scientific ML with JAX.

---

**Happy Learning!**