In [13]:
# ========================= Arrays/Tensors & RNGs: JAX vs PyTorch =========================

import numpy as np
import jax
import jax.numpy as jnp
import torch

# --------------------------------- What is a JAX Array? (quick intuition) ---------------------------------
# • A JAX array (jnp.ndarray) is JAX’s n-dimensional array type, closely mirroring NumPy arrays.
# • You use jax.numpy (imported as jnp) instead of np:
#     - jnp.array, jnp.zeros, jnp.ones, jnp.linspace, jnp.sin, jnp.dot, jnp.linalg.norm, etc.
#     - Most of the time, jnp.* has the SAME signature and behavior as the corresponding np.* function.
# • Differences vs plain NumPy:
#     - JAX arrays live on accelerator devices (CPU/GPU/TPU) and are **immutable** (no in-place updates).
#     - All jnp operations are traceable and differentiable (so grad/jit/vmap can see them).
#     - Randomness is handled very differently (pure functional RNG, see RNG block below).
# • Devices:
#     - JAX automatically places arrays on a default device (CPU or GPU/TPU if configured).
#     - You can check with `.device()` and move with jax.device_put if needed, but in practice you mostly ignore it.

# Basic JAX array creation
x_jax = jnp.array([1.0, 2.0, 3.0])          # (3,), float32
M_jax = jnp.arange(12.0).reshape(3, 4)      # (3,4)
print("JAX x:", x_jax, " | shape:", x_jax.shape, " | dtype:", x_jax.dtype, " | device:", x_jax.device)
print("JAX M:\n", M_jax, "\n")

# jnp behaves like NumPy:
print("JAX jnp.sin(x):", jnp.sin(x_jax))
print("JAX M.T shape:", M_jax.T.shape)

# ---------------- RNGs in JAX vs NumPy ----------------
# NumPy RNG (stateful, object/global-state API):
# - Legacy: np.random.seed(...); draws mutate global state.
# - Modern: rng = np.random.default_rng(seed); draws mutate rng’s INTERNAL STATE:
#     rng = np.random.default_rng(0)
#     x = rng.uniform(size=...)               # state ADVANCES implicitly
#     y = rng.normal(size=...)                # next draw uses ADVANCED state, y differs from x
# - You don’t pass the RNG to every call; state changes happen behind the scenes.
# - There’s no built-in equivalent of “split” producing independent child generators per call (you can create
#   new Generators with new seeds, or use bit_generator.jumped(), but it’s not as ergonomic as JAX’s split).
# If rng is created with a fixed seed, the initial state is fixed, so the sequence of draws is reproducible across runs.
# If rng is defined without specifying a seed, the initial state differs each call (uses OS entropy), so the sequence of draws differs each run.
#
# JAX RNG (functional, stateless API):
# - You create an explicit PRNGKey and PASS IT IN to every random call.
#     key = jax.random.key(seed)              # create a key (counter-based PRNG); note that a seed must be provided
#     key, sub = jax.random.split(key)        # split into NEW independent keys
#     x = jax.random.uniform(sub, shape, ...) # must provide a key
# - Keys are PURE DATA. Reusing the SAME key → the SAME random numbers (no implicit advance)
#      y = jax.random.uniform(sub, shape, ...) # y == x if reusing 'sub'
#   You must SPLIT to advance the stream. This makes randomness reproducible under jit/vmap/pmap.
# - Splitting is cheap and mathematically well-defined for counter-based PRNGs (e.g., Threefry/Philox).
# - Best practices:
#     • Never reuse a key; always do: key, k1 = jax.random.split(key)
#     • For per-step randomness in training: key = jax.random.fold_in(key, step)  # mixes step into key
#     • For batched randomness: keys = jax.random.split(key, B)  # one subkey per batch element (works with vmap)
#
# Why JAX does this:
# - Pure functional RNG makes behavior deterministic under jit/vmap/pmap and easy to parallelize.
# - You control exactly where/when randomness happens by threading keys through your program.
#
# In the linear regression example:
#   key = jax.random.key(0)                       # initial key
#   x = jax.random.uniform(key, (N, 1)..)         # BAD: reusing 'key' → same x on every call
#   key, sub = jax.random.split(key)              # GOOD: advance RNG; 'key' 'sub' is a fresh stream every call
#   x = jax.random.uniform(key, (N,1))            # draw with the fresh subkey
#   noise = 0.05 * jax.random.normal(sub, (N,1))  # draw with the fresh subkey
# Pattern: (key, k1, k2, ...) = jax.random.split(key, n) then use k1, k2,... for each random op.

# (Note: in actual JAX code today you’d typically use jax.random.PRNGKey(seed), but the functional pattern above is the key idea.)

# Tiny demo of JAX RNG behavior (functional keys)
key = jax.random.PRNGKey(0)
key, sub1, sub2 = jax.random.split(key, 3)
x1 = jax.random.normal(sub1, (3,))
x2 = jax.random.normal(sub1, (3,))   # same key → same values
x3 = jax.random.normal(sub2, (3,))   # different key → different values
print("JAX RNG x1:", x1)
print("JAX RNG x2 (same key):", x2)
print("JAX RNG x3 (different key):", x3, "\n")

torch.manual_seed(0)  # reproducible per run (used for all random ops below but state is advanced as in NumPy)

# --------------------------------- What is a Torch Tensor? (quick intuition) ---------------------------------
# • A torch.Tensor is PyTorch’s fundamental n-dimensional array (like NumPy arrays, but with autograd + devices).
# • Carries: shape (sizes along each axis), dtype (float32, int64, …), device (where it lives: 'cpu', 'cuda', 'mps'), and
#   autograd metadata (requires_grad) for learning.
# • Device matters: you can compute on CPU (default), NVIDIA GPU ('cuda', fastest), or Apple Silicon GPU ('mps', faster). Ops require all inputs
#   on the *same* device; move with .to("cuda") / .to("cpu") / .to("mps") or create directly on that device.
# • Unlike plain Python lists, tensors support fast vectorized math, broadcasting, linear algebra (like NumPy arrays) and backprop.

# -------------------------------------- TENSORS (fundamental object) --------------------------------------
x0 = torch.tensor([1.0, 2.0, 3.0])                         # (3,), float32 CPU
x1 = torch.zeros(2, 3)                                     # (2,3), float32 CPU
x2 = torch.randn(4, 5)                                     # ~ N(0,1), (4,5), float32 CPU
x3 = torch.rand(4, 5)                                      # ~ U(0,1), (4,5), float32 CPU
x4 = torch.ones((4,))                                      # (4,) all ones, float32 CPU
x_like = torch.zeros_like(x3)                              # shape & dtype like x3

