<a href="https://colab.research.google.com/github/HarrisonSantiago/WebsiteNotebooks/blob/main/Coding/JAX_vs_Numpy_network.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

This notebook is to show the speed difference in using jax and numpy in creating and training a neural network from scratch. It accompanies my post at https://harrisonsantiago.com/?page_id=86

#utils

In [1]:
import numpy as np

import jax
import jax.numpy as jnp
from jax import grad, jit, vmap, device_put, random
from functools import partial
from jax import random

from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split
from sklearn.neighbors import KNeighborsClassifier
from scipy.ndimage import gaussian_filter

import matplotlib.pyplot as plt
import seaborn as sns

from abc import ABC, abstractmethod
from typing import Callable, Optional, Union, Tuple, List, Dict
import time
from dataclasses import dataclass


def plot_results(X_test, y_test, predictions, save_path='prediction_results.png'):
    """Plot comprehensive results for neural network classification.

    Args:
        X_test: Test input data of shape (n_samples, 2)
        y_test: Test target data of shape (n_samples, 3) - one-hot encoded
        network: Trained neural network
        save_path: Path to save the plot
    """
    sns.set_theme()
    fig = plt.figure(figsize=(20, 5))

    true_classes = np.argmax(y_test, axis=1)
    pred_classes = np.argmax(predictions, axis=1)

    # Create meshgrid for decision boundary
    x_min, x_max = X_test[:, 0].min() - 0.5, X_test[:, 0].max() + 0.5
    y_min, y_max = X_test[:, 1].min() - 0.5, X_test[:, 1].max() + 0.5
    xx, yy = np.meshgrid(np.arange(x_min, x_max, 0.02),
                        np.arange(y_min, y_max, 0.02))
    grid_points = np.c_[xx.ravel(), yy.ravel()]

    # Function to create smooth decision boundaries
    def get_smooth_decision_boundary(X, y, grid_points):
        knn = KNeighborsClassifier(n_neighbors=5)
        knn.fit(X, y)
        Z = knn.predict(grid_points)
        Z = Z.reshape(xx.shape)
        Z_smooth = gaussian_filter(Z, sigma=2)
        return Z_smooth

    # Get decision boundaries for true and predicted classes
    Z_true = get_smooth_decision_boundary(X_test, true_classes, grid_points)
    Z_pred = get_smooth_decision_boundary(X_test, pred_classes, grid_points)

    # Plot true classes with true decision boundary
    ax1 = plt.subplot(1, 3, 1)
    contour = ax1.contourf(xx, yy, Z_true, alpha=0.3, cmap='viridis', vmin=0, vmax=2)
    scatter = ax1.scatter(X_test[:, 0], X_test[:, 1],
                         c=true_classes,
                         cmap='viridis',
                         alpha=1,
                         vmin=0,
                         vmax=2,
                         edgecolors='black',
                         linewidth=1)
    plt.colorbar(scatter, ax=ax1, ticks=[0, 1, 2])
    ax1.set_title('True Classes with Ideal Decision Boundary')
    ax1.set_xlabel('X1')
    ax1.set_ylabel('X2')

    # Plot predicted classes with model's decision boundary
    ax2 = plt.subplot(1, 3, 2)
    contour = ax2.contourf(xx, yy, Z_pred, alpha=0.3, cmap='viridis', vmin=0, vmax=2)
    scatter = ax2.scatter(X_test[:, 0], X_test[:, 1],
                         c=pred_classes,
                         cmap='viridis',
                         alpha=1,
                         vmin=0,
                         vmax=2,
                         edgecolors='black',
                         linewidth=1)
    plt.colorbar(scatter, ax=ax2, ticks=[0, 1, 2])
    ax2.set_title('Predicted Classes with Model Decision Boundary')
    ax2.set_xlabel('X1')
    ax2.set_ylabel('X2')

    # Plot confusion matrix
    ax3 = plt.subplot(1, 3, 3)
    confusion = np.zeros((3, 3))
    for t, p in zip(true_classes, pred_classes):
        confusion[t, p] += 1
    confusion_normalized = confusion / confusion.sum(axis=1, keepdims=True)

    im = ax3.imshow(confusion_normalized, cmap='Blues', vmin=0, vmax=1)
    plt.colorbar(im, ax=ax3)
    ax3.set_title('Confusion Matrix (Normalized)')
    ax3.set_xlabel('Predicted Class')
    ax3.set_ylabel('True Class')

    # Add text annotations to confusion matrix
    for i in range(3):
        for j in range(3):
            text_color = 'white' if confusion_normalized[i, j] > 0.5 else 'black'
            ax3.text(j, i, f'{confusion[i, j]:.0f}\n({confusion_normalized[i, j]:.2f})',
                    ha='center', va='center',
                    color=text_color)

    plt.tight_layout()
    plt.savefig(save_path, dpi=300, bbox_inches='tight')
    plt.close()


