# Ejemplos Prácticos con JAX

Este cuaderno demuestra las capacidades principales de JAX mediante ejemplos prácticos.

## 1. Setup e Importaciones

In [None]:
import jax
import jax.numpy as jnp
from jax import grad, jit, vmap
import numpy as np
import matplotlib.pyplot as plt

print(f"JAX version: {jax.__version__}")
print(f"Backend: {jax.default_backend()}")

## 2. Transformación `grad`: Diferenciación Automática

JAX puede calcular derivadas automáticamente de cualquier función.

In [None]:
# Función simple
def f(x):
    return x**3 + 2*x**2 - 5*x + 3

# Calcular derivada
df_dx = grad(f)

x = 2.0
print(f"f({x}) = {f(x)}")
print(f"f'({x}) = {df_dx(x)}")
print(f"Esperado: 3*{x}^2 + 4*{x} - 5 = {3*x**2 + 4*x - 5}")

# Segunda derivada
d2f_dx2 = grad(grad(f))
print(f"f''({x}) = {d2f_dx2(x)}")

## 3. Transformación `jit`: Compilación JIT

JIT compila funciones con XLA para mayor rendimiento.

In [None]:
def slow_function(x):
    for _ in range(10):
        x = jnp.dot(x, x)
    return x

# Versión compilada
fast_function = jit(slow_function)

x = jnp.ones((100, 100))

# Calentar JIT
_ = fast_function(x).block_until_ready()

# Comparar tiempos
import time

start = time.time()
_ = slow_function(x).block_until_ready()
time_slow = time.time() - start

start = time.time()
_ = fast_function(x).block_until_ready()
time_fast = time.time() - start

print(f"Sin JIT: {time_slow*1000:.3f} ms")
print(f"Con JIT: {time_fast*1000:.3f} ms")
print(f"Aceleración: {time_slow/time_fast:.1f}x")

## 4. Transformación `vmap`: Vectorización Automática

vmap aplica automáticamente funciones a batches de datos.

In [None]:
def compute_norm(vector):
    """Calcula la norma L2 de un vector"""
    return jnp.sqrt(jnp.sum(vector ** 2))

# Batch de vectores
vectors = jnp.array([[1.0, 2.0, 3.0],
                     [4.0, 5.0, 6.0],
                     [7.0, 8.0, 9.0]])

print("Vectores:")
print(vectors)

# Con vmap (automático)
compute_norms_batch = vmap(compute_norm)
norms = compute_norms_batch(vectors)

print(f"\nNormas: {norms}")

## 5. Ejemplo: Regresión Lineal

Implementación completa de regresión lineal con gradient descent.

In [None]:
# Generar datos sintéticos
np.random.seed(42)
n_samples = 100
true_w, true_b = 3.5, 2.0

X = np.random.uniform(-5, 5, n_samples)
y = true_w * X + true_b + np.random.normal(0, 1, n_samples)

X_jax = jnp.array(X)
y_jax = jnp.array(y)

# Visualizar datos
plt.figure(figsize=(8, 5))
plt.scatter(X, y, alpha=0.5)
plt.xlabel('X')
plt.ylabel('y')
plt.title('Datos de entrenamiento')
plt.grid(True, alpha=0.3)
plt.show()

print(f"Parámetros verdaderos: w={true_w}, b={true_b}")

In [None]:
# Modelo y función de pérdida
def predict(params, X):
    return params['w'] * X + params['b']

def mse_loss(params, X, y):
    predictions = predict(params, X)
    return jnp.mean((predictions - y) ** 2)

# Inicializar parámetros
key = jax.random.PRNGKey(0)
params = {
    'w': jax.random.normal(key, shape=()),
    'b': jax.random.normal(key, shape=())
}

print(f"Parámetros iniciales: w={params['w']:.4f}, b={params['b']:.4f}")

# Compilar gradiente
grad_loss = jit(grad(mse_loss))

# Entrenamiento
learning_rate = 0.01
n_epochs = 100
history = {'loss': [], 'w': [], 'b': []}

for epoch in range(n_epochs):
    grads = grad_loss(params, X_jax, y_jax)
    params['w'] -= learning_rate * grads['w']
    params['b'] -= learning_rate * grads['b']
    
    loss = mse_loss(params, X_jax, y_jax)
    history['loss'].append(float(loss))
    history['w'].append(float(params['w']))
    history['b'].append(float(params['b']))
    
    if (epoch + 1) % 20 == 0:
        print(f"Época {epoch+1:3d} | Loss: {loss:.4f} | w: {params['w']:.4f} | b: {params['b']:.4f}")

print(f"\nParámetros finales: w={params['w']:.4f}, b={params['b']:.4f}")
print(f"Parámetros reales:  w={true_w}, b={true_b}")

In [None]:
# Visualizar resultados
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Regresión
axes[0].scatter(X, y, alpha=0.5, label='Datos')
x_line = jnp.linspace(X.min(), X.max(), 100)
y_pred = predict(params, x_line)
y_true_line = true_w * x_line + true_b
axes[0].plot(x_line, y_pred, 'r-', linewidth=2, label='Predicción')
axes[0].plot(x_line, y_true_line, 'g--', linewidth=2, label='Real')
axes[0].set_xlabel('X')
axes[0].set_ylabel('y')
axes[0].set_title('Regresión Lineal')
axes[0].legend()
axes[0].grid(True, alpha=0.3)

# Loss
axes[1].plot(history['loss'], linewidth=2)
axes[1].set_xlabel('Época')
axes[1].set_ylabel('MSE Loss')
axes[1].set_title('Evolución de la Pérdida')
axes[1].grid(True, alpha=0.3)
axes[1].set_yscale('log')

plt.tight_layout()
plt.show()

## 6. Ejemplo: Red Neuronal Simple