print(x0)
print("shape:", x0.shape, "dtype:", x0.dtype, "device:", x0.device)
print(x1)
print("shape:", x1.shape, "dtype:", x1.dtype, "device:", x1.device)
print(x2)
print("shape:", x2.shape, "dtype:", x2.dtype, "device:", x2.device)
print(x3)
print("shape:", x3.shape, "dtype:", x3.dtype, "device:", x3.device)
print(x4)
print("shape:", x4.shape, "dtype:", x4.dtype, "device:", x4.device)
print(x_like)
print("shape:", x_like.shape, "dtype:", x_like.dtype, "device:", x_like.device)

JAX x: [1. 2. 3.]  | shape: (3,)  | dtype: float32  | device: TFRT_CPU_0
JAX M:
 [[ 0.  1.  2.  3.]
 [ 4.  5.  6.  7.]
 [ 8.  9. 10. 11.]] 

JAX jnp.sin(x): [0.84147096 0.9092974  0.14112   ]
JAX M.T shape: (4, 3)
JAX RNG x1: [-2.4424558  -2.0356805   0.20554423]
JAX RNG x2 (same key): [-2.4424558  -2.0356805   0.20554423]
JAX RNG x3 (different key): [ 1.2956359   1.3550105  -0.40960556] 

tensor([1., 2., 3.], device='mps:0')
shape: torch.Size([3]) dtype: torch.float32 device: mps:0
tensor([[0., 0., 0.],
        [0., 0., 0.]], device='mps:0')
shape: torch.Size([2, 3]) dtype: torch.float32 device: mps:0
tensor([[-0.3413, -1.4359,  0.7669, -1.1818,  0.7512],
        [-1.1156,  0.6561, -0.1908, -0.2969,  0.4658],
        [ 0.2387, -0.0734, -1.1044,  0.8299,  0.7745],
        [ 2.1827, -0.2245,  0.0699,  0.1540,  1.3947]], device='mps:0')