def generate_circle_data_np(n_points=100, n_classes=3, noise=0.1, max_radius=1.0):
    """
    Generate concentric circles dataset for classification.

    Parameters:
    -----------
    n_points : int
        Number of points per class
    n_classes : int
        Number of classes (circles)
    noise : float
        Standard deviation of Gaussian noise
    max_radius : float
        Radius of the outermost circle

    Returns:
    --------
    X : ndarray of shape (n_points * n_classes, 2)
    y : ndarray of shape (n_points * n_classes,)
    """
    X = np.zeros((n_points * n_classes, 2))
    y = np.zeros(n_points * n_classes, dtype='uint8')

    for class_idx in range(n_classes):
      ix = range(n_points * class_idx, n_points * (class_idx + 1))

      # Quadratically increasing spaces (more space between outer rings)
      r = max_radius * ((n_classes - class_idx) / n_classes) ** 2
      theta = np.linspace(0, 2*np.pi, n_points) + np.random.randn(n_points) * noise
      r_noise = r + np.random.randn(n_points) * noise * r

      # polar -> Cartesian
      X[ix] = np.column_stack((
          r_noise * np.cos(theta),
          r_noise * np.sin(theta)
      ))

      y[ix] = class_idx

      y_onehot = np.eye(3)[y]


    scaler = StandardScaler()
    X_scaled = scaler.fit_transform(X)

    X_train, X_test, y_train, y_test = train_test_split(
        X_scaled, y_onehot, test_size=0.2, random_state=42, stratify=y
    )

    return X_train, X_test, y_train, y_test



def generate_circle_data_jax(key, n_points=100, n_classes=3, noise=0.1, max_radius=1.0):
    """
    Generate concentric circles dataset for classification using JAX.

    Parameters:
    -----------
    key : jax.random.PRNGKey
        Random number generator key
    n_points : int
        Number of points per class
    n_classes : int
        Number of classes (circles)
    noise : float
        Standard deviation of Gaussian noise
    max_radius : float
        Radius of the outermost circle

    Returns:
    --------
    X : ndarray of shape (n_points * n_classes, 2)
    y : ndarray of shape (n_points * n_classes,)
    """
    X = jnp.zeros((n_points * n_classes, 2))
    y = jnp.zeros(n_points * n_classes, dtype='int32')

    for class_idx in range(n_classes):
        # Get new random key for this iteration
        key, subkey = random.split(key)

        # Convert range to array for indexing
        ix = jnp.arange(n_points * class_idx, n_points * (class_idx + 1))

        r = max_radius * ((n_classes - class_idx) / n_classes) ** 2
        theta = jnp.linspace(0, 2*jnp.pi, n_points) + \
                random.normal(key, (n_points,)) * noise
        key, subkey = random.split(key)
        r_noise = r + random.normal(subkey, (n_points,)) * noise * r

        # Convert polar coordinates to Cartesian
        X = X.at[ix].set(jnp.stack([
            r_noise * jnp.cos(theta),  # x coordinates
            r_noise * jnp.sin(theta)   # y coordinates
        ], axis=1))

        y = y.at[ix].set(class_idx)

    y_onehot = jnp.eye(3)[y]

    scaler = StandardScaler()
    X_scaled = scaler.fit_transform(X)
    X_train, X_test, y_train, y_test = train_test_split(
        X_scaled, y_onehot, test_size=0.2, random_state=42, stratify=y
    )

    return X_train, X_test, y_train, y_test


#Numpy Version