Implementación de una red neuronal para clasificación del dataset Iris.

In [None]:
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler

# Cargar datos
iris = load_iris()
X = iris.data
y = iris.target

# Dividir y normalizar
X_train, X_test, y_train, y_test = train_test_split(
    X, y, test_size=0.2, random_state=42, stratify=y
)

scaler = StandardScaler()
X_train = scaler.fit_transform(X_train)
X_test = scaler.transform(X_test)

# One-hot encoding
def one_hot(y, num_classes):
    return jnp.eye(num_classes)[y]

X_train = jnp.array(X_train)
X_test = jnp.array(X_test)
y_train_onehot = one_hot(y_train, 3)
y_test_onehot = one_hot(y_test, 3)

print(f"Train: {X_train.shape}, Test: {X_test.shape}")
print(f"Clases: {iris.target_names}")

In [None]:
# Inicializar red neuronal
def init_network(layer_sizes, key):
    params = []
    keys = jax.random.split(key, len(layer_sizes) - 1)
    
    for i, (n_in, n_out) in enumerate(zip(layer_sizes[:-1], layer_sizes[1:])):
        w_key, b_key = jax.random.split(keys[i])
        scale = jnp.sqrt(2.0 / (n_in + n_out))
        W = jax.random.normal(w_key, (n_in, n_out)) * scale
        b = jnp.zeros(n_out)
        params.append({'W': W, 'b': b})
    
    return params

# Forward pass
def forward(params, x):
    activation = x
    for i in range(len(params) - 1):
        z = jnp.dot(activation, params[i]['W']) + params[i]['b']
        activation = jnp.maximum(0, z)  # ReLU
    
    logits = jnp.dot(activation, params[-1]['W']) + params[-1]['b']
    exp_x = jnp.exp(logits - jnp.max(logits, axis=-1, keepdims=True))
    return exp_x / jnp.sum(exp_x, axis=-1, keepdims=True)  # Softmax

# Loss y accuracy
def cross_entropy_loss(params, X, y_true):
    y_pred = forward(params, X)
    y_pred = jnp.clip(y_pred, 1e-10, 1.0)
    return -jnp.mean(jnp.sum(y_true * jnp.log(y_pred), axis=1))

def accuracy(params, X, y_true):
    y_pred = forward(params, X)
    predicted_class = jnp.argmax(y_pred, axis=1)
    true_class = jnp.argmax(y_true, axis=1)
    return jnp.mean(predicted_class == true_class)

# Arquitectura: 4 -> 16 -> 8 -> 3
layer_sizes = [4, 16, 8, 3]
key = jax.random.PRNGKey(42)
params = init_network(layer_sizes, key)

print(f"Arquitectura: {' -> '.join(map(str, layer_sizes))}")

In [None]:
# Entrenar
grad_loss = jit(grad(cross_entropy_loss))
loss_fn = jit(cross_entropy_loss)
accuracy_fn = jit(accuracy)

learning_rate = 0.01
n_epochs = 200
history = {'train_loss': [], 'test_loss': [], 'train_acc': [], 'test_acc': []}

for epoch in range(n_epochs):
    grads = grad_loss(params, X_train, y_train_onehot)
    for i in range(len(params)):
        params[i]['W'] -= learning_rate * grads[i]['W']
        params[i]['b'] -= learning_rate * grads[i]['b']
    
    if (epoch + 1) % 20 == 0 or epoch == 0:
        train_loss = loss_fn(params, X_train, y_train_onehot)
        test_loss = loss_fn(params, X_test, y_test_onehot)
        train_acc = accuracy_fn(params, X_train, y_train_onehot)
        test_acc = accuracy_fn(params, X_test, y_test_onehot)
        
        history['train_loss'].append(float(train_loss))
        history['test_loss'].append(float(test_loss))
        history['train_acc'].append(float(train_acc))
        history['test_acc'].append(float(test_acc))
        
        print(f"Época {epoch+1:3d} | Train Loss: {train_loss:.4f} | "
              f"Test Loss: {test_loss:.4f} | Train Acc: {train_acc:.3f} | Test Acc: {test_acc:.3f}")

final_acc = accuracy_fn(params, X_test, y_test_onehot)
print(f"\nAccuracy final: {final_acc:.2%}")

In [None]:
# Visualizar resultados
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Loss
axes[0].plot(history['train_loss'], label='Train', linewidth=2)
axes[0].plot(history['test_loss'], label='Test', linewidth=2)
axes[0].set_xlabel('Época (x20)')
axes[0].set_ylabel('Loss')
axes[0].set_title('Evolución de la Pérdida')
axes[0].legend()
axes[0].grid(True, alpha=0.3)

# Accuracy
axes[1].plot(history['train_acc'], label='Train', linewidth=2)
axes[1].plot(history['test_acc'], label='Test', linewidth=2)
axes[1].set_xlabel('Época (x20)')
axes[1].set_ylabel('Accuracy')
axes[1].set_title('Evolución del Accuracy')
axes[1].legend()
axes[1].grid(True, alpha=0.3)
axes[1].set_ylim([0, 1.05])

plt.tight_layout()
plt.show()

# Matriz de confusión
from sklearn.metrics import confusion_matrix

y_pred = forward(params, X_test)
y_pred_class = jnp.argmax(y_pred, axis=1)
y_true_class = jnp.argmax(y_test_onehot, axis=1)

cm = confusion_matrix(y_true_class, y_pred_class)
print("\nMatriz de Confusión:")
print(cm)

## 7. Conclusión

JAX ofrece:
- **Diferenciación automática** con `grad`
- **Alto rendimiento** con `jit`
- **Vectorización** con `vmap`
- **API familiar** tipo NumPy

Es ideal para investigación y proyectos que requieren máximo control y rendimiento.