# JAX vs PyTorch 代码示例

本 Notebook 展示 JAX 和 PyTorch 的核心特性对比。

## 1. 安装依赖

In [None]:
# 安装 JAX
!pip install -q jax jaxlib

# PyTorch 通常已预装在 Colab 中
import torch
import jax
import jax.numpy as jnp
import numpy as np
import time

print(f"PyTorch version: {torch.__version__}")
print(f"JAX version: {jax.__version__}")

## 2. 基础操作对比

In [None]:
# PyTorch
x_torch = torch.randn(1000, 1000)
y_torch = torch.randn(1000, 1000)

start = time.time()
z_torch = torch.matmul(x_torch, y_torch)
print(f"PyTorch 矩阵乘法: {time.time() - start:.4f}s")

# JAX
x_jax = jax.random.normal(jax.random.PRNGKey(0), (1000, 1000))
y_jax = jax.random.normal(jax.random.PRNGKey(1), (1000, 1000))

start = time.time()
z_jax = jnp.matmul(x_jax, y_jax)
z_jax.block_until_ready()  # JAX 是异步的
print(f"JAX 矩阵乘法: {time.time() - start:.4f}s")

## 3. 自动微分对比

In [None]:
# PyTorch 自动微分
def pytorch_loss(x):
    return (x ** 2).sum()

x_torch = torch.randn(100, requires_grad=True)
loss = pytorch_loss(x_torch)
loss.backward()
print(f"PyTorch 梯度形状: {x_torch.grad.shape}")
print(f"PyTorch 梯度前5个值: {x_torch.grad[:5]}")

# JAX 自动微分
def jax_loss(x):
    return (x ** 2).sum()

x_jax = jax.random.normal(jax.random.PRNGKey(0), (100,))
grad_fn = jax.grad(jax_loss)
grads = grad_fn(x_jax)
print(f"\nJAX 梯度形状: {grads.shape}")
print(f"JAX 梯度前5个值: {grads[:5]}")

## 4. JIT 编译对比

In [None]:
# 定义一个复杂函数
def complex_fn(x):
    for _ in range(10):
        x = x @ x.T
        x = x + 1
    return x.sum()

# PyTorch (Eager)
x_torch = torch.randn(100, 100)
start = time.time()
for _ in range(100):
    result = complex_fn(x_torch)
pytorch_time = time.time() - start
print(f"PyTorch (Eager): {pytorch_time:.4f}s")

# PyTorch (torch.compile)
complex_fn_compiled = torch.compile(complex_fn)
start = time.time()
for _ in range(100):
    result = complex_fn_compiled(x_torch)
pytorch_compile_time = time.time() - start
print(f"PyTorch (Compiled): {pytorch_compile_time:.4f}s")

# JAX (JIT)
def jax_complex_fn(x):
    for _ in range(10):
        x = x @ x.T
        x = x + 1
    return x.sum()

jax_complex_fn_jit = jax.jit(jax_complex_fn)
x_jax = jax.random.normal(jax.random.PRNGKey(0), (100, 100))

# 预热
_ = jax_complex_fn_jit(x_jax).block_until_ready()

start = time.time()
for _ in range(100):
    result = jax_complex_fn_jit(x_jax).block_until_ready()
jax_time = time.time() - start
print(f"JAX (JIT): {jax_time:.4f}s")

print(f"\n加速比:")
print(f"PyTorch Compile vs Eager: {pytorch_time/pytorch_compile_time:.2f}x")
print(f"JAX JIT vs PyTorch Eager: {pytorch_time/jax_time:.2f}x")

## 5. vmap (自动向量化) - JAX 独有

In [None]:
# 单样本函数
def predict_single(params, x):
    W, b = params
    return jnp.dot(x, W) + b

# 手动批处理 (PyTorch 风格)
def predict_batch_manual(params, X):
    return jnp.stack([predict_single(params, x) for x in X])

# 自动向量化 (JAX vmap)
predict_batch_vmap = jax.vmap(predict_single, in_axes=(None, 0))

# 测试
W = jax.random.normal(jax.random.PRNGKey(0), (10, 5))
b = jax.random.normal(jax.random.PRNGKey(1), (5,))
params = (W, b)
X = jax.random.normal(jax.random.PRNGKey(2), (32, 10))

# 手动批处理
start = time.time()
result_manual = predict_batch_manual(params, X)
manual_time = time.time() - start

# vmap
start = time.time()
result_vmap = predict_batch_vmap(params, X)
vmap_time = time.time() - start

print(f"手动批处理: {manual_time:.6f}s")
print(f"vmap: {vmap_time:.6f}s")
print(f"加速比: {manual_time/vmap_time:.2f}x")
print(f"结果一致: {jnp.allclose(result_manual, result_vmap)}")

## 6. 简单神经网络训练对比

In [None]:
# PyTorch 版本
import torch.nn as nn
import torch.optim as optim

class PyTorchMLP(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(10, 64)
        self.fc2 = nn.Linear(64, 1)
    
    def forward(self, x):
        x = torch.relu(self.fc1(x))
        return self.fc2(x)

model_torch = PyTorchMLP()
optimizer = optim.Adam(model_torch.parameters(), lr=0.01)
criterion = nn.MSELoss()

# 训练数据
X_train = torch.randn(100, 10)
y_train = torch.randn(100, 1)

# 训练循环
for epoch in range(10):
    optimizer.zero_grad()
    pred = model_torch(X_train)
    loss = criterion(pred, y_train)
    loss.backward()
    optimizer.step()
    if epoch % 2 == 0:
        print(f"PyTorch Epoch {epoch}, Loss: {loss.item():.4f}")

In [None]:
# JAX 版本 (纯函数式)
import optax

def jax_mlp(params, x):
    W1, b1, W2, b2 = params
    x = jax.nn.relu(jnp.dot(x, W1) + b1)
    return jnp.dot(x, W2) + b2

def mse_loss(params, X, y):
    pred = jax.vmap(lambda x: jax_mlp(params, x))(X)
    return jnp.mean((pred - y) ** 2)

# 初始化参数
key = jax.random.PRNGKey(0)
keys = jax.random.split(key, 4)
W1 = jax.random.normal(keys[0], (10, 64)) * 0.1
b1 = jnp.zeros(64)
W2 = jax.random.normal(keys[1], (64, 1)) * 0.1
b2 = jnp.zeros(1)
params = (W1, b1, W2, b2)

# 优化器
optimizer = optax.adam(0.01)
opt_state = optimizer.init(params)

# 训练数据
X_train_jax = jax.random.normal(keys[2], (100, 10))
y_train_jax = jax.random.normal(keys[3], (100, 1))

# JIT 编译训练步骤
@jax.jit
def train_step(params, opt_state, X, y):
    loss, grads = jax.value_and_grad(mse_loss)(params, X, y)
    updates, opt_state = optimizer.update(grads, opt_state)
    params = optax.apply_updates(params, updates)
    return params, opt_state, loss

# 训练循环
for epoch in range(10):
    params, opt_state, loss = train_step(params, opt_state, X_train_jax, y_train_jax)
    if epoch % 2 == 0:
        print(f"JAX Epoch {epoch}, Loss: {loss:.4f}")

## 总结

通过以上示例可以看到：

**PyTorch 优势**：
- 代码更直观，易于理解
- 面向对象设计，符合传统编程习惯
- 调试方便

**JAX 优势**：
- JIT 编译带来显著性能提升
- vmap 自动向量化非常强大
- 函数式编程，可组合性强
- 更灵活的自动微分

选择哪个框架取决于你的具体需求！