In [3]:
class Layer(ABC):
  """Abstract base class for neural network layers."""

  def __init__(self) -> None:
    self.input = None
    self.output = None

  @abstractmethod
  def forward(self, input_data: np.ndarray) -> np.ndarray:
    """Forward pass computation.

    Args:
        input_data: Input tensor of shape (batch_size, input_features)

    Returns:
        Output tensor of shape (batch_size, output_features)
    """
    raise NotImplementedError

  @abstractmethod
  def backward(self, output_error: np.ndarray, learning_rate: float) -> np.ndarray:
    """Backward pass computation.

    Args:
        output_error: Gradient of the loss with respect to layer output
        learning_rate: Learning rate for parameter updates

    Returns:
        Gradient of the loss with respect to layer input
    """
    raise NotImplementedError


class DenseLayer(Layer):
  """Fully connected neural network layer."""

  def __init__(self, input_size: int, output_size: int) -> None:
    """Initialize layer parameters.

    Args:
        input_size: Number of input features
        output_size: Number of output features
    """
    super().__init__()

    limit = np.sqrt(6 / (input_size + output_size))
    self.weights = np.random.uniform(-limit, limit, (input_size, output_size))
    self.bias = np.zeros((1, output_size))

    # Parameter gradients
    self.weights_grad: Optional[np.ndarray] = None
    self.bias_grad: Optional[np.ndarray] = None

  def forward(self, input_data: np.ndarray) -> np.ndarray:
    """Compute forward pass.

    Args:
        input_data: Input tensor of shape (batch_size, input_features)

    Returns:
        Output tensor of shape (batch_size, output_features)
    """
    self.input = input_data
    self.output = np.dot(input_data, self.weights) + self.bias
    return self.output

  def backward(self, output_error: np.ndarray, learning_rate: float) -> np.ndarray:
    """Compute backward pass and update parameters.

    Args:
        output_error: Gradient of loss with respect to layer output
        learning_rate: Learning rate for parameter updates

    Returns:
        Gradient of loss with respect to layer input
    """
    # Compute gradients
    input_error = np.dot(output_error, self.weights.T)
    self.weights_grad = np.dot(self.input.T, output_error) / self.input.shape[0]  # Added batch normalization
    self.bias_grad = np.sum(output_error, axis=0, keepdims=True) / self.input.shape[0]  # Added batch normalization

    # Update parameters
    self.weights -= learning_rate * self.weights_grad
    self.bias -= learning_rate * self.bias_grad

    return input_error


class ActivationLayer(Layer):
  """Neural network activation layer."""

  def __init__(self,
                activation_fn: Callable[[np.ndarray], np.ndarray],
                activation_prime: Callable[[np.ndarray], np.ndarray]) -> None:
    """Initialize activation functions.

    Args:
        activation_fn: Activation function
        activation_prime: Derivative of activation function
    """
    super().__init__()
    self.activation_fn = activation_fn
    self.activation_prime = activation_prime

  def forward(self, input_data: np.ndarray) -> np.ndarray:
    """Apply activation function.

    Args:
        input_data: Input tensor

    Returns:
        Activated tensor
    """
    self.input = input_data
    self.output = self.activation_fn(self.input)
    return self.output

  def backward(self, output_error: np.ndarray, learning_rate: float) -> np.ndarray:
    """Compute gradient through activation function.

    Args:
        output_error: Gradient of loss with respect to layer output
        learning_rate: Unused, kept for API consistency

    Returns:
        Gradient of loss with respect to layer input
    """
    return self.activation_prime(self.input) * output_error



