# **Investigación de Jax**

## **Ejemplo práctico.**

Para este ejemplo usaremos el dataset de MNIST de números manuscritos, para poder ver las diferencias entre Jax, Tensorflow y PyTorch. Crearemos un modelo para predecir y sacaremos las conclusiones de como trabaja cada uno.

In [None]:
import keras

Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz
[1m11490434/11490434[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 0us/step


### Jax

In [None]:
import jax
import jax.numpy as jnp
from jax import grad, random
from tensorflow import keras
import numpy as np

# Cargar MNIST desde Keras
(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()

x_train = jnp.array(x_train) / 255.0
x_train = x_train[..., None]  # (N,H,W,1)
y_train = jnp.array(y_train)

x_test = jnp.array(x_test) / 255.0
x_test = x_test[..., None]
y_test = jnp.array(y_test)

# Función para generar mini-batches
def get_batches(x, y, batch_size=128):
    n = x.shape[0]
    for i in range(0, n, batch_size):
        yield x[i:i+batch_size], y[i:i+batch_size]

# Inicializar parámetros
def init_params(key):
    keys = jax.random.split(key, 6)
    return {
        "w1": jax.random.normal(keys[0], (3,3,1,32)) * 0.1,
        "w2": jax.random.normal(keys[1], (3,3,32,64)) * 0.1,
        "w3": jax.random.normal(keys[2], (36864,128)) * 0.1,
        "w4": jax.random.normal(keys[3], (128,10)) * 0.1,
    }

# Modelo convolucional
def model(params, x):
    x = jax.lax.conv_general_dilated(
        x, params["w1"], (1,1), "VALID",
        dimension_numbers=("NHWC","HWIO","NHWC")
    )
    x = jax.nn.relu(x)

    x = jax.lax.conv_general_dilated(
        x, params["w2"], (1,1), "VALID",
        dimension_numbers=("NHWC","HWIO","NHWC")
    )
    x = jax.nn.relu(x)

    x = x.reshape(x.shape[0], -1)
    x = jax.nn.relu(x @ params["w3"])
    return x @ params["w4"]

# Pérdida (MSE)
def loss_fn(params, x, y):
    logits = model(params, x)
    labels = jax.nn.one_hot(y, 10)
    return jnp.mean(jnp.sum((logits - labels) ** 2, axis=1))

# Entrenamiento
params = init_params(jax.random.PRNGKey(0))
lr = 0.01
epochs = 1

@jax.jit
def train_step(params, images, labels):
    grads = grad(loss_fn)(params, images, labels)
    return jax.tree_util.tree_map(lambda p, g: p - lr * g, params, grads)

for epoch in range(epochs):
    print(f"\n--- Epoch {epoch+1}/{epochs} ---")
    for i, (images, labels) in enumerate(get_batches(x_train, y_train, batch_size=128)):
        params = train_step(params, images, labels)
        if i % 100 == 0: # Cada 100 batches muestra el progreso
            current_loss = loss_fn(params, images, labels)
            print(f"  Batch {i+1}, Loss: {current_loss:.4f}")

# Prueba
logits_test = model(params, x_test[:10])
preds = jnp.argmax(logits_test, axis=1)
print("\nPredicciones:", preds)
print("Etiquetas reales:", y_test[:10])

Starting JAX training. The first step will involve JIT compilation, which may take some time.

--- Epoch 1/1 ---
  Batch 1, Loss: 51.3541
  Batch 101, Loss: 0.8574
  Batch 201, Loss: 0.7220
  Batch 301, Loss: 0.6174
  Batch 401, Loss: 0.5756

Predicciones: [7 2 1 0 4 1 8 4 6 9]
Etiquetas reales: [7 2 1 0 4 1 4 9 5 9]


### Tensorflow

In [None]:
import tensorflow as tf

# Cargar MNIST
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()

x_train = x_train.reshape(-1, 28, 28, 1).astype('float32') / 255.0
x_test = x_test.reshape(-1, 28, 28, 1).astype('float32') / 255.0

# Modelo CNN
model = tf.keras.Sequential([
    tf.keras.layers.Conv2D(32, kernel_size=(3, 3), activation='relu', input_shape=(28, 28, 1)),
    tf.keras.layers.Conv2D(64, kernel_size=(3, 3), activation='relu'),
    tf.keras.layers.Flatten(),
    tf.keras.layers.Dense(128, activation='relu'),
    tf.keras.layers.Dense(10, activation='softmax') # 10 clases para MNIST
])

model.compile(
    optimizer=tf.keras.optimizers.Adam(learning_rate=0.01),
    loss='sparse_categorical_crossentropy', # Para etiquetas enteras
    metrics=['accuracy']
)

# Entrenamiento
print("\n--- Entrenamiento TensorFlow ---")
model.fit(
    x_train, y_train,
    epochs=1,
    batch_size=128,
    verbose=1
)

# Evaluación
loss, accuracy = model.evaluate(x_test, y_test, verbose=0)
print(f"TensorFlow - Test Loss: {loss:.4f}, Test Accuracy: {accuracy:.4f}")

# Predicciones de ejemplo
logits_test = model.predict(x_test[:10])
preds = tf.argmax(logits_test, axis=1)
print("\nPredicciones (TensorFlow):", preds.numpy())
print("Etiquetas reales (TensorFlow):", y_test[:10])


--- Entrenamiento TensorFlow ---
[1m469/469[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m168s[0m 354ms/step - accuracy: 0.8548 - loss: 0.5382
TensorFlow - Test Loss: 0.0774, Test Accuracy: 0.9751
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 296ms/step

Predicciones (TensorFlow): [7 2 1 0 4 1 4 9 5 9]
Etiquetas reales (TensorFlow): [7 2 1 0 4 1 4 9 5 9]


### Pytorch

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

# Cargar MNIST
train_dataset = datasets.MNIST('./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST('./data', train=False, download=True, transform=transform)

train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=128, shuffle=False)

# Modelo CNN para PyTorch
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3)
        self.relu1 = nn.ReLU()
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3)
        self.relu2 = nn.ReLU()
        self.fc1 = nn.Linear(64 * 24 * 24, 128)
        self.relu3 = nn.ReLU()
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = self.relu1(self.conv1(x))
        x = self.relu2(self.conv2(x))
        x = x.view(-1, 64 * 24 * 24)
        x = self.relu3(self.fc1(x))
        x = self.fc2(x)
        return x

model = Net()

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.01)

# Entrenamiento
print("\n--- Entrenamiento PyTorch ---")
epochs = 1
for epoch in range(epochs):
    model.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()
        if batch_idx % 100 == 0: # Mostrar progreso
            print(f'  Epoch {epoch+1}, Batch {batch_idx*len(data)}/{len(train_loader.dataset)} Loss: {loss.item():.4f}')

# Evaluación
model.eval()
correct = 0
total = 0
with torch.no_grad():
    for data, target in test_loader:
        output = model(data)
        _, predicted = torch.max(output.data, 1)
        total += target.size(0)
        correct += (predicted == target).sum().item()

accuracy = correct / total
print(f"PyTorch - Test Accuracy: {accuracy:.4f}")

# Predicciones de ejemplo
data, target = next(iter(test_loader))
output = model(data[:10])
preds = torch.argmax(output, axis=1)
print("\nPredicciones (PyTorch):", preds.numpy())
print("Etiquetas reales (PyTorch):", target[:10].numpy())

100%|██████████| 9.91M/9.91M [00:00<00:00, 22.4MB/s]
100%|██████████| 28.9k/28.9k [00:00<00:00, 619kB/s]
100%|██████████| 1.65M/1.65M [00:00<00:00, 5.61MB/s]
100%|██████████| 4.54k/4.54k [00:00<00:00, 8.78MB/s]



--- Entrenamiento PyTorch ---
  Epoch 1, Batch 0/60000 Loss: 2.3002
  Epoch 1, Batch 12800/60000 Loss: 0.1568
  Epoch 1, Batch 25600/60000 Loss: 0.1032
  Epoch 1, Batch 38400/60000 Loss: 0.1191
  Epoch 1, Batch 51200/60000 Loss: 0.1207
PyTorch - Test Accuracy: 0.9610

Predicciones (PyTorch): [7 2 1 0 4 1 4 9 6 9]
Etiquetas reales (PyTorch): [7 2 1 0 4 1 4 9 5 9]


Tras los tres ejemplos, lo que más cabe destacar es la diferencia que hay en la sintaxis y forma de trabajar de JAX en comparación a TensorFlow y PyTorch. JAX es muy matemático y la mayoría de pasos tienes que hacerlos tú, por lo que se siente muy manual. Por otro lado, en TensorFlow y PyTorch tenemos la ventaja de que esos modelos ya están creados. Bajo mi experiencia, se ve bastante más complejo JAX al tener más libertad y tener que crear por ti mismo facilidades que TensorFlow y PyTorch ya traen por defecto.

Podemos concluir y demostrar que como dijimos, JAX se centra en la programación funcional y que el usuario pueda controlar más a fondo el flujo, mientras que los otros dos priorizan la productividad del desarrollador.