shape: torch.Size([4, 5]) dtype: torch.float32 device: mps:0
tensor([[0.1421, 0.2592, 0.4999, 0.1917, 0.5603],
        [0.1563, 0.4288, 0.5015, 0.4667, 0

In [14]:
# ------------------------------ Device & dtype basics ------------------------------
# Pick a device ONCE. By default PyTorch uses CPU; we switch to CUDA/MPS if available.
if torch.cuda.is_available():
    device = torch.device("cuda")
    print("Using CUDA GPU")
elif torch.backends.mps.is_available():
    device = torch.device("mps")
    print("Using Apple Silicon GPU (MPS)")
else:
    device = torch.device("cpu")
    print("Using CPU")

# (Optional) Make *future* factory-created tensors default to this device (PyTorch ≥ 2.0).
# NOTE: torch.set_default_device(...) returns None; don't assign it to a variable.
torch.set_default_device(device)

# ------------------------------ Creation respects default device ------------------------------
# Because we called torch.set_default_device(device), this lands on `device`.
A = torch.arange(12., dtype=torch.float32).reshape(3, 4)
print("A.device:", A.device)   # -> e.g. mps/cuda/cpu depending on above

# ------------------------------ Moving tensors across devices ------------------------------
x64_cpu = torch.randn(2, 2, dtype=torch.float64, device="cpu")  # force CPU for float64
x_dev = x64_cpu.to(dtype=torch.float32, device=device)          # cast to float32 and move to device
print("x_dev:", x_dev.shape, x_dev.dtype, x_dev.device)

# ------------------------------ NumPy interoperability (CPU only shares memory) ------------------------------
a = np.array([1, 2, 3], dtype=np.float32)     # NumPy array (float32, usual default is float64)
t_cpu = torch.from_numpy(a)                   # CPU tensor that SHARES MEMORY with `a`

# If you need this tensor on GPU/MPS for compute:
t_dev = t_cpu.to(device)                      # now on device (no longer shares with `a`)

# Converting back to NumPy REQUIRES CPU:
a2 = t_dev.cpu().numpy()                      # move to CPU first, then NumPy
# `a2` now shares memory with the CPU tensor we just created; modifying that CPU tensor will reflect in `a2`.

# Demonstrate in-place sharing on CPU:
t_cpu += 1                                    # modifies both t_cpu and the original NumPy array `a`
print("a after t_cpu += 1:", a)               # shows the +1 update

# TL;DR:
# • Use `device = torch.device(...)` once; optionally call torch.set_default_device(device).
# • On MPS, avoid float64: cast to float32/float16/bfloat16 when moving to "mps".
# • .numpy() only works for CPU tensors; use `.cpu().numpy()` when coming from GPU/MPS.
# • NumPy <-> Torch zero-copy sharing happens ONLY on CPU tensors (and same dtype/contiguous layout).

Using Apple Silicon GPU (MPS)
A.device: mps:0
x_dev: torch.Size([2, 2]) torch.float32 mps:0
a after t_cpu += 1: [2. 3. 4.]


In [None]:
#========================= Core Tensor Operations: JAX vs PyTorch =========================
# Indexing/slicing, reshaping, broadcasting, stacking/concat, linear algebra, einsum.

# Reuse a simple helper
def show(name, arr):
    print(f"{name}:")
    print(arr)
    try:
        shape = arr.shape
    except AttributeError:
        shape = "no .shape"
    print("shape:", shape, "\n")

torch.set_default_device('cpu') # reset to CPU for this section

# =========================================================================================
# 1a) INDEXING, SLICING, RESHAPING, TRANSPOSING
# =========================================================================================

print("\n================ 1) Indexing / Reshape / Transpose ================")

# ------------------------------ JAX ------------------------------
print("\n--- JAX: indexing / reshape / transpose ---")

M_jax = jnp.arange(12).reshape(3, 4)   # [[0,1,2,3],
                                       #  [4,5,6,7],
                                       #  [8,9,10,11]]

row0_jax  = M_jax[0]           # (4,)
col1_jax  = M_jax[:, 1]        # (3,)
block_jax = M_jax[0:2, 1:4]    # (2,3)

show("M_jax", M_jax)
show("row0_jax = M_jax[0]", row0_jax)
show("col1_jax = M_jax[:, 1]", col1_jax)
show("block_jax = M_jax[0:2, 1:4]", block_jax)

# Reshape / transpose
A_jax   = jnp.arange(24).reshape(2, 3, 4)       # (2,3,4)
A_T_jax = jnp.transpose(A_jax, (0, 2, 1))       # (2,4,3)
A_F_jax = jnp.reshape(A_jax, (2, -1))           # (2,12)
A_R_jax = A_jax.reshape(6, 4)                   # (6,4)

print("JAX shapes: A:", A_jax.shape,
      "A_T:", A_T_jax.shape,
      "A_F:", A_F_jax.shape,
      "A_R:", A_R_jax.shape)

# ------------------------------ PyTorch ------------------------------
print("\n--- PyTorch: indexing / reshape / transpose ---")

M_t = torch.arange(12., dtype=torch.float32).reshape(3, 4)

row0_t  = M_t[0]           # (4,)
col1_t  = M_t[:, 1]        # (3,)
block_t = M_t[0:2, 1:4]    # (2,3)

show("M_t", M_t)
show("row0_t = M_t[0]", row0_t)
show("col1_t = M_t[:, 1]", col1_t)
show("block_t = M_t[0:2, 1:4]", block_t)

A_t   = torch.arange(24.).reshape(2, 3, 4)   # (2,3,4)
A_T_t = A_t.transpose(1, 2)                  # (2,4,3) (swap dims 1 and 2)
A_F_t = A_t.flatten(start_dim=1)             # (2,12)
A_R_t = A_t.reshape(6, 4)                    # (6,4)

print("Torch shapes: A:", A_t.shape,
      "A_T:", A_T_t.shape,
      "A_F:", A_F_t.shape,
      "A_R:", A_R_t.shape)

# =========================================================================================
# 1b) UNSQUEEZING/EXPANDING DIMS
# =========================================================================================

print("\n================ 1b) Unsqueeze / Expand Dims ================")

# ------------------------------ JAX ------------------------------
print("\n--- JAX: expand_dims / newaxis ---")

v_jax = jnp.arange(4)   # (4,)
show("v_jax", v_jax)

# Row vs column views
v_row_jax = jnp.expand_dims(v_jax, axis=0)   # (1,4)
v_col_jax = jnp.expand_dims(v_jax, axis=1)   # (4,1)

# Same thing using None / newaxis
v_row2_jax = v_jax[None, :]   # (1,4)
v_col2_jax = v_jax[:, None]   # (4,1)

show("v_row_jax = expand_dims(v_jax, axis=0)", v_row_jax)
show("v_col_jax = expand_dims(v_jax, axis=1)", v_col_jax)
show("v_row2_jax = v_jax[None, :]", v_row2_jax)
show("v_col2_jax = v_jax[:, None]", v_col2_jax)

# Example: broadcasting with an extra batch dim
B_jax = jnp.ones((2, 4))             # (2,4)
v_batched_jax = v_jax[None, :]       # (1,4) → broadcast to (2,4) when added
B_plus_v_jax = B_jax + v_batched_jax
show("B_jax (2,4)", B_jax)
show("v_batched_jax = v_jax[None, :]", v_batched_jax)
show("B_plus_v_jax = B_jax + v_batched_jax", B_plus_v_jax)

# ------------------------------ PyTorch ------------------------------
print("\n--- PyTorch: unsqueeze ---")

v_t = torch.arange(4., dtype=torch.float32)   # (4,)
show("v_t", v_t)

# Row vs column views
v_row_t = v_t.unsqueeze(0)   # (1,4)
v_col_t = v_t.unsqueeze(1)   # (4,1)

# Same idea with indexing
v_row2_t = v_t[None, :]      # (1,4)
v_col2_t = v_t[:, None]      # (4,1)

show("v_row_t = v_t.unsqueeze(0)", v_row_t)
show("v_col_t = v_t.unsqueeze(1)", v_col_t)
show("v_row2_t = v_t[None, :]", v_row2_t)
show("v_col2_t = v_t[:, None]", v_col2_t)

# Example: expand to a batch
B_t = torch.ones(2, 4)                # (2,4)
v_batched_t = v_t.unsqueeze(0)        # (1,4)
v_expanded_t = v_batched_t.expand(2, 4)  # (2,4) (no data copy)
B_plus_v_t = B_t + v_expanded_t

show("B_t (2,4)", B_t)
show("v_batched_t = v_t.unsqueeze(0)", v_batched_t)
show("v_expanded_t = v_batched_t.expand(2,4)", v_expanded_t)
show("B_plus_v_t = B_t + v_expanded_t", B_plus_v_t)


# =========================================================================================
# 2) BROADCASTING & ELEMENTWISE OPERATIONS
# =========================================================================================

print("\n================ 2) Broadcasting & Elementwise Ops ================")

# ------------------------------ JAX ------------------------------
print("\n--- JAX: broadcasting & elementwise ---")

B_jax = jnp.linspace(0.0, 1.0, 3).reshape(3, 1)   # (3,1)
C_jax = jnp.linspace(-1.0, 1.0, 4).reshape(1, 4)  # (1,4)
BC_jax = B_jax + C_jax                            # (3,4) via broadcasting

show("B_jax (3,1)", B_jax)
show("C_jax (1,4)", C_jax)
show("B_jax + C_jax (3,4)", BC_jax)

Z_jax = jnp.linspace(-2.0, 2.0, 5)
E_jax = Z_jax**2 + jnp.sin(Z_jax)
print("JAX E:", E_jax, "| mean(E):", float(jnp.mean(E_jax)))

# ------------------------------ PyTorch ------------------------------
print("\n--- PyTorch: broadcasting & elementwise ---")

B_t = torch.linspace(0.0, 1.0, steps=3).reshape(3, 1)   # (3,1)
C_t = torch.linspace(-1.0, 1.0, steps=4).reshape(1, 4)  # (1,4)
BC_t = B_t + C_t                                        # (3,4)

show("B_t (3,1)", B_t)
show("C_t (1,4)", C_t)
show("B_t + C_t (3,4)", BC_t)

Z_t = torch.linspace(-2.0, 2.0, steps=5)
E_t = Z_t**2 + torch.sin(Z_t)
print("Torch E:", E_t, "| mean(E):", E_t.mean().item())

# =========================================================================================
# 3) STACKING vs CONCATENATION
# =========================================================================================

print("\n================ 3) Stacking vs Concatenation ================")

# ------------------------------ JAX ------------------------------
print("\n--- JAX: jnp.stack vs jnp.concatenate ---")

v1_jax = jnp.array([1., 2., 3.])
v2_jax = jnp.array([4., 5., 6.])

# stack: add a NEW axis
stack0_jax = jnp.stack([v1_jax, v2_jax], axis=0)   # (2,3)
stack1_jax = jnp.stack([v1_jax, v2_jax], axis=1)   # (3,2)

# concat: join along existing axis
cat0_jax = jnp.concatenate([v1_jax, v2_jax], axis=0)  # (6,)

show("stack0_jax (axis=0)", stack0_jax)
show("stack1_jax (axis=1)", stack1_jax)
show("cat0_jax (axis=0)", cat0_jax)

A_jax2 = jnp.arange(6).reshape(2, 3)
B_jax2 = A_jax2 + 10

S0_jax = jnp.stack([A_jax2, B_jax2], axis=0)  # (2,2,3)
S1_jax = jnp.stack([A_jax2, B_jax2], axis=1)  # (2,2,3)
S2_jax = jnp.stack([A_jax2, B_jax2], axis=2)  # (2,3,2)

print("JAX S0,S1,S2 shapes:", S0_jax.shape, S1_jax.shape, S2_jax.shape)

C_row_jax = jnp.concatenate([A_jax2, B_jax2], axis=0)  # (4,3)
C_col_jax = jnp.concatenate([A_jax2, B_jax2], axis=1)  # (2,6)
print("JAX C_row.shape:", C_row_jax.shape, "C_col.shape:", C_col_jax.shape)

# ------------------------------ PyTorch ------------------------------
print("\n--- PyTorch: torch.stack vs torch.cat ---")

v1_t = torch.tensor([1., 2., 3.])
v2_t = torch.tensor([4., 5., 6.])

stack0_t = torch.stack([v1_t, v2_t], dim=0)  # (2,3)
stack1_t = torch.stack([v1_t, v2_t], dim=1)  # (3,2)
cat0_t   = torch.cat([v1_t, v2_t], dim=0)    # (6,)

show("stack0_t (dim=0)", stack0_t)
show("stack1_t (dim=1)", stack1_t)
show("cat0_t (dim=0)", cat0_t)

A_t2 = torch.arange(6.).reshape(2, 3)
B_t2 = A_t2 + 10

S0_t = torch.stack([A_t2, B_t2], dim=0)  # (2,2,3)
S1_t = torch.stack([A_t2, B_t2], dim=1)  # (2,2,3)
S2_t = torch.stack([A_t2, B_t2], dim=2)  # (2,3,2)

print("Torch S0,S1,S2 shapes:", S0_t.shape, S1_t.shape, S2_t.shape)

C_row_t = torch.cat([A_t2, B_t2], dim=0)  # (4,3)
C_col_t = torch.cat([A_t2, B_t2], dim=1)  # (2,6)
print("Torch C_row.shape:", C_row_t.shape, "C_col.shape:", C_col_t.shape)

# =========================================================================================
# 4) LINEAR ALGEBRA: MATMUL, NORMS, QR, EIG
# =========================================================================================

print("\n================ 4) Linear Algebra / Matmul ================")

# ------------------------------ JAX ------------------------------
print("\n--- JAX: matmul / linalg ---")

X_jax = jnp.arange(35., dtype=jnp.float32).reshape(7, 5) / 10.0  # (7,5)
W_jax = jnp.arange(15., dtype=jnp.float32).reshape(3, 5) / 10.0  # (3,5)

Y_jax = X_jax @ W_jax.T     # (7,3)
show("Y_jax = X @ W^T", Y_jax)

normW_jax = jnp.linalg.norm(W_jax)  # Frobenius norm
print("JAX ||W||_F:", float(normW_jax))

M_square_jax = jnp.arange(25., dtype=jnp.float32).reshape(5, 5)

# NOTE: kernel crashes here; jnp linalg issues? 
# Q_jax, R_jax = jnp.linalg.qr(M_square_jax)       # QR
# eigvals_jax = jnp.linalg.eigvals(M_square_jax)   # eigenvalues (may be complex)

# show("Q_jax (from QR)", Q_jax)
# show("R_jax (from QR)", R_jax)
# show("eigvals_jax", eigvals_jax)

# ------------------------------ PyTorch ------------------------------
print("\n--- PyTorch: matmul / linalg ---")

X_t3 = torch.arange(35., dtype=torch.float32).reshape(7, 5) / 10.0
W_t3 = torch.arange(15., dtype=torch.float32).reshape(3, 5) / 10.0

Y_t3 = X_t3 @ W_t3.T   # (7,3)
show("Y_t3 = X @ W^T", Y_t3)

normW_t = torch.linalg.norm(W_t3)   # Frobenius norm
print("Torch ||W||_F:", normW_t.item())

M_square_t = torch.arange(25., dtype=torch.float32).reshape(5, 5)
Q_t, R_t = torch.linalg.qr(M_square_t)         # QR
eigvals_t = torch.linalg.eigvals(M_square_t)   # complex eigenvalues

show("Q_t (from QR)", Q_t)
show("R_t (from QR)", R_t)
show("eigvals_t", eigvals_t)


# =========================================================================================
# 5) EINSUM: TRACE, DIAGONAL, OUTER, BATCHED MATMUL, BILINEAR FORM
# =========================================================================================

print("\n================ 5) Einsum Patterns ================")

# ------------------------------ JAX ------------------------------
print("\n--- JAX: einsum ---")

# 1) Trace: 'ii->'
M_jax_e = jnp.arange(16., dtype=jnp.float32).reshape(4, 4)
trace_jax = jnp.einsum("ii->", M_jax_e)
show("JAX trace = einsum('ii->', M)", trace_jax)

# 2) Diagonal: 'ii->i'
diag_jax = jnp.einsum("ii->i", M_jax_e)
show("JAX diag = einsum('ii->i', M)", diag_jax)

# 3) Outer product: 'i,j->ij'
x_jax_e = jnp.arange(5., 10.)
y_jax_e = jnp.arange(-1., 3.)
outer_jax = jnp.einsum("i,j->ij", x_jax_e, y_jax_e)
show("JAX outer = einsum('i,j->ij', x, y)", outer_jax)

# 4) Batched matmul: 'bij,bjk->bik'
As_jax = jnp.ones((3, 2, 5))
Bs_jax = jnp.arange(3*5*4., dtype=jnp.float32).reshape(3, 5, 4)
batched_mm_jax = jnp.einsum("bij,bjk->bik", As_jax, Bs_jax)
show("JAX batched_mm = einsum('bij,bjk->bik', As, Bs)", batched_mm_jax)

# 5) Bilinear form: y_b = x_b^T A x_b (per batch)
xb_jax = jnp.arange(15., dtype=jnp.float32).reshape(3, 5)  # (B=3, N=5)
A2_jax = jnp.eye(5, dtype=jnp.float32)                     # (5,5)
yb_jax = jnp.einsum("bi,ij,bj->b", xb_jax, A2_jax, xb_jax) # (B,)
show("JAX yb = einsum('bi,ij,bj->b', xb, A2, xb)", yb_jax)

# ------------------------------ PyTorch ------------------------------
print("\n--- PyTorch: einsum ---")

M_t_e = torch.arange(16., dtype=torch.float32).reshape(4, 4)
trace_t = torch.einsum("ii->", M_t_e)
show("Torch trace = einsum('ii->', M)", trace_t)

diag_t = torch.einsum("ii->i", M_t_e)
show("Torch diag = einsum('ii->i', M)", diag_t)

x_t_e = torch.arange(5., 10.)
y_t_e = torch.arange(-1., 3.)
outer_t = torch.einsum("i,j->ij", x_t_e, y_t_e)
show("Torch outer = einsum('i,j->ij', x, y)", outer_t)

As_t = torch.ones(3, 2, 5)
Bs_t = torch.arange(3*5*4., dtype=torch.float32).reshape(3, 5, 4)
batched_mm_t = torch.einsum("bij,bjk->bik", As_t, Bs_t)
show("Torch batched_mm = einsum('bij,bjk->bik', As, Bs)", batched_mm_t)

xb_t = torch.arange(15., dtype=torch.float32).reshape(3, 5)  # (B=3, N=5)
A2_t = torch.eye(5, dtype=torch.float32)                     # (5,5)
yb_t = torch.einsum("bi,ij,bj->b", xb_t, A2_t, xb_t)         # (B,)
show("Torch yb = einsum('bi,ij,bj->b', xb, A2, xb)", yb_t)



--- JAX: indexing / reshape / transpose ---
M_jax:
[[ 0  1  2  3]
 [ 4  5  6  7]
 [ 8  9 10 11]]
shape: (3, 4) 

row0_jax = M_jax[0]:
[0 1 2 3]
shape: (4,) 

col1_jax = M_jax[:, 1]:
[1 5 9]
shape: (3,) 

block_jax = M_jax[0:2, 1:4]:
[[1 2 3]
 [5 6 7]]
shape: (2, 3) 

JAX shapes: A: (2, 3, 4) A_T: (2, 4, 3) A_F: (2, 12) A_R: (6, 4)

--- PyTorch: indexing / reshape / transpose ---
M_t:
tensor([[ 0.,  1.,  2.,  3.],
        [ 4.,  5.,  6.,  7.],
        [ 8.,  9., 10., 11.]])
shape: torch.Size([3, 4]) 

row0_t = M_t[0]:
tensor([0., 1., 2., 3.])
shape: torch.Size([4]) 

col1_t = M_t[:, 1]:
tensor([1., 5., 9.])
shape: torch.Size([3]) 

block_t = M_t[0:2, 1:4]:
tensor([[1., 2., 3.],
        [5., 6., 7.]])
shape: torch.Size([2, 3]) 

Torch shapes: A: torch.Size([2, 3, 4]) A_T: torch.Size([2, 4, 3]) A_F: torch.Size([2, 12]) A_R: torch.Size([6, 4])


--- JAX: expand_dims / newaxis ---
v_jax:
[0 1 2 3]
shape: (4,) 

v_row_jax = expand_dims(v_jax, axis=0):
[[0 1 2 3]]
shape: (1, 4) 

v_col_jax 

In [15]:
# ========================= Gradients, Jacobians, JVP/VJP: JAX vs PyTorch =========================

# In this section we write everything in a **JAX-style functional way**:
#
#   • JAX:      jax.grad, jax.jacfwd, jax.jacrev, jax.jvp, jax.vjp, jax.vmap
#   • PyTorch:  torch.func.grad, torch.func.jacfwd, torch.func.jacrev,
#               torch.func.jvp, torch.func.vjp, torch.func.vmap
#
# Key idea:
#   • We treat functions as PURE maps on tensors:
#         f(x) -> scalar
#         F(x) -> vector
#     and ask JAX/PyTorch to return derivatives of those functions.

# Helper for printing
def show(name, x):
    print(f"{name}:")
    print(x)
    print("shape:", getattr(x, "shape", "no .shape"), "\n")

# =========================================================================================
# 1) GRADIENTS (SCALAR OUTPUT)
# =========================================================================================
# 1.1) f : R → R
# 1.2) F : R^n → R
# 1.3) F : R^{N1×…×Nk} → R (+ flatten equivalence)
# =========================================================================================

print("\n================ 1) GRADIENTS (SCALAR OUTPUT) ================")

# ------------------------------------------------------------------------------
# 1.1) f : R → R   (JAX)
# ------------------------------------------------------------------------------

def f_scalar_jax(x):
    """f: R -> R   (scalar input, scalar output) in JAX."""
    return jnp.sin(x) + 0.1 * x**2

df_dx_jax = jax.grad(f_scalar_jax)

x0_jax = 1.23
print("\n1.1) JAX: f : R -> R")
print("  x0      =", x0_jax)
print("  f(x0)   =", float(f_scalar_jax(x0_jax)))
print("  df/dx|x0:", float(df_dx_jax(x0_jax)))


# 1.1) f : R → R   (PyTorch)
def f_scalar_torch(x: torch.Tensor) -> torch.Tensor:
    """f: R -> R   (scalar input, scalar output) in Torch."""
    return torch.sin(x) + 0.1 * x**2

df_dx_torch = torch.func.grad(f_scalar_torch)

x0_t = torch.tensor(1.23)
print("\n1.1) Torch: f : R -> R")
print("  x0      =", x0_t.item())
print("  f(x0)   =", f_scalar_torch(x0_t).item())
print("  df/dx|x0:", df_dx_torch(x0_t).item())


# ------------------------------------------------------------------------------
# 1.2) F : R^n → R   (vector input, scalar output)
# ------------------------------------------------------------------------------

def F_vec_jax(x):
    """
    JAX: F: R^3 -> R.
      F(x) = sum_i [ sin(x_i) + 0.5 * x_i^2 ] + 0.1 * (x^T A x)
    """
    A = jnp.array([[2.0, -0.5, 0.0],
                   [-0.5, 1.0, 0.3],
                   [0.0,  0.3, 1.5]])
    quad = x @ A @ x
    return jnp.sum(jnp.sin(x) + 0.5 * x**2) + 0.1 * quad

grad_F_vec_jax = jax.grad(F_vec_jax)

x0_vec_jax = jnp.array([0.7, -1.2, 0.3])
g_vec_jax = grad_F_vec_jax(x0_vec_jax)

print("\n1.2) JAX: F : R^n -> R")
show("x0_vec_jax", x0_vec_jax)
print("F(x0)     =", float(F_vec_jax(x0_vec_jax)))
show("∇_x F(x0)_jax", g_vec_jax)
print("shape(x0) =", x0_vec_jax.shape, "shape(∇F) =", g_vec_jax.shape)


def F_vec_torch(x: torch.Tensor) -> torch.Tensor:
    """
    Torch: F: R^3 -> R.
      F(x) = sum_i [ sin(x_i) + 0.5 * x_i^2 ] + 0.1 * (x^T A x)
    """
    A = torch.tensor([[2.0, -0.5, 0.0],
                      [-0.5, 1.0, 0.3],
                      [0.0,  0.3, 1.5]])
    quad = x @ A @ x
    return torch.sum(torch.sin(x) + 0.5 * x**2) + 0.1 * quad

grad_F_vec_torch = torch.func.grad(F_vec_torch)

x0_vec_t = torch.tensor([0.7, -1.2, 0.3])
g_vec_t = grad_F_vec_torch(x0_vec_t)

print("\n1.2) Torch: F : R^n -> R")
show("x0_vec_t", x0_vec_t)
print("F(x0)     =", F_vec_torch(x0_vec_t).item())
show("∇_x F(x0)_torch", g_vec_t)
print("shape(x0) =", x0_vec_t.shape, "shape(∇F) =", g_vec_t.shape)


# ------------------------------------------------------------------------------
# 1.3) F : R^{N1×…×Nk} → R   (tensor input, scalar output)
# ------------------------------------------------------------------------------

def F_tensor_jax(X):
    """
    JAX: F(X) = 1/2 ||X||^2  (Frobenius norm squared).
    X ∈ R^{2×3} here.
    """
    return 0.5 * jnp.sum(X**2)

grad_F_tensor_jax = jax.grad(F_tensor_jax)

X0_jax = jnp.array([[0.1, -0.2, 0.3],
                    [1.0,  0.5, -0.7]])

g_tensor_jax = grad_F_tensor_jax(X0_jax)

print("\n1.3) JAX: F : R^{N1×…×Nk} -> R")
show("X0_jax", X0_jax)
print("F(X0)      =", float(F_tensor_jax(X0_jax)))
show("∇_X F(X0)_jax", g_tensor_jax)


def F_tensor_torch(X: torch.Tensor) -> torch.Tensor:
    """
    Torch: F(X) = 1/2 ||X||^2, X ∈ R^{2×3} here.
    """
    return 0.5 * torch.sum(X**2)

grad_F_tensor_torch = torch.func.grad(F_tensor_torch)

X0_t = torch.tensor([[0.1, -0.2, 0.3],
                     [1.0,  0.5, -0.7]])

g_tensor_t = grad_F_tensor_torch(X0_t)

print("\n1.3) Torch: F : R^{N1×…×Nk} -> R")
show("X0_t", X0_t)
print("F(X0)      =", F_tensor_torch(X0_t).item())
show("∇_X F(X0)_torch", g_tensor_t)

# Flattening equivalence
def F_flat_jax(x_flat):
    return 0.5 * jnp.sum(x_flat**2)

grad_F_flat_jax = jax.grad(F_flat_jax)
x0_flat_jax = X0_jax.reshape(-1)
g_flat_jax = grad_F_flat_jax(x0_flat_jax)

print("\nFlattening equivalence (JAX): vec(∇_X F) = ∇_{vec(X)} F")
show("vec(X0_jax)", x0_flat_jax)
show("∇_{vec(X)} F_jax", g_flat_jax)
show("vec(∇_X F(X0)_jax)", g_tensor_jax.reshape(-1))
print("Difference (should be ~0):", g_flat_jax - g_tensor_jax.reshape(-1))

def F_flat_torch(x_flat: torch.Tensor) -> torch.Tensor:
    return 0.5 * torch.sum(x_flat**2)

grad_F_flat_torch = torch.func.grad(F_flat_torch)
x0_flat_t = X0_t.reshape(-1)
g_flat_t = grad_F_flat_torch(x0_flat_t)

print("\nFlattening equivalence (Torch): vec(∇_X F) = ∇_{vec(X)} F")
show("vec(X0_t)", x0_flat_t)
show("∇_{vec(X)} F_torch", g_flat_t)
show("vec(∇_X F(X0)_torch)", g_tensor_t.reshape(-1))
print("Difference (should be ~0):", g_flat_t - g_tensor_t.reshape(-1))


# =========================================================================================
# 2) JACOBIANS (VECTOR OUTPUT) + JVP / VJP
# =========================================================================================
# 2.1) G : R^n → R^m
# 2.2) G : tensor → vector
# 2.3) G : (X, θ) → vector (multi-arg)
# =========================================================================================

print("\n================ 2) JACOBIANS (VECTOR OUTPUT) ================")

# ------------------------------------------------------------------------------
# 2.1) G : R^n → R^m (vector → vector) + JVP/VJP
# ------------------------------------------------------------------------------

def G_jax(x):
    """
    JAX: G: R^3 -> R^3, components:
      G0 = x0 * exp(x1)
      G1 = sin(x1) + x2^2
      G2 = x0 + x1 + x2
    """
    x0, x1, x2 = x
    return jnp.array([
        x0 * jnp.exp(x1),
        jnp.sin(x1) + x2**2,
        x0 + x1 + x2
    ])

x0_jax_vec = jnp.array([0.7, -0.2, 0.5])
J_fwd_jax = jax.jacfwd(G_jax)(x0_jax_vec)   # forward-mode
J_rev_jax = jax.jacrev(G_jax)(x0_jax_vec)   # reverse-mode

print("\n2.1) JAX: G : R^3 -> R^3")
show("G_jax(x0)", G_jax(x0_jax_vec))
show("Jacobian via jacfwd (JAX)", J_fwd_jax)
show("Jacobian via jacrev (JAX)", J_rev_jax)

# JVP: J(x0) · v
v_jax = jnp.array([1.0, -2.0, 0.5])
y_jax, Jv_jax = jax.jvp(G_jax, (x0_jax_vec,), (v_jax,))
show("G_jax(x0) from jvp", y_jax)
show("JVP (J·v)_jax", Jv_jax)

# VJP: J(x0)^T · w
w_jax = jnp.array([0.3, -1.0, 2.0])
G_x0_val_jax, vjp_fun_jax = jax.vjp(G_jax, x0_jax_vec)
(wJ_jax,) = vjp_fun_jax(w_jax)
show("G_jax(x0) from vjp", G_x0_val_jax)
show("VJP (J^T·w)_jax", wJ_jax)


def G_torch(x: torch.Tensor) -> torch.Tensor:
    """
    Torch: G: R^3 -> R^3, same as JAX version.
    """
    x0, x1, x2 = x
    return torch.stack([
        x0 * torch.exp(x1),
        torch.sin(x1) + x2**2,
        x0 + x1 + x2
    ])

x0_t_vec = torch.tensor([0.7, -0.2, 0.5])
J_fwd_t = torch.func.jacfwd(G_torch)(x0_t_vec)
J_rev_t = torch.func.jacrev(G_torch)(x0_t_vec)

print("\n2.1) Torch: G : R^3 -> R^3")
show("G_torch(x0)", G_torch(x0_t_vec))
show("Jacobian via torch.func.jacfwd", J_fwd_t)
show("Jacobian via torch.func.jacrev", J_rev_t)

# JVP: J(x0) · v
v_t = torch.tensor([1.0, -2.0, 0.5])
y_t, Jv_t = torch.func.jvp(G_torch, (x0_t_vec,), (v_t,))
show("G_torch(x0) from jvp", y_t)
show("JVP (J·v)_torch", Jv_t)

# VJP: J(x0)^T · w
w_t = torch.tensor([0.3, -1.0, 2.0])
G_x0_val_t, vjp_fun_t = torch.func.vjp(G_torch, x0_t_vec)
(wJ_t,) = vjp_fun_t(w_t)
show("G_torch(x0) from vjp", G_x0_val_t)
show("VJP (J^T·w)_torch", wJ_t)


# ------------------------------------------------------------------------------
# 2.2) G : R^{N1×…×Nk} → R^m  (tensor → vector)
# ------------------------------------------------------------------------------

def G_tensor_jax(X):
    """
    JAX: X ∈ R^{2×3}, G(X) ∈ R^2
      G0 = 0.5 * ||X||^2
      G1 = ⟨A, X⟩
    """
    A = jnp.array([[1.0,  2.0,  3.0],
                   [0.5, -1.0,  4.0]])
    G0 = 0.5 * jnp.sum(X**2)
    G1 = jnp.sum(A * X)
    return jnp.array([G0, G1])

X0_jax_T = jnp.array([[0.1, -0.2, 0.3],
                      [1.0,  0.5, -0.7]])

J_tensor_jax = jax.jacrev(G_tensor_jax)(X0_jax_T)  # (m=2, 2,3)

print("\n2.2) JAX: G : tensor -> vector")
show("G_tensor_jax(X0)", G_tensor_jax(X0_jax_T))
print("J_X G_tensor_jax shape:", J_tensor_jax.shape)
show("∇_X G0_jax", J_tensor_jax[0])
show("∇_X G1_jax", J_tensor_jax[1])

# Flattened equivalence
def G_tensor_flat_jax(x_flat):
    X = x_flat.reshape(2, 3)
    return G_tensor_jax(X)

x0_flat_T_jax = X0_jax_T.reshape(-1)
J_flat_jax = jax.jacrev(G_tensor_flat_jax)(x0_flat_T_jax)  # (2,6)
print("Flattened Jacobian (JAX) shape:", J_flat_jax.shape)
show("Row0 J_flat_jax", J_flat_jax[0])
show("Row1 J_flat_jax", J_flat_jax[1])


def G_tensor_torch(X: torch.Tensor) -> torch.Tensor:
    """
    Torch: X ∈ R^{2×3}, G(X) ∈ R^2
      G0 = 0.5 * ||X||^2
      G1 = ⟨A, X⟩
    """
    A = torch.tensor([[1.0,  2.0,  3.0],
                      [0.5, -1.0,  4.0]])
    G0 = 0.5 * torch.sum(X**2)
    G1 = torch.sum(A * X)
    return torch.stack([G0, G1])

X0_t_T = torch.tensor([[0.1, -0.2, 0.3],
                       [1.0,  0.5, -0.7]])

J_tensor_t = torch.func.jacrev(G_tensor_torch)(X0_t_T)  # (2,2,3)

print("\n2.2) Torch: G : tensor -> vector")
show("G_tensor_torch(X0)", G_tensor_torch(X0_t_T))
print("J_X G_tensor_t shape:", J_tensor_t.shape)
show("∇_X G0_torch", J_tensor_t[0])
show("∇_X G1_torch", J_tensor_t[1])

def G_tensor_flat_torch(x_flat: torch.Tensor) -> torch.Tensor:
    X = x_flat.reshape(2, 3)
    return G_tensor_torch(X)

x0_flat_T_t = X0_t_T.reshape(-1)
J_flat_t = torch.func.jacrev(G_tensor_flat_torch)(x0_flat_T_t)  # (2,6)
print("Flattened Jacobian (Torch) shape:", J_flat_t.shape)
show("Row0 J_flat_t", J_flat_t[0])
show("Row1 J_flat_t", J_flat_t[1])


# ------------------------------------------------------------------------------
# 2.3) G : (X, θ) → R^m  (multi-arg: X and θ)
# ------------------------------------------------------------------------------

def G_multi_jax(X, theta):
    """
    JAX: X, θ ∈ R^{2×2}, G ∈ R^2:
      G0 = ⟨X, θ⟩
      G1 = ||X||^2 + 0.1 ||θ||^2
    """
    G0 = jnp.sum(X * theta)
    G1 = jnp.sum(X**2) + 0.1 * jnp.sum(theta**2)
    return jnp.array([G0, G1])

X0_jax_M = jnp.array([[1.0,  0.5],
                      [-0.3, 2.0]])
theta0_jax = jnp.array([[0.2, -1.0],
                        [1.5,  0.7]])

print("\n2.3) JAX: multi-arg G(X, θ) -> vector")
show("G_multi_jax(X0, θ0)", G_multi_jax(X0_jax_M, theta0_jax))

J_X_jax = jax.jacrev(G_multi_jax, argnums=0)(X0_jax_M, theta0_jax)
J_theta_jax = jax.jacrev(G_multi_jax, argnums=1)(X0_jax_M, theta0_jax)

print("J_X_jax shape:", J_X_jax.shape)
show("∇_X G0_jax", J_X_jax[0])
show("∇_X G1_jax", J_X_jax[1])

print("J_θ_jax shape:", J_theta_jax.shape)
show("∇_θ G0_jax", J_theta_jax[0])
show("∇_θ G1_jax", J_theta_jax[1])


def G_multi_torch(X: torch.Tensor, theta: torch.Tensor) -> torch.Tensor:
    """
    Torch: X, θ ∈ R^{2×2}, G ∈ R^2:
      G0 = ⟨X, θ⟩
      G1 = ||X||^2 + 0.1 ||θ||^2
    """
    G0 = torch.sum(X * theta)
    G1 = torch.sum(X**2) + 0.1 * torch.sum(theta**2)
    return torch.stack([G0, G1])

X0_t_M = torch.tensor([[1.0,  0.5],
                       [-0.3, 2.0]])
theta0_t = torch.tensor([[0.2, -1.0],
                         [1.5,  0.7]])

print("\n2.3) Torch: multi-arg G(X, θ) -> vector")
show("G_multi_torch(X0, θ0)", G_multi_torch(X0_t_M, theta0_t))

J_X_t = torch.func.jacrev(G_multi_torch, argnums=0)(X0_t_M, theta0_t)
J_theta_t = torch.func.jacrev(G_multi_torch, argnums=1)(X0_t_M, theta0_t)

print("J_X_t shape:", J_X_t.shape)
show("∇_X G0_torch", J_X_t[0])
show("∇_X G1_torch", J_X_t[1])

print("J_θ_t shape:", J_theta_t.shape)
show("∇_θ G0_torch", J_theta_t[0])
show("∇_θ G1_torch", J_theta_t[1])


# =========================================================================================
# 3) GRADIENTS W.R.T. SUBSET OF INPUTS + SIMPLE BATCHING (VMAP)
# =========================================================================================

print("\n================ 3) Subset-of-input grads + vmap ================")

# ------------------------------------------------------------------------------
# 3.1) Gradient w.r.t subset of inputs (linear regression loss)
# ------------------------------------------------------------------------------

def lin_loss_jax(w, b, x, y):
    """
    JAX: simple linear regression.
      yhat = x @ w + b
      loss = mean((yhat - y)^2)
    """
    yhat = x @ w + b
    return jnp.mean((yhat - y)**2)

grad_w_only_jax = jax.grad(lin_loss_jax, argnums=0)
grad_w_and_b_jax = jax.grad(lin_loss_jax, argnums=(0, 1))

w_jax = jnp.array([1.0, -2.0, 0.5])
b_jax = 0.3
x_jax_data = jnp.array([[1.0, 0.0, 2.0],
                        [0.5, 1.0, -1.0]])
y_jax_data = jnp.array([1.5, -0.2])

print("\n3.1) JAX: subset-of-input grads (w vs (w,b))")
print("loss_jax =", float(lin_loss_jax(w_jax, b_jax, x_jax_data, y_jax_data)))
dw_jax = grad_w_only_jax(w_jax, b_jax, x_jax_data, y_jax_data)
show("∂loss/∂w_jax", dw_jax)

dw_jax2, db_jax2 = grad_w_and_b_jax(w_jax, b_jax, x_jax_data, y_jax_data)
print("∂loss/∂w_jax2:", dw_jax2, "  ∂loss/∂b_jax2:", db_jax2)


def lin_loss_torch(w: torch.Tensor,
                   b: torch.Tensor,
                   x: torch.Tensor,
                   y: torch.Tensor) -> torch.Tensor:
    """
    Torch: simple linear regression.
      yhat = x @ w + b
      loss = mean((yhat - y)^2)
    """
    yhat = x @ w + b
    return torch.mean((yhat - y)**2)

grad_w_only_t = torch.func.grad(lin_loss_torch, argnums=0)
grad_w_and_b_t = torch.func.grad(lin_loss_torch, argnums=(0, 1))

w_t_reg = torch.tensor([1.0, -2.0, 0.5])
b_t_reg = torch.tensor(0.3)
x_t_data = torch.tensor([[1.0, 0.0, 2.0],
                         [0.5, 1.0, -1.0]])
y_t_data = torch.tensor([1.5, -0.2])

print("\n3.1) Torch: subset-of-input grads (w vs (w,b))")
print("loss_torch =", lin_loss_torch(w_t_reg, b_t_reg, x_t_data, y_t_data).item())
dw_t = grad_w_only_t(w_t_reg, b_t_reg, x_t_data, y_t_data)
show("∂loss/∂w_t", dw_t)

dw_t2, db_t2 = grad_w_and_b_t(w_t_reg, b_t_reg, x_t_data, y_t_data)
print("∂loss/∂w_t2:", dw_t2, "  ∂loss/∂b_t2:", db_t2)


# ------------------------------------------------------------------------------
# 3.2) vmap: batch scalar derivatives and vector gradients
# ------------------------------------------------------------------------------

# Reuse scalar f and its grad from Section 1.1

xs_jax = jnp.linspace(-2.0, 2.0, 5)  # (-2,-1,0,1,2)
df_dx_batched_jax = jax.vmap(df_dx_jax)
print("\n3.2) JAX: vmap df/dx over xs")
show("df/dx_jax(xs)", df_dx_batched_jax(xs_jax))


xs_t = torch.linspace(-2.0, 2.0, steps=5)
df_dx_batched_t = torch.func.vmap(df_dx_torch)
print("\n3.2) Torch: vmap df/dx over xs")
show("df/dx_t(xs)", df_dx_batched_t(xs_t))


# Batch a multi-arg function g(a, x) = a cos(x) over x only
def g_jax(a, x):
    return a * jnp.cos(x)

a_jax = 2.0
g_batched_jax = jax.vmap(lambda x: g_jax(a_jax, x))
show("JAX g(a,x) over xs", g_batched_jax(xs_jax))


def g_torch(a: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
    return a * torch.cos(x)

a_t = torch.tensor(2.0)
g_batched_t = torch.func.vmap(lambda x: g_torch(a_t, x))
show("Torch g(a,x) over xs", g_batched_t(xs_t))


# Batched vector gradient: vmap over grad(F_vec)
xs_batch_jax = jnp.stack([
    jnp.array([0.0,  0.0,  0.0]),
    jnp.array([0.5,  0.2, -0.1]),
    jnp.array([1.0, -0.3,  0.7]),
], axis=0)  # (B,3)

grads_batched_jax = jax.vmap(grad_F_vec_jax)(xs_batch_jax)
print("\n3.2) JAX: batched ∇F_vec")
show("grads_batched_jax", grads_batched_jax)


xs_batch_t = torch.stack([
    torch.tensor([0.0,  0.0,  0.0]),
    torch.tensor([0.5,  0.2, -0.1]),
    torch.tensor([1.0, -0.3,  0.7]),
], dim=0)  # (B,3)

grads_batched_t = torch.func.vmap(grad_F_vec_torch)(xs_batch_t)
print("\n3.2) Torch: batched ∇F_vec")
show("grads_batched_t", grads_batched_t)


# Batched Jacobians: vmap over jacrev(G)
batched_J_jax = jax.vmap(jax.jacrev(G_jax))(xs_batch_jax)   # (B,3,3)
print("JAX batched Jacobians shape:", batched_J_jax.shape)
show("JAX J(xs_batch[0])", batched_J_jax[0])

batched_J_t = torch.func.vmap(torch.func.jacrev(G_torch))(xs_batch_t)  # (B,3,3)
print("Torch batched Jacobians shape:", batched_J_t.shape)
show("Torch J(xs_batch[0])", batched_J_t[0])



1.1) JAX: f : R -> R
  x0      = 1.23
  f(x0)   = 1.0937788486480713
  df/dx|x0: 0.5802377462387085

1.1) Torch: f : R -> R
  x0      = 1.2300000190734863
  f(x0)   = 1.0937788486480713
  df/dx|x0: 0.5802377462387085

1.2) JAX: F : R^n -> R
x0_vec_jax:
[ 0.7 -1.2  0.3]
shape: (3,) 

F(x0)     = 1.3355987071990967
∇_x F(x0)_jax:
[ 1.8648423 -1.1296424  1.2733364]
shape: (3,) 

shape(x0) = (3,) shape(∇F) = (3,)

1.2) Torch: F : R^n -> R
x0_vec_t:
tensor([ 0.7000, -1.2000,  0.3000], device='mps:0')
shape: torch.Size([3]) 

F(x0)     = 1.3355987071990967
∇_x F(x0)_torch:
tensor([ 1.8648, -1.1296,  1.2733], device='mps:0')
shape: torch.Size([3]) 

shape(x0) = torch.Size([3]) shape(∇F) = torch.Size([3])

1.3) JAX: F : R^{N1×…×Nk} -> R
X0_jax:
[[ 0.1 -0.2  0.3]
 [ 1.   0.5 -0.7]]
shape: (2, 3) 

F(X0)      = 0.9399999976158142
∇_X F(X0)_jax:
[[ 0.1 -0.2  0.3]
 [ 1.   0.5 -0.7]]
shape: (2, 3) 


1.3) Torch: F : R^{N1×…×Nk} -> R
X0_t:
tensor([[ 0.1000, -0.2000,  0.3000],
        [ 1.0000,  0.