In [4]:
class LossFunctions:
  """Collection of loss functions and their derivatives."""

  @staticmethod
  def cross_entropy(y_true: np.ndarray,
                    y_pred: np.ndarray,
                    epsilon: float = 1e-15) -> float:
    """Categorical cross-entropy loss for multi-class classification.

    Args:
        y_true: One-hot encoded ground truth values
        y_pred: Predicted probabilities for each class
        epsilon: Small constant to avoid log(0)

    Returns:
        Categorical cross-entropy loss value
    """
    # Clip predictions to avoid log(0)
    y_pred = np.clip(y_pred, epsilon, 1 - epsilon)

    # Calculate categorical cross-entropy
    return -np.mean(np.sum(y_true * np.log(y_pred), axis=1))

  @staticmethod
  def cross_entropy_prime(y_true: np.ndarray,
                          y_pred: np.ndarray,
                          epsilon: float = 1e-15) -> np.ndarray:
    """Derivative of categorical cross-entropy loss.

    Args:
        y_true: One-hot encoded ground truth values
        y_pred: Predicted probabilities for each class
        epsilon: Small constant to avoid division by zero

    Returns:
        Gradient of categorical cross-entropy with respect to predictions
    """
     # Clip predictions to avoid division by zero
    y_pred = np.clip(y_pred, epsilon, 1 - epsilon)

    return y_pred - y_true


class Activations:
  """Collection of activation functions and their derivatives."""

  @staticmethod
  def leaky_relu(x: np.ndarray, alpha: float = 0.01) -> np.ndarray:
    return np.where(x > 0, x, alpha * x)

  @staticmethod
  def leaky_relu_prime(x: np.ndarray, alpha: float = 0.01) -> np.ndarray:
    return np.where(x > 0, 1.0, alpha)



In [5]:
@dataclass
class NetworkConfig:
    """Neural network configuration parameters."""
    learning_rate: float = 0.01
    epochs: int = 1000
    batch_size: int = 32
    clip_value: float = 5.0  # Add gradient clipping
    epsilon: float = 1e-15   # Small constant to prevent division by zero


class NeuralNetwork:
    """Simple neural network implementation."""

    def __init__(self, config: NetworkConfig = NetworkConfig()) -> None:
        self.layers: list[Layer] = []
        self.loss = LossFunctions.cross_entropy
        self.loss_prime = LossFunctions.cross_entropy_prime
        self.config = config

    def add(self, layer: Layer) -> None:
        self.layers.append(layer)

    def predict(self, input_data: np.ndarray) -> np.ndarray:
        """Generate predictions with NaN checking."""
        output = input_data


        for i, layer in enumerate(self.layers):
            output = layer.forward(output)

        return output

    def _clip_gradients(self, grad: np.ndarray) -> np.ndarray:
        """Clip gradients to prevent explosion."""
        return np.clip(grad, -self.config.clip_value, self.config.clip_value)

    def fit(self, x_train: np.ndarray, y_train: np.ndarray) -> list[float]:
        """Train the neural network with NaN prevention."""
        samples = len(x_train)

        for epoch in range(self.config.epochs):
            epoch_loss = 0
            batch_count = 0

            # Mini-batch gradient descent
            for i in range(0, samples, self.config.batch_size):
                batch_x = x_train[i:i + self.config.batch_size]
                batch_y = y_train[i:i + self.config.batch_size]
                actual_batch_size = len(batch_x)

                # Forward propagation
                output = self.predict(batch_x)
                # Prevent division by zero in log operations
                output = np.clip(output, self.config.epsilon, 1 - self.config.epsilon)


                # Backward propagation
                error = self.loss_prime(batch_y, output)
                error = self._clip_gradients(error)  # Clip initial gradients

                for j, layer in enumerate(reversed(self.layers)):
                    error = layer.backward(error, self.config.learning_rate)
                    error = self._clip_gradients(error)  # Clip initial gradients


# JAX Version

The jax version is simple, the difficulty lies in keeping a functional paradigm. I find this difficult after years of OOP.

In [6]:
def get_layer_params(input_size: int,
                     output_size: int,
                     key: random.PRNGKey) -> Dict:
  """
  Return the weights and biases for a dense layer

  Args:
    - input_size: size of this layer
    - output_size: size of the next layer
    - key: random key for this layer

  Returns:
    - params: a dictionary of weights and biases for this layer
  """
  limit = jnp.sqrt(6 / (input_size + output_size))
  W_key, b_key = random.split(key)
  return {
    'weights': random.uniform(W_key, (output_size, input_size), minval=-limit, maxval=limit),
    'bias': random.uniform(b_key, (output_size))
    }


@jit
def step(params: List[Dict],
         x: jnp.ndarray,
         y: jnp.ndarray,
         lr: float = 0.05):

  """Optimized training step with static learning rate"""

  grads = grad(loss)(params, x, y)
  updated_params = []

  for param, grad_param in zip(params, grads):
    updated_param = {
      'weights': param['weights'] - lr * grad_param['weights'],
      'bias': param['bias'] - lr * grad_param['bias']
    }
    updated_params.append(updated_param)
  return updated_params

@jit
def loss(params: List[Dict],
         x: jnp.ndarray,
         targets: jnp.ndarray) -> float:
  """Compute cross entropy loss"""

  predictions = batched_predict(params, x)
  log_softmax = predictions - jax.scipy.special.logsumexp(predictions, axis=1, keepdims=True)
  return -jnp.mean(jnp.sum(targets * log_softmax, axis=1))

@jit
def predict(params: List[Dict], x: jnp.ndarray) -> jnp.ndarray:
  """Implicitly defines densley connected network"""
  alpha = 1e-15
  for p in params[:-1]:
    x = jnp.dot(p['weights'], x) + p['bias']
    x = jnp.where(x > 0, x, alpha * x)

  final_weight = params[-1]['weights']
  final_bias = params[-1]['bias']
  return jnp.dot(final_weight, x) + final_bias

batched_predict = vmap(predict, in_axes=(None, 0))


# Comparing the Networks

In [7]:
#Numpy version

#Get data from utils
X_train, X_test, y_train, y_test = generate_circle_data_np(n_points= 2000)

for _ in range(10):
  # Define our network for each trial
  network = NeuralNetwork(NetworkConfig(
      learning_rate=0.05,
      epochs=50,
      batch_size=64,
  ))

  network.add(DenseLayer(2, 128))
  network.add(ActivationLayer(Activations.leaky_relu, Activations.leaky_relu_prime))
  network.add(DenseLayer(128, 256))
  network.add(ActivationLayer(Activations.leaky_relu, Activations.leaky_relu_prime))
  network.add(DenseLayer(256, 128))
  network.add(ActivationLayer(Activations.leaky_relu, Activations.leaky_relu_prime))
  network.add(DenseLayer(128, 3))

  # Time how long training takes
  start_time = time.time()
  network.fit(X_train, y_train)
  end_time = time.time()
  print(f"Training time: {end_time - start_time:.2f} seconds")

#plot last network for sanity check
predictions = network.predict(X_test)
plot_results(X_test, y_test, predictions, 'numpy_demo.png')

Training time: 7.30 seconds
Training time: 7.54 seconds
Training time: 7.76 seconds
Training time: 7.57 seconds
Training time: 7.71 seconds
Training time: 7.63 seconds
Training time: 7.61 seconds
Training time: 7.08 seconds
Training time: 7.55 seconds
Training time: 7.19 seconds


In [8]:
# JAX version

PRN = random.key(0)
X_train, X_test, y_train, y_test = generate_circle_data_jax(PRN, n_points=2000)


input_dim = X_train.shape[1]
num_classes = y_train.shape[1]

layer_sizes = [input_dim, 128, 256, 128 ,num_classes]
keys = random.split(PRN, len(layer_sizes))
num_epochs = 50
batch_size = 64

for _ in range(10):

  #restart out network every trial
  params = [get_layer_params(input_size, output_size, key) \
        for input_size, output_size, key \
        in zip(layer_sizes[:-1], layer_sizes[1:], keys)]

  #Train the network
  start_time = time.time()
  for epoch in range(num_epochs):
    epoch_loss = 0
    for i in range(0, len(X_train), batch_size):
      batch_x = X_train[i:i + batch_size]
      batch_y = y_train[i:i + batch_size]
      params = step(params, batch_x, batch_y)


  end_time = time.time()
  print(f"Training time: {end_time - start_time:.2f} seconds")

#plot out last network for sanity check
predictions = batched_predict(params, X_test)
plot_results(X_test, y_test, predictions, 'jax_demo.png')

Training time: 2.00 seconds
Training time: 1.67 seconds
Training time: 1.77 seconds
Training time: 1.82 seconds
Training time: 1.73 seconds
Training time: 1.73 seconds
Training time: 1.85 seconds
Training time: 1.95 seconds
Training time: 2.00 seconds
Training time: 1.96 seconds
