In [15]:
import numpy as np
import wandb
import pandas as pd
import matplotlib.pyplot as plt
import jax.numpy as jnp
from keras.datasets import mnist
from keras.utils import to_categorical
from sklearn.metrics import confusion_matrix
import seaborn as sns
import tensorflow_datasets as tfds
import random
from jax import random, grad, jit, vmap
from functools import partial

### Intro Function PyNet

In [None]:
#%% PyNet Shared Functions Module
import time
import numpy as np
import wandb

class PyNetBase:
    """Base class containing all shared neural network functionality"""

    def __init__(self, num_features, hidden_units, num_output, weights_init='he', activation='relu', loss='cross_entropy', optimizer='sgd', l2_coeff=0.0, dropout_p=None):
        """
        Initialize neural network with configurable architecture.

        Args:
            num_features: Number of input features
            hidden_units: List of hidden layer sizes [layer1, layer2, ...]
            num_output: Number of output classes
            weights_init: Weight initialization method ('he', 'xavier', 'normal')
            activation: Activation function ('relu', 'tanh', 'sigmoid')
            loss: Loss function ('cross_entropy', 'mse', 'mae')
            optimizer: Optimizer type ('sgd', 'adam', 'rmsprop')
            l2_coeff: L2 regularization coefficient (weight_decay)
            dropout_p: List of dropout probabilities for each hidden layer (None = no dropout)
        """

        # Build layer sizes: input → hidden layers → output
        layer_sizes = [num_features] + hidden_units + [num_output]

        self.layer_sizes = layer_sizes
        self.activation = activation
        self.weights_init = weights_init
        self.loss = loss
        self.optimizer = optimizer
        self.l2_coeff = l2_coeff
        self.dropout_p = dropout_p

        # Validate dropout_p if provided
        num_hidden = len(hidden_units)
        if dropout_p is not None:
            if len(dropout_p) != num_hidden:
                raise ValueError(f"dropout_p must have {num_hidden} values (one per hidden layer)")
            self.dropout_p = dropout_p
        else:
            self.dropout_p = [0.0] * num_hidden  # No dropout by default

        # Initialize weights for each layer
        self.W = []
        for i in range(len(layer_sizes) - 1):
            input_size = layer_sizes[i]
            output_size = layer_sizes[i + 1]

            # Weight initialization
            if weights_init == 'he':
                # He initialization (good for ReLU)
                w = np.random.randn(input_size + 1, output_size) * np.sqrt(2 / input_size)
            elif weights_init == 'xavier':
                # Xavier initialization (good for tanh/sigmoid)
                w = np.random.randn(input_size + 1, output_size) * np.sqrt(1 / input_size)
            elif weights_init == 'normal':
                # Standard normal initialization
                w = np.random.randn(input_size + 1, output_size) * 0.01
            else:
                raise ValueError(f"Unknown weights_init: {weights_init}")

            self.W.append(w)

        # Initialize optimizer state
        if optimizer == 'adam':
            self.m = [np.zeros_like(w) for w in self.W]  # First moment estimates
            self.v = [np.zeros_like(w) for w in self.W]  # Second moment estimates
            self.t = 0  # Time step counter
        elif optimizer == 'rmsprop':
            self.v = [np.zeros_like(w) for w in self.W]  # Moving average of squared gradients


    def forward(self, X, W, dropout_on=False):
        """
        Forward pass through the network with optional dropout

        Args:
            X: Input data
            W: Weights
            dropout_on: Whether to apply dropout (True during training, False during inference)
        Returns:
            y: Output predictions
            h: List of hidden layer activations
            masks: List of dropout masks (one per hidden layer)
        """
        h = []
        masks = []
        a = X
        num_hidden = len(W) - 1

        # Loop through hidden layers
        for l in range(num_hidden):
            a = np.vstack([a, np.ones(a.shape[1])])  # Add bias term
            z = W[l].T @ a
            a = self._activation_function(z)  # Use configurable activation

            # Apply dropout if enabled
            if dropout_on and self.dropout_p[l] > 0.0:
                p = self.dropout_p[l]
                # Inverted dropout: scale active neurons to maintain expected activation
                mask = (np.random.rand(*a.shape) > p).astype(float) / (1.0 - p)
                a *= mask
            else:
                mask = np.ones_like(a)  # No dropout: all neurons active

            h.append(a)
            masks.append(mask)

        # Output layer (no dropout)
        a = np.vstack([a, np.ones(a.shape[1])])  # Add bias term
        y_hat = W[-1].T @ a
        y = self._softmax(y_hat)  # Output layer always uses softmax for classification
        return y, h, masks


    def backward(self, X, T, W, h, masks, eta, y_pred=None, use_clipping=True, max_grad_norm=25.0):
        """
        Backward pass with configurable optimizers, L2 regularization, gradient clipping, and dropout.

        Args:
            X: Input data
            T: Target labels
            W: Weights
            h: Hidden activations from forward pass
            masks: Dropout masks from forward pass
            eta: Learning rate
            y_pred: Pre-computed predictions (optional, for efficiency)
            use_clipping: Whether to use gradient clipping (default True)
            max_grad_norm: Maximum gradient norm for clipping (default 25.0)
        """
        m = X.shape[1]

        if y_pred is None:  # Use pre-computed predictions if available, otherwise compute them
            y, _, _ = self.forward(X, W, dropout_on=False)
        else:
            y = y_pred

        # Increment Adam time step once per backward pass
        if self.optimizer == 'adam':
            self.t += 1

        delta = self._loss_derivative(y, T)  # Use configurable loss derivative

        # Backpropagate through hidden layers (in reverse)
        for l in range(len(W) - 1, 0, -1):
            a_prev = np.vstack([h[l-1], np.ones(h[l-1].shape[1])])  # Add bias term
            Q = a_prev @ delta.T

            # Add L2 regularization to gradient (don't regularize biases - last row)
            if self.l2_coeff > 0:
                Q[:-1, :] += self.l2_coeff * W[l][:-1, :]  # Only regularize weights, not biases

            # Optional gradient clipping
            if use_clipping:
                grad_norm = np.linalg.norm(Q)
                if grad_norm > max_grad_norm:
                    Q *= max_grad_norm / grad_norm

            # Apply optimizer update
            W = self._apply_optimizer_update(W, l, Q, eta, m)

            # Backpropagate delta
            delta = W[l][:-1, :] @ delta
            delta *= self._activation_derivative(h[l-1])  # Use configurable activation derivative
            delta *= masks[l-1]  # Apply dropout mask (only gradients through active neurons)

        # First layer gradient
        a_prev = np.vstack([X, np.ones(X.shape[1])])  # Add bias term
        Q = a_prev @ delta.T

        # Add L2 regularization to first layer gradient
        if self.l2_coeff > 0:
            Q[:-1, :] += self.l2_coeff * W[0][:-1, :]  # Only regularize weights, not biases

        # Optional gradient clipping for first layer
        if use_clipping:
            grad_norm = np.linalg.norm(Q)
            if grad_norm > max_grad_norm:
                Q = Q * (max_grad_norm / grad_norm)

        # Apply optimizer update to first layer
        W = self._apply_optimizer_update(W, 0, Q, eta, m)
        loss = self._loss_function(y, T)
        return W, loss


    def _apply_optimizer_update(self, W, layer_idx, gradient, eta, batch_size):
        """Helper method to apply optimizer-specific weight updates"""

        if self.optimizer == 'sgd':
            # Standard SGD (Stochastic Gradient Descent) update
            W[layer_idx] -= (eta / batch_size) * gradient

        elif self.optimizer == 'adam':
            # Adam (Adaptive Moment Estimation) optimizer update
            beta1, beta2, epsilon = 0.9, 0.999, 1e-8

            # Update biased first moment estimate
            self.m[layer_idx] = beta1 * self.m[layer_idx] + (1 - beta1) * gradient
            # Update biased second raw moment estimate
            self.v[layer_idx] = beta2 * self.v[layer_idx] + (1 - beta2) * gradient**2
            # Compute bias-corrected first moment estimate
            m_hat = self.m[layer_idx] / (1 - beta1**self.t)
            # Compute bias-corrected second raw moment estimate
            v_hat = self.v[layer_idx] / (1 - beta2**self.t)
            # Update weights
            denominator = np.sqrt(v_hat) + epsilon
            update = (eta / batch_size) * m_hat / denominator
            # Clip extreme updates to prevent instability
            update = np.clip(update, -1.0, 1.0)
            W[layer_idx] -= update

        elif self.optimizer == 'rmsprop':
            # RMSprop (Root Mean Square Propagation) optimizer update
            alpha, epsilon = 0.99, 1e-8
            # Update moving average of squared gradients
            self.v[layer_idx] = alpha * self.v[layer_idx] + (1 - alpha) * gradient**2
            # Update weights
            W[layer_idx] -= (eta / batch_size) * gradient / (np.sqrt(self.v[layer_idx]) + epsilon)

        else:
            raise ValueError(f"Unknown optimizer: {self.optimizer}")

        return W


    def _softmax(self, y_hat):
        """Compute softmax probabilities"""
        y_hat = y_hat - np.max(y_hat, axis=0, keepdims=True)  # prevent overflow
        exp_scores = np.exp(y_hat)
        return exp_scores / np.sum(exp_scores, axis=0, keepdims=True)


    def _activation_function(self, z):
        """Apply activation function"""
        if self.activation == 'relu':
            return np.maximum(0, z)
        elif self.activation == 'tanh':
            return np.tanh(z)
        elif self.activation == 'sigmoid':
            return 1 / (1 + np.exp(-np.clip(z, -500, 500)))  # Clip to prevent overflow
        else:
            raise ValueError(f"Unknown activation: {self.activation}")


    def _activation_derivative(self, a):
        """Calculate derivative of activation function"""
        if self.activation == 'relu':
            return a > 0
        elif self.activation == 'tanh':
            return 1 - a**2
        elif self.activation == 'sigmoid':
            return a * (1 - a)
        else:
            raise ValueError(f"Unknown activation: {self.activation}")


    def _loss_function(self, y_pred, y_true):
        """Calculate loss based on configured loss function"""
        epsilon = 1e-12  # Prevent log(0)

        if self.loss == 'cross_entropy':
            # Categorical Cross-Entropy Loss
            return -np.sum(np.log(np.sum(y_pred * y_true, axis=0) + epsilon))
        elif self.loss == 'mse':
            # Mean Squared Error Loss
            return 0.5 * np.sum((y_pred - y_true) ** 2)
        elif self.loss == 'mae':
            # Mean Absolute Error Loss
            return np.sum(np.abs(y_pred - y_true))
        else:
            raise ValueError(f"Unknown loss function: {self.loss}")


    def _loss_derivative(self, y_pred, y_true):
        """Calculate derivative of loss function for backpropagation"""
        if self.loss == 'cross_entropy':
            # For cross-entropy with softmax: derivative is simply (y_pred - y_true)
            return y_pred - y_true
        elif self.loss == 'mse':
            # MSE derivative: (y_pred - y_true)
            return y_pred - y_true
        elif self.loss == 'mae':
            # MAE derivative: sign(y_pred - y_true)
            return np.sign(y_pred - y_true)
        else:
            raise ValueError(f"Unknown loss function: {self.loss}")




# Shared utility functions
def calculate_accuracy(net, X, T, W):
    """Calculate accuracy percentage (always with dropout OFF)"""
    y, _, _ = net.forward(X, W, dropout_on=False)
    predictions = np.argmax(y, axis=0)
    true_labels = np.argmax(T, axis=0)
    return np.mean(predictions == true_labels) * 100


def train(net, X, T, W, epochs, eta, batchsize=32, use_clipping=True, max_grad_norm=25.0, use_wandb=False, wandb_project=None, wandb_config=None, wandb_mode="online"):
    """
    Training loop for neural network.

    Args:
        net: Neural network instance
        X, T: Training data and labels
        W: Initial weights
        epochs: Number of training epochs
        eta: Learning rate
        batchsize: Mini-batch size
        use_clipping: Whether to use gradient clipping
        max_grad_norm: Maximum gradient norm for clipping
        use_wandb: Whether to use Weights & Biases logging
        wandb_project: W&B project name
        wandb_config: Dictionary of hyperparameters to log to W&B
        wandb_mode: W&B mode - "online", "offline", or "disabled"
    """
    losses = []
    accuracies = []  # Track training accuracy
    epoch_times = []  # Track computation time per epoch

    # Initialize W&B if enabled
    if use_wandb and wandb_project:
        wandb.init(project=wandb_project, config=wandb_config, mode=wandb_mode)

    # Print header for nicely formatted table
    print("-" * 70)
    print(f"{'Epoch':<10} {'Accuracy':<10} {'Gain':<10} {'Time':<10} {'ETA'}")
    print("-" * 70)

    start_total = time.time()

    m = X.shape[1]
    for epoch in range(epochs):
        epoch_start = time.time()  # Start timing this epoch

        order = np.random.permutation(m)
        epoch_loss = 0
        for i in range(0, m, batchsize):
            batch = order[i:i+batchsize]
            X_batch = X[:, batch]
            T_batch = T[:, batch]
            # Forward pass with dropout enabled during training
            y_batch, h, masks = net.forward(X_batch, W, dropout_on=True)
            # Backward pass with dropout masks
            W, loss = net.backward(X_batch, T_batch, W, h, masks, eta, y_batch, use_clipping, max_grad_norm)
            epoch_loss += loss

        # Calculate training accuracy for this epoch
        train_accuracy = calculate_accuracy(net, X, T, W)
        accuracies.append(train_accuracy)
        losses.append(epoch_loss)

        # Calculate gain compared to last epoch
        if epoch > 0:
            gain = train_accuracy - accuracies[-2]  # Current - previous
            if gain > 0:
                gain_str = f"+{gain:.2f}%"
            elif gain < 0:
                gain_str = f"{gain:.2f}%"  # Already has negative sign
            else:
                gain_str = " 0.00%"
        else:
            gain_str = "baseline"

        # Calculate epoch time
        epoch_time = time.time() - epoch_start
        epoch_times.append(epoch_time)

        # Calculate ETA (estimated time remaining)
        if epoch > 0:
            avg_time_per_epoch = np.mean(epoch_times)
            remaining_epochs = epochs - (epoch + 1)
            eta_seconds = avg_time_per_epoch * remaining_epochs
            if eta_seconds > 60:
                eta_str = f"{np.floor(eta_seconds/60):.0f}min {eta_seconds%60:.0f}sec"
            else:
                eta_str = f"{eta_seconds:.0f}sec"
        else:
            eta_str = "calculating..."

        # Format epoch info for output
        epoch_str = f"{epoch+1}/{epochs}"
        accuracy_str = f"{train_accuracy:.2f}%"
        time_str = f"{epoch_time:.2f}sec"

        # Show progress
        print(f"{epoch_str:<10} {accuracy_str:<10} {gain_str:<10} {time_str:<10} {eta_str}")

        # Log to W&B if enabled
        if use_wandb and wandb_project:
            wandb.log({
                "epoch": epoch + 1,
                "train_loss": float(epoch_loss),
                "train_accuracy": float(train_accuracy),
                "epoch_time": float(epoch_time)
            })

    total_time = time.time() - start_total
    avg_epoch_time = np.mean(epoch_times)

    print("-" * 70)
    print(f"Total training time: {total_time:.1f}sec")
    print(f"Average per epoch: {avg_epoch_time:.2f}sec")
    print("-" * 70)

    # Don't finish W&B here - let evaluate_model do it after logging test metrics

    return W, losses, accuracies

def evaluate_model(net, X_test, T_test, y_test, W, train_accuracies, use_wandb=False):
    """
    Evaluate model performance and print results

    Args:
        net: Neural network instance
        X_test, T_test: Test data and labels
        y_test: Test labels (not one-hot encoded)
        W: Trained weights
        train_accuracies: List of training accuracies from training
        use_wandb: Whether to log test metrics to W&B
    """
    # Make predictions and calculate accuracy (dropout OFF for evaluation)
    y_test_pred, _, _ = net.forward(X_test.T, W, dropout_on=False)
    y_pred = np.argmax(y_test_pred, axis=0)
    test_accuracy = np.mean(y_pred == y_test)

    # Calculate test loss using the configurable loss function
    test_loss = net._loss_function(y_test_pred, T_test.T) / X_test.shape[0]  # Average per sample

    print(f"\n================== Final Results ==================")
    print(f"Test Accuracy: {test_accuracy * 100:.2f}%")
    print(f"Test Loss (avg per sample): {test_loss:.4f}")
    print(f"Training Accuracy Improvement: {(train_accuracies[-1] - train_accuracies[0]):.1f}% points")
    print(f"Final Training Accuracy: {train_accuracies[-1]:.2f}%")

    # Log test metrics to W&B if enabled and run is still active
    if use_wandb and wandb.run is not None:
        wandb.log({
            "test_accuracy": float(test_accuracy * 100),
            "test_loss": float(test_loss)
        })
        wandb.finish(quiet=False)  # Finish the W&B run after logging test metrics (quiet=False shows summary)

    return y_pred, test_accuracy, test_loss





## Start Numpy Plots

In [None]:
wandb.login()

  | |_| | '_ \/ _` / _` |  _/ -_)
[34m[1mwandb[0m: Logging into wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)
[34m[1mwandb[0m: You can find your API key in your browser here: https://wandb.ai/authorize
[34m[1mwandb[0m: Paste an API key from your profile and hit enter:

 ··········


[34m[1mwandb[0m: No netrc file found, creating one.
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc
[34m[1mwandb[0m: Currently logged in as: [33mmathiasdyhr[0m ([33motovo-dtu-qa[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


True

## PyNetM10

In [None]:
np.random.seed(42)
random.seed(42)

# Dataset Configuration
num_features = 28 * 28     # MNIST: 28x28 pixels
num_classes = 10           # MNIST: digits 0–9

# Architecture Configuration
hidden_units = [32, 32]
activation = 'relu'
weights_init = 'he'

# Training Configuration
num_epochs = 100
learning_rate = 0.001
batch_size = 32
loss = 'cross_entropy'
optimizer = 'adam'
l2_coeff = 1e-8
dropout_p = [0.1, 0.1]
use_grad_clipping = False
max_grad_norm = 50.0

# WandB Configuration
use_wandb = True
wandb_project = "PyNetxJaxNet"        # must match your project name on wandb.ai
wandb_mode = "online"                  # ✅ online logging
wandb_config = {
    # Architecture
    "num_features": num_features,
    "hidden_units": hidden_units,
    "num_classes": num_classes,
    "activation": activation,
    "weights_init": weights_init,

    # Training
    "optimizer": optimizer,
    "learning_rate": learning_rate,
    "batch_size": batch_size,
    "num_epochs": num_epochs,
    "loss": loss,
    "l2_coeff": l2_coeff,
    "dropout_p": dropout_p,
    "use_grad_clipping": use_grad_clipping,
    "max_grad_norm": max_grad_norm,

    # Metadata
    "dataset": "MNIST",
    "framework": "PyNet (NumPy)",
    "seed": 42
}


#%%######################### 2. Load MNIST Data ############################
print("Loading MNIST dataset...")
(X_train, y_train), (X_test, y_test) = mnist.load_data()

X_train = X_train.reshape(-1, 28*28) / 255.0
X_test = X_test.reshape(-1, 28*28) / 255.0

T_train = to_categorical(y_train, num_classes=10)
T_test = to_categorical(y_test, num_classes=10)

print(f"Successfully loaded!")
print(f"Training samples: {X_train.shape[0]:,}")
print(f"Test samples: {X_test.shape[0]:,}")
print(f"Classes: 0–9 (10 total)")
print(f"Image shape: 28×28 → {X_train.shape[1]} features")


#%%################ 3. Initialize MNIST Neural Network #####################
class PyNet_M10(PyNetBase):
    """MNIST-specific neural network using shared base functionality"""
    pass

net = PyNet_M10(
    num_features, hidden_units, num_classes,
    weights_init, activation, loss, optimizer, l2_coeff, dropout_p
)

print("\nNetwork Architecture:")
print(f"   Input features: {num_features}")
print(f"   Hidden layers: {hidden_units}")
print(f"   Output classes: {num_classes}")
print(f"   Activation: {activation}")
print(f"   Weight init: {weights_init}")
print("Training Configuration:")
print(f"   Optimizer: {optimizer}")
print(f"   Learning rate: {learning_rate}")
print(f"   Batch size: {batch_size}")
print(f"   Epochs: {num_epochs}")
print(f"   Loss function: {loss}")
print(f"   L2 coefficient: {l2_coeff}")
print(f"   Gradient clipping: {use_grad_clipping}")
print(f"   Max gradient norm: {max_grad_norm}")
print(f"   Dropout: {dropout_p}")


#%%########################### 4. Training Loop ############################
net.W, losses, train_accuracies = train(
    net, X_train.T, T_train.T, net.W,
    num_epochs, learning_rate, batch_size,
    use_clipping=use_grad_clipping, max_grad_norm=max_grad_norm,
    use_wandb=use_wandb,
    wandb_project=wandb_project,
    wandb_config=wandb_config,
    wandb_mode=wandb_mode
)


#%%########################## 5. Evaluate Model ############################
y_pred, test_accuracy, test_loss = evaluate_model(
    net, X_test, T_test, y_test, net.W, train_accuracies, use_wandb=True
)


#%%########################## 6. Confusion Matrix ##########################
if wandb.run is not None:
    cm = confusion_matrix(y_test, y_pred)
    plt.figure(figsize=(8,6))
    sns.heatmap(cm, annot=True, fmt="d", cmap="Blues")
    plt.title("Confusion Matrix (MNIST, PyNet NumPy)")
    plt.xlabel("Predicted Label")
    plt.ylabel("True Label")
    plt.tight_layout()

    # Log confusion matrix image to WandB
    wandb.log({"confusion_matrix": wandb.Image(plt)})
    plt.show()

    wandb.finish()


Loading MNIST dataset...
Successfully loaded!
Training samples: 60,000
Test samples: 10,000
Classes: 0–9 (10 total)
Image shape: 28×28 → 784 features

Network Architecture:
   Input features: 784
   Hidden layers: [32, 32]
   Output classes: 10
   Activation: relu
   Weight init: he
Training Configuration:
   Optimizer: adam
   Learning rate: 0.001
   Batch size: 32
   Epochs: 100
   Loss function: cross_entropy
   L2 coefficient: 1e-08
   Gradient clipping: False
   Max gradient norm: 50.0
   Dropout: [0.1, 0.1]


0,1
epoch,▁▁▁▁▂▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
epoch_time,▂▃▃▁▄▃▃▄▅▄▃▃▃▄▄▄▄▅▆▃▃▄▄▄█▃▄▃▄▂▃▃▃▃▄▃▄▄▃▂
train_accuracy,▁▅▆▆▆▆▇▇▇▇▇▇▇▇▇▇▇▇▇▇████████████████████
train_loss,█▅▃▃▂▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
epoch,44.0
epoch_time,1.79172
train_accuracy,95.635
train_loss,13205.5108


----------------------------------------------------------------------
Epoch      Accuracy   Gain       Time       ETA
----------------------------------------------------------------------
1/100      71.89%     baseline   1.68sec    calculating...
2/100      84.92%     +13.04%    1.71sec    2min 46sec
3/100      87.26%     +2.34%     1.89sec    2min 51sec
4/100      88.61%     +1.35%     1.99sec    2min 55sec
5/100      89.45%     +0.85%     1.80sec    2min 52sec
6/100      90.28%     +0.82%     1.71sec    2min 49sec
7/100      90.73%     +0.45%     1.71sec    2min 46sec
8/100      91.22%     +0.48%     1.80sec    2min 44sec
9/100      91.52%     +0.30%     1.81sec    2min 43sec
10/100     91.90%     +0.38%     1.71sec    2min 40sec
11/100     92.18%     +0.29%     1.80sec    2min 39sec
12/100     92.48%     +0.30%     1.81sec    2min 37sec
13/100     92.67%     +0.19%     1.89sec    2min 36sec
14/100     92.92%     +0.25%     1.81sec    2min 34sec
15/100     93.08%     +0.16%     1.8

0,1
epoch,▁▁▁▁▂▂▂▃▃▃▃▃▄▄▄▅▅▅▅▅▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇▇███
epoch_time,▁▂▅▂▅▇▅▅█▅▅█▅▇▅▄▂▅█▅▇▅▅▅▅▅▅▅█▂███▅█▂█▅▂▅
test_accuracy,▁
test_loss,▁
train_accuracy,▁▂▃▃▄▅▅▅▅▆▆▆▆▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇███████████
train_loss,█▅▃▃▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
epoch,100.0
epoch_time,1.89872
test_accuracy,96.25
test_loss,0.12532
train_accuracy,97.33
train_loss,9244.02213


## PyNetE26

In [None]:
np.random.seed(42)
random.seed(42)

# Dataset Configuration
num_features = 28 * 28     # EMNIST: 28x28 pixels
num_classes = 26           # EMNIST: letters A-Z

# Architecture Configuration
hidden_units = [32, 32]    # Units per hidden layer [layer1, layer2, ...]
activation = 'relu'        # Activation function: 'relu', 'tanh', 'sigmoid'
weights_init = 'he'        # Weight initialization: 'he', 'xavier', 'normal'

# Training Configuration
num_epochs = 100           # Number of training epochs
learning_rate = 0.001      # Learning rate for gradient descent
batch_size = 32            # Mini-batch size
loss = 'cross_entropy'     # Loss function: 'cross_entropy', 'mse', 'mae'
optimizer = 'adam'         # Optimizer: 'sgd', 'adam', 'rmsprop'
l2_coeff = 1e-8            # L2 regularization coefficient (weight_decay)
dropout_p = [0.0, 0.0]     # Dropout probabilities per layer [hidden1, hidden2, ...]; 0.0 = no dropout
use_grad_clipping = False  # Enable/disable gradient clipping
max_grad_norm = 50.0       # Maximum gradient norm for clipping

# WandB Configuration
use_wandb = True                           # Enable W&B logging
wandb_project = "PyNetxJaxNet"             # Your W&B project name
wandb_mode = "online"                      # W&B mode: "online", "offline", or "disabled"
wandb_config = {
    # Architecture
    "num_features": num_features,
    "hidden_units": hidden_units,
    "num_classes": num_classes,
    "activation": activation,
    "weights_init": weights_init,

    # Training
    "optimizer": optimizer,
    "learning_rate": learning_rate,
    "batch_size": batch_size,
    "num_epochs": num_epochs,
    "loss": loss,
    "l2_coeff": l2_coeff,
    "dropout_p": dropout_p,
    "use_grad_clipping": use_grad_clipping,
    "max_grad_norm": max_grad_norm,

    # Metadata
    "dataset": "EMNIST",
    "framework": "PyNet"
}

#%%######################## 2. Load EMNIST Data ############################

# Load EMNIST Letters dataset using TensorFlow Datasets
print("Loading EMNIST Letters dataset...")
ds_train, ds_test = tfds.load('emnist/letters', split=['train', 'test'], as_supervised=True)

# Convert to numpy arrays
def preprocess_data(ds):
    images, labels = [], []
    for image, label in ds:
        images.append(image.numpy())
        labels.append(label.numpy())
    return np.array(images), np.array(labels)

print("Converting to numpy arrays...")
X_train, y_train = preprocess_data(ds_train)
X_test, y_test = preprocess_data(ds_test)

# Reshape and normalize inputs (same as MNIST)
X_train = X_train.reshape(-1, 28*28) / 255.0
X_test = X_test.reshape(-1, 28*28) / 255.0

# EMNIST letters uses labels 1-26, need to convert to 0-25 for one-hot encoding
print(f"Original label range: {y_train.min()}-{y_train.max()}")
y_train = y_train - 1  # Convert 1-26 to 0-25
y_test = y_test - 1    # Convert 1-26 to 0-25
print(f"Adjusted label range: {y_train.min()}-{y_train.max()}")

# One-hot encode labels (now 0-25 for A-Z)
T_train = to_categorical(y_train, num_classes=26)
T_test = to_categorical(y_test, num_classes=26)

print(f"Successfully loaded!")
print(f"Training samples: {X_train.shape[0]:,}")
print(f"Test samples: {X_test.shape[0]:,}")
print(f"Classes: A-Z (26 total)")
print(f"Image shape: 28x28 → {X_train.shape[1]} features")




#%%################ 3. Initialize EMNIST Neural Network #####################

# Create EMNIST-specific network class
class PyNet_E26(PyNetBase):
    """EMNIST Letters-specific neural network using shared base functionality"""
    pass

# Initialize network
net = PyNet_E26(num_features, hidden_units, num_classes, weights_init, activation, loss, optimizer, l2_coeff, dropout_p)

print(f"\nNetwork Architecture:")
print(f"   Input features: {num_features}")
print(f"   Hidden layers: {hidden_units}")
print(f"   Output classes: {num_classes}")
print(f"   Activation: {activation}")
print(f"   Weight init: {weights_init}")
print(f"Training Configuration:")
print(f"   Optimizer: {optimizer}")
print(f"   Learning rate: {learning_rate}")
print(f"   Batch size: {batch_size}")
print(f"   Epochs: {num_epochs}")
print(f"   Loss function: {loss}")
print(f"   L2 coefficient: {l2_coeff}")
print(f"   Gradient clipping: {use_grad_clipping}")
print(f"   Max gradient norm: {max_grad_norm}")
print(f"   Dropout: {dropout_p if dropout_p is not None else 'None'}")




#%%########################### 4. Training Loop ############################

# Train the model (using configured gradient clipping)
net.W, losses, train_accuracies = train(
    net, X_train.T, T_train.T, net.W,
    num_epochs, learning_rate, batch_size,
    use_clipping=use_grad_clipping, max_grad_norm=max_grad_norm,
    use_wandb=use_wandb,
    wandb_project=wandb_project,
    wandb_config=wandb_config,
    wandb_mode=wandb_mode
)




#%%########################## 5. Evaluate Model ############################

# Evaluate and display results
y_pred, test_accuracy, test_loss = evaluate_model(
    net, X_test, T_test, y_test, net.W, train_accuracies
)

# Convert some predictions to letters for demonstration
def number_to_letter(num):
    return chr(ord('A') + num)

print(f"\n Sample Letter Predictions:")
sample_indices = np.random.choice(len(y_test), 5, replace=False)
for i in sample_indices:
    true_letter = number_to_letter(y_test[i])  # y_test is already 0-25 range
    pred_letter = number_to_letter(y_pred[i])
    print(f"True: {true_letter}, Predicted: {pred_letter}")

Loading EMNIST Letters dataset...
Converting to numpy arrays...
Original label range: 1-26
Adjusted label range: 0-25
Successfully loaded!
Training samples: 88,800
Test samples: 14,800
Classes: A-Z (26 total)
Image shape: 28x28 → 784 features

Network Architecture:
   Input features: 784
   Hidden layers: [32, 32]
   Output classes: 26
   Activation: relu
   Weight init: he
Training Configuration:
   Optimizer: adam
   Learning rate: 0.001
   Batch size: 32
   Epochs: 100
   Loss function: cross_entropy
   L2 coefficient: 1e-08
   Gradient clipping: False
   Max gradient norm: 50.0
   Dropout: [0.0, 0.0]


0,1
epoch,▁▁▁▁▂▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▄▄▄▅▅▅▅▅▆▆▆▆▆▆▇▇████
epoch_time,▆▃▄█▃▄▃▄▄▃▆▃▁▄▄▄▁▄▄▄▄▆▃▆▃▃▁▃▃▆▃▆▆▆▆▆▆▆▃▆
train_accuracy,▁▃▄▅▆▆▆▆▆▇▇▇▇▇▇▇▇▇▇▇▇▇██████████████████
train_loss,█▄▃▃▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
epoch,100.0
epoch_time,2.69706
train_accuracy,83.72635
train_loss,49165.53155


----------------------------------------------------------------------
Epoch      Accuracy   Gain       Time       ETA
----------------------------------------------------------------------
1/100      35.41%     baseline   2.53sec    calculating...
2/100      49.85%     +14.44%    2.59sec    4min 11sec
3/100      56.84%     +6.99%     2.51sec    4min 7sec
4/100      60.53%     +3.69%     2.49sec    4min 3sec
5/100      63.34%     +2.80%     2.51sec    3min 60sec
6/100      65.56%     +2.22%     2.70sec    4min 0sec
7/100      67.08%     +1.52%     2.49sec    3min 57sec
8/100      68.55%     +1.47%     2.51sec    3min 54sec
9/100      69.65%     +1.10%     2.49sec    3min 51sec
10/100     70.49%     +0.85%     2.50sec    3min 48sec
11/100     71.50%     +1.00%     2.50sec    3min 45sec
12/100     72.13%     +0.63%     2.50sec    3min 42sec
13/100     72.78%     +0.65%     2.50sec    3min 40sec
14/100     73.40%     +0.62%     2.50sec    3min 37sec
15/100     73.94%     +0.54%     2.49se

## PyNetE47B


In [None]:
np.random.seed(42)
random.seed(42)

# Dataset Configuration
num_features = 28 * 28     # EMNIST: 28x28 pixels
num_classes = 47           # EMNIST Balanced: 47 classes (merged digits and letters)

# Architecture Configuration
hidden_units = [32, 32]    # Units per hidden layer [layer1, layer2, ...]
activation = 'relu'        # Activation function: 'relu', 'tanh', 'sigmoid'
weights_init = 'he'        # Weight initialization: 'he', 'xavier', 'normal'

# Training Configuration
num_epochs = 100           # Number of training epochs
learning_rate = 0.001      # Learning rate for gradient descent
batch_size = 32            # Mini-batch size
loss = 'cross_entropy'     # Loss function: 'cross_entropy', 'mse', 'mae'
optimizer = 'adam'         # Optimizer: 'sgd', 'adam', 'rmsprop'
l2_coeff = 1e-8            # L2 regularization coefficient (weight_decay)
dropout_p = [0.1, 0.1]     # Dropout probabilities per layer [hidden1, hidden2, ...]; 0.0 = no dropout
use_grad_clipping = False  # Enable/disable gradient clipping
max_grad_norm = 50.0       # Maximum gradient norm for clipping

# WandB Configuration
use_wandb = True                           # Enable W&B logging
wandb_project = "PyNetxJaxNet"             # Your W&B project name
wandb_mode = "online"                      # W&B mode: "online", "offline", or "disabled"
wandb_config = {
    # Architecture
    "num_features": num_features,
    "hidden_units": hidden_units,
    "num_classes": num_classes,
    "activation": activation,
    "weights_init": weights_init,

    # Training
    "optimizer": optimizer,
    "learning_rate": learning_rate,
    "batch_size": batch_size,
    "num_epochs": num_epochs,
    "loss": loss,
    "l2_coeff": l2_coeff,
    "dropout_p": dropout_p,
    "use_grad_clipping": use_grad_clipping,
    "max_grad_norm": max_grad_norm,

    # Metadata
    "dataset": "EMNIST",
    "framework": "PyNet"
}

#%%######################## 2. Load EMNIST Data ############################

# Load EMNIST Balanced dataset using TensorFlow Datasets
print("Loading EMNIST Balanced dataset...")
ds_train, ds_test = tfds.load('emnist/balanced', split=['train', 'test'], as_supervised=True)

# Convert to numpy arrays
def preprocess_data(ds):
    images, labels = [], []
    for image, label in ds:
        images.append(image.numpy())
        labels.append(label.numpy())
    return np.array(images), np.array(labels)

print("Converting to numpy arrays...")
X_train, y_train = preprocess_data(ds_train)
X_test, y_test = preprocess_data(ds_test)

# Reshape and normalize inputs (same as MNIST)
X_train = X_train.reshape(-1, 28*28) / 255.0
X_test = X_test.reshape(-1, 28*28) / 255.0

# EMNIST balanced uses labels 0-46 (47 classes)
print(f"Label range: {y_train.min()}-{y_train.max()}")

# One-hot encode labels (0-46 for 47 classes)
T_train = to_categorical(y_train, num_classes=47)
T_test = to_categorical(y_test, num_classes=47)

print(f"Successfully loaded!")
print(f"Training samples: {X_train.shape[0]:,}")
print(f"Test samples: {X_test.shape[0]:,}")
print(f"Classes: 0-9 + merged letters (47 total)")
print(f"Image shape: 28x28 → {X_train.shape[1]} features")




#%%################ 3. Initialize EMNIST Neural Network #####################

# Create EMNIST Balanced-specific network class
class PyNet_E47B(PyNetBase):
    """EMNIST Balanced-specific neural network using shared base functionality"""
    pass

# Initialize network
net = PyNet_E47B(num_features, hidden_units, num_classes, weights_init, activation, loss, optimizer, l2_coeff, dropout_p)

print(f"\nNetwork Architecture:")
print(f"   Input features: {num_features}")
print(f"   Hidden layers: {hidden_units}")
print(f"   Output classes: {num_classes}")
print(f"   Activation: {activation}")
print(f"   Weight init: {weights_init}")
print(f"Training Configuration:")
print(f"   Optimizer: {optimizer}")
print(f"   Learning rate: {learning_rate}")
print(f"   Batch size: {batch_size}")
print(f"   Epochs: {num_epochs}")
print(f"   Loss function: {loss}")
print(f"   L2 coefficient: {l2_coeff}")
print(f"   Dropout probabilities: {dropout_p}")
print(f"   Gradient clipping: {use_grad_clipping}")
print(f"   Max gradient norm: {max_grad_norm}")




#%%########################### 4. Training Loop ############################

# Train the model (using configured gradient clipping)
net.W, losses, train_accuracies = train(
    net, X_train.T, T_train.T, net.W,
    num_epochs, learning_rate, batch_size,
    use_clipping=use_grad_clipping, max_grad_norm=max_grad_norm,
    use_wandb=use_wandb,
    wandb_project=wandb_project,
    wandb_config=wandb_config,
    wandb_mode=wandb_mode
)




#%%########################## 5. Evaluate Model ############################

# Evaluate and display results
y_pred, test_accuracy, test_loss = evaluate_model(
    net, X_test, T_test, y_test, net.W, train_accuracies
)

print(f"\n Sample Predictions (Class IDs):")
sample_indices = np.random.choice(len(y_test), 5, replace=False)
for i in sample_indices:
    true_class = y_test[i]
    pred_class = y_pred[i]
    status = "✅" if true_class == pred_class else "❌"
    print(f"{status} True: {true_class}, Predicted: {pred_class}")

Loading EMNIST Balanced dataset...




Downloading and preparing dataset Unknown size (download: Unknown size, generated: Unknown size, total: Unknown size) to /root/tensorflow_datasets/emnist/balanced/3.1.0...


Dl Completed...: 0 url [00:00, ? url/s]

Dl Size...: 0 MiB [00:00, ? MiB/s]

Extraction completed...: 0 file [00:00, ? file/s]

Extraction completed...: 0 file [00:00, ? file/s]

Generating splits...:   0%|          | 0/2 [00:00<?, ? splits/s]

Generating train examples...: 0 examples [00:00, ? examples/s]

Shuffling /root/tensorflow_datasets/emnist/balanced/incomplete.HBRPOI_3.1.0/emnist-train.tfrecord*...:   0%|  …

Generating test examples...: 0 examples [00:00, ? examples/s]

Shuffling /root/tensorflow_datasets/emnist/balanced/incomplete.HBRPOI_3.1.0/emnist-test.tfrecord*...:   0%|   …

Dataset emnist downloaded and prepared to /root/tensorflow_datasets/emnist/balanced/3.1.0. Subsequent calls will reuse this data.
Converting to numpy arrays...
Label range: 0-46
Successfully loaded!
Training samples: 112,800
Test samples: 18,800
Classes: 0-9 + merged letters (47 total)
Image shape: 28x28 → 784 features

Network Architecture:
   Input features: 784
   Hidden layers: [32, 32]
   Output classes: 47
   Activation: relu
   Weight init: he
Training Configuration:
   Optimizer: adam
   Learning rate: 0.001
   Batch size: 32
   Epochs: 100
   Loss function: cross_entropy
   L2 coefficient: 1e-08
   Dropout probabilities: [0.1, 0.1]
   Gradient clipping: False
   Max gradient norm: 50.0


----------------------------------------------------------------------
Epoch      Accuracy   Gain       Time       ETA
----------------------------------------------------------------------
1/100      15.27%     baseline   3.53sec    calculating...
2/100      31.35%     +16.09%    3.46sec    5min 42sec
3/100      40.98%     +9.63%     5.14sec    6min 32sec
4/100      47.89%     +6.91%     3.46sec    6min 14sec
5/100      52.51%     +4.62%     3.53sec    6min 3sec
6/100      55.44%     +2.93%     5.05sec    6min 18sec
7/100      57.59%     +2.15%     3.61sec    6min 9sec
8/100      59.12%     +1.53%     3.49sec    5min 60sec
9/100      60.47%     +1.35%     5.07sec    6min 7sec
10/100     61.47%     +1.00%     3.53sec    5min 59sec
11/100     62.62%     +1.15%     3.58sec    5min 52sec
12/100     63.39%     +0.78%     5.00sec    5min 55sec
13/100     64.07%     +0.68%     3.44sec    5min 47sec
14/100     64.67%     +0.60%     3.48sec    5min 40sec
15/100     65.31%     +0.64%     5.10se

# JaxNet

### Intro Function JaxNet

In [None]:
class JAXNetBase:
    """Base class containing all shared neural network functionality"""

    def __init__(self, num_features, hidden_units, num_output, weights_init='he', activation='relu', loss='cross_entropy', optimizer='sgd', l2_coeff=0.0, dropout_p=None, seed=42):
        """
        Initialize neural network with configurable architecture.

        Args:
            num_features: Number of input features
            hidden_units: List of hidden layer sizes [layer1, layer2, ...]
            num_output: Number of output classes
            weights_init: Weight initialization method ('he', 'xavier', 'normal')
            activation: Activation function ('relu', 'tanh', 'sigmoid')
            loss: Loss function ('cross_entropy', 'mse', 'mae')
            optimizer: Optimizer type ('sgd', 'adam', 'rmsprop')
            l2_coeff: L2 regularization coefficient (weight_decay)
            dropout_p: List of dropout probabilities for each hidden layer (None = no dropout)
            seed: Random seed for weight initialization
        """

        # Build layer sizes: input → hidden layers → output
        layer_sizes = [num_features] + hidden_units + [num_output]

        self.layer_sizes = layer_sizes
        self.activation = activation
        self.weights_init = weights_init
        self.loss = loss
        self.optimizer = optimizer
        self.l2_coeff = l2_coeff
        self.dropout_p = dropout_p
        self.seed = seed

        # Validate dropout_p if provided
        num_hidden = len(hidden_units)
        if dropout_p is not None:
            if len(dropout_p) != num_hidden:
                raise ValueError(f"dropout_p must have {num_hidden} values (one per hidden layer)")
            self.dropout_p = dropout_p
        else:
            self.dropout_p = [0.0] * num_hidden  # No dropout by default

        # Initialize weights for each layer
        self.W = []
        key = random.PRNGKey(seed)

        for i in range(len(layer_sizes) - 1):
            key, subkey = random.split(key)
            input_size = layer_sizes[i]
            output_size = layer_sizes[i + 1]

            # Weight initialization
            if weights_init == 'he':
                # He initialization (good for ReLU)
                w = random.normal(subkey, (input_size + 1, output_size)) * jnp.sqrt(2 / input_size)
            elif weights_init == 'xavier':
                # Xavier initialization (good for tanh/sigmoid)
                w = random.normal(subkey, (input_size + 1, output_size)) * jnp.sqrt(1 / input_size)
            elif weights_init == 'normal':
                # Standard normal initialization
                w = random.normal(subkey, (input_size + 1, output_size)) * 0.01
            else:
                raise ValueError(f"Unknown weights_init: {weights_init}")

            self.W.append(w)

        # Initialize optimizer state
        if optimizer == 'adam':
            self.m = [jnp.zeros_like(w) for w in self.W]  # First moment estimates
            self.v = [jnp.zeros_like(w) for w in self.W]  # Second moment estimates
            self.t = 0  # Time step counter
        elif optimizer == 'rmsprop':
            self.v = [jnp.zeros_like(w) for w in self.W]  # Moving average of squared gradients


    def forward(self, X, W, dropout_on=False, rng_key=None):
        """
        Forward pass through the network with optional dropout

        Args:
            X: Input data
            W: Weights
            dropout_on: Whether to apply dropout (True during training, False during inference)
            rng_key: JAX random key for dropout (required if dropout_on=True)
        Returns:
            y: Output predictions
            h: List of hidden layer activations
            masks: List of dropout masks (one per hidden layer)
        """
        h = []
        masks = []
        a = X
        num_hidden = len(W) - 1

        # Loop through hidden layers
        for l in range(num_hidden):
            a = jnp.vstack([a, jnp.ones((1, a.shape[1]))])  # Add bias term
            z = W[l].T @ a
            a = self._activation_function(z)  # Use configurable activation

            # Apply dropout if enabled
            if dropout_on and self.dropout_p[l] > 0.0:
                if rng_key is None:
                    raise ValueError("rng_key must be provided when dropout_on=True")
                rng_key, subkey = random.split(rng_key)
                p = self.dropout_p[l]
                # Inverted dropout: scale active neurons to maintain expected activation
                mask = (random.uniform(subkey, a.shape) > p).astype(float) / (1.0 - p)
                a = a * mask
            else:
                mask = jnp.ones_like(a)  # No dropout: all neurons active

            h.append(a)
            masks.append(mask)

        # Output layer (no dropout)
        a = jnp.vstack([a, jnp.ones((1, a.shape[1]))])  # Add bias term
        y_hat = W[-1].T @ a
        y = self._softmax(y_hat)  # Output layer always uses softmax for classification
        return y, h, masks


    def backward(self, X, T, W, h, masks, eta, y_pred=None, use_clipping=True, max_grad_norm=25.0):
        """
        Backward pass with configurable optimizers, L2 regularization, gradient clipping, and dropout.

        Args:
            X: Input data
            T: Target labels
            W: Weights
            h: Hidden activations from forward pass
            masks: Dropout masks from forward pass
            eta: Learning rate
            y_pred: Pre-computed predictions (optional, for efficiency)
            use_clipping: Whether to use gradient clipping (default True)
            max_grad_norm: Maximum gradient norm for clipping (default 25.0)
        """
        m = X.shape[1]

        if y_pred is None:  # Use pre-computed predictions if available, otherwise compute them
            y, _, _ = self.forward(X, W, dropout_on=False)
        else:
            y = y_pred

        # Increment Adam time step once per backward pass
        if self.optimizer == 'adam':
            self.t += 1

        delta = self._loss_derivative(y, T)  # Use configurable loss derivative

        # Backpropagate through hidden layers (in reverse)
        for l in range(len(W) - 1, 0, -1):
            a_prev = jnp.vstack([h[l-1], jnp.ones((1, h[l-1].shape[1]))])  # Add bias term
            Q = a_prev @ delta.T

            # Add L2 regularization to gradient (don't regularize biases - last row)
            if self.l2_coeff > 0:
                Q = Q.at[:-1, :].add(self.l2_coeff * W[l][:-1, :])  # Only regularize weights, not biases

            # Optional gradient clipping
            if use_clipping:
                grad_norm = jnp.linalg.norm(Q)
                Q = jnp.where(grad_norm > max_grad_norm, Q * (max_grad_norm / grad_norm), Q)

            # Apply optimizer update
            W = self._apply_optimizer_update(W, l, Q, eta, m)

            # Backpropagate delta
            delta = W[l][:-1, :] @ delta
            delta = delta * self._activation_derivative(h[l-1])  # Use configurable activation derivative
            delta = delta * masks[l-1]  # Apply dropout mask (only gradients through active neurons)

        # First layer gradient
        a_prev = jnp.vstack([X, jnp.ones((1, X.shape[1]))])  # Add bias term
        Q = a_prev @ delta.T

        # Add L2 regularization to first layer gradient
        if self.l2_coeff > 0:
            Q = Q.at[:-1, :].add(self.l2_coeff * W[0][:-1, :])  # Only regularize weights, not biases

        # Optional gradient clipping for first layer
        if use_clipping:
            grad_norm = jnp.linalg.norm(Q)
            Q = jnp.where(grad_norm > max_grad_norm, Q * (max_grad_norm / grad_norm), Q)

        # Apply optimizer update to first layer
        W = self._apply_optimizer_update(W, 0, Q, eta, m)
        loss = self._loss_function(y, T)
        return W, loss


    def _apply_optimizer_update(self, W, layer_idx, gradients, eta, batch_size):
        """
        Apply optimizer-specific weight updates with optional update clipping.

        Args:
            W: Current weights (list of arrays)
            layer_idx: Index of layer to update
            gradients: Computed gradients
            eta: Learning rate
            batch_size: Size of current batch

        Returns:
            W: Updated weights
        """
        # Create a copy of weights list for functional update
        W = [w for w in W]  # Shallow copy for JAX functional programming

        # Normalize gradients by batch size
        grad = gradients / batch_size

        if self.optimizer == 'sgd':
            # Standard SGD update
            W[layer_idx] = W[layer_idx] - eta * grad

        elif self.optimizer == 'adam':
            # Adam optimizer with bias correction and update clipping
            beta1, beta2, epsilon = 0.9, 0.999, 1e-8

            # Update biased first moment estimate
            self.m[layer_idx] = beta1 * self.m[layer_idx] + (1 - beta1) * grad

            # Update biased second raw moment estimate
            self.v[layer_idx] = beta2 * self.v[layer_idx] + (1 - beta2) * (grad ** 2)

            # Compute bias-corrected first moment estimate
            m_hat = self.m[layer_idx] / (1 - beta1 ** self.t)

            # Compute bias-corrected second raw moment estimate
            v_hat = self.v[layer_idx] / (1 - beta2 ** self.t)

            # Compute the raw update
            update = eta * m_hat / (jnp.sqrt(v_hat) + epsilon)

            # Apply update clipping for Adam (clip the update, not the gradient)
            update_norm = jnp.linalg.norm(update)
            max_update_norm = 1.0  # Maximum allowed update norm for Adam
            update = jnp.where(update_norm > max_update_norm,
                             update * (max_update_norm / update_norm),
                             update)

            # Apply the clipped update
            W[layer_idx] = W[layer_idx] - update

        elif self.optimizer == 'rmsprop':
            # RMSprop optimizer
            decay_rate, epsilon = 0.9, 1e-8

            # Update moving average of squared gradients
            self.v[layer_idx] = decay_rate * self.v[layer_idx] + (1 - decay_rate) * (grad ** 2)

            # Apply update
            W[layer_idx] = W[layer_idx] - eta * grad / (jnp.sqrt(self.v[layer_idx]) + epsilon)

        else:
            raise ValueError(f"Unknown optimizer: {self.optimizer}")

        return W


    def _softmax(self, y_hat):
        """Compute softmax probabilities"""
        y_hat = y_hat - jnp.max(y_hat, axis=0, keepdims=True)  # prevent overflow
        exp_scores = jnp.exp(y_hat)
        return exp_scores / jnp.sum(exp_scores, axis=0, keepdims=True)


    def _activation_function(self, z):
        """Apply activation function"""
        if self.activation == 'relu':
            return jnp.maximum(0, z)
        elif self.activation == 'tanh':
            return jnp.tanh(z)
        elif self.activation == 'sigmoid':
            return 1 / (1 + jnp.exp(-jnp.clip(z, -500, 500)))  # Clip to prevent overflow
        else:
            raise ValueError(f"Unknown activation: {self.activation}")


    def _activation_derivative(self, a):
        """Calculate derivative of activation function"""
        if self.activation == 'relu':
            return (a > 0).astype(jnp.float32)
        elif self.activation == 'tanh':
            return 1 - a**2
        elif self.activation == 'sigmoid':
            return a * (1 - a)
        else:
            raise ValueError(f"Unknown activation: {self.activation}")


    def _loss_function(self, y_pred, y_true):
        """Calculate loss based on configured loss function"""
        epsilon = 1e-12  # Prevent log(0)

        if self.loss == 'cross_entropy':
            # Categorical Cross-Entropy Loss
            return -jnp.sum(jnp.log(jnp.sum(y_pred * y_true, axis=0) + epsilon))
        elif self.loss == 'mse':
            # Mean Squared Error Loss
            return 0.5 * jnp.sum((y_pred - y_true) ** 2)
        elif self.loss == 'mae':
            # Mean Absolute Error Loss
            return jnp.sum(jnp.abs(y_pred - y_true))
        else:
            raise ValueError(f"Unknown loss function: {self.loss}")


    def _loss_derivative(self, y_pred, y_true):
        """Calculate derivative of loss function for backpropagation"""
        if self.loss == 'cross_entropy':
            # For cross-entropy with softmax: derivative is simply (y_pred - y_true)
            return y_pred - y_true
        elif self.loss == 'mse':
            # MSE derivative: (y_pred - y_true)
            return y_pred - y_true
        elif self.loss == 'mae':
            # MAE derivative: sign(y_pred - y_true)
            return jnp.sign(y_pred - y_true)
        else:
            raise ValueError(f"Unknown loss function: {self.loss}")


# Shared utility functions
def calculate_accuracy(net, X, T, W):
    """Calculate accuracy percentage (always with dropout OFF)"""
    y, _, _ = net.forward(X, W, dropout_on=False)
    predictions = jnp.argmax(y, axis=0)
    true_labels = jnp.argmax(T, axis=0)
    return jnp.mean(predictions == true_labels) * 100


def train(net, X, T, W, epochs, eta, batchsize=32, use_clipping=True, max_grad_norm=25.0, use_wandb=False, wandb_project=None, wandb_config=None, wandb_mode="online"):
    """
    Training loop for neural network.

    Args:
        net: Neural network instance
        X, T: Training data and labels
        W: Initial weights
        epochs: Number of training epochs
        eta: Learning rate
        batchsize: Mini-batch size
        use_clipping: Whether to use gradient clipping
        max_grad_norm: Maximum gradient norm for clipping
        use_wandb: Whether to use Weights & Biases logging
        wandb_project: W&B project name
        wandb_config: Dictionary of hyperparameters to log to W&B
        wandb_mode: W&B mode - "online", "offline", or "disabled"
    """
    losses = []
    accuracies = []  # Track training accuracy
    epoch_times = []  # Track computation time per epoch

    # Initialize W&B if enabled
    if use_wandb and wandb_project:
        wandb.init(project=wandb_project, config=wandb_config, mode=wandb_mode)

    # Print header for nicely formatted table
    print("-" * 70)
    print(f"{'Epoch':<10} {'Accuracy':<10} {'Gain':<10} {'Time':<10} {'ETA'}")
    print("-" * 70)

    start_total = time.time()

    m = X.shape[1]
    key = random.PRNGKey(42)  # For batch shuffling and dropout

    for epoch in range(epochs):
        epoch_start = time.time()  # Start timing this epoch

        key, subkey = random.split(key)
        order = random.permutation(subkey, m)
        epoch_loss = 0

        for i in range(0, m, batchsize):
            batch = order[i:i+batchsize]
            X_batch = X[:, batch]
            T_batch = T[:, batch]

            # Forward pass with dropout enabled during training
            key, dropout_key = random.split(key)
            y_batch, h, masks = net.forward(X_batch, W, dropout_on=True, rng_key=dropout_key)

            # Backward pass with dropout masks
            W, loss = net.backward(X_batch, T_batch, W, h, masks, eta, y_batch, use_clipping, max_grad_norm)
            epoch_loss += loss

        # Calculate training accuracy for this epoch
        train_accuracy = float(calculate_accuracy(net, X, T, W))  # Convert to Python float
        accuracies.append(train_accuracy)
        losses.append(float(epoch_loss))  # Convert to Python float

        # Calculate gain compared to last epoch
        if epoch > 0:
            gain = train_accuracy - accuracies[-2]  # Current - previous
            if gain > 0:
                gain_str = f"+{gain:.2f}%"
            elif gain < 0:
                gain_str = f"{gain:.2f}%"  # Already has negative sign
            else:
                gain_str = " 0.00%"
        else:
            gain_str = "baseline"

        # Calculate epoch time
        epoch_time = time.time() - epoch_start
        epoch_times.append(epoch_time)

        # Calculate ETA (estimated time remaining)
        if epoch > 0:
            avg_time_per_epoch = sum(epoch_times) / len(epoch_times)  # Use pure Python for averaging
            remaining_epochs = epochs - (epoch + 1)
            eta_seconds = avg_time_per_epoch * remaining_epochs
            if eta_seconds > 60:
                eta_str = f"{int(eta_seconds//60)}min {int(eta_seconds%60)}sec"
            else:
                eta_str = f"{int(eta_seconds)}sec"
        else:
            eta_str = "calculating..."

        # Format epoch info for output
        epoch_str = f"{epoch+1}/{epochs}"
        accuracy_str = f"{train_accuracy:.2f}%"
        time_str = f"{epoch_time:.2f}sec"

        # Show progress
        print(f"{epoch_str:<10} {accuracy_str:<10} {gain_str:<10} {time_str:<10} {eta_str}")

        # Log to W&B if enabled
        if use_wandb and wandb_project:
            wandb.log({
                "epoch": epoch + 1,
                "train_loss": float(epoch_loss),
                "train_accuracy": float(train_accuracy),
                "epoch_time": float(epoch_time)
            })

    total_time = time.time() - start_total
    avg_epoch_time = sum(epoch_times) / len(epoch_times)

    print("-" * 70)
    print(f"Total training time: {total_time:.1f}sec")
    print(f"Average per epoch: {avg_epoch_time:.2f}sec")
    print("-" * 70)

    # Don't finish W&B here - let evaluate_model do it after logging test metrics

    return W, losses, accuracies

def evaluate_model(net, X_test, T_test, y_test, W, train_accuracies, use_wandb=False):
    """
    Evaluate model performance and print results

    Args:
        net: Neural network instance
        X_test, T_test: Test data and labels
        y_test: Test labels (not one-hot encoded)
        W: Trained weights
        train_accuracies: List of training accuracies from training
        use_wandb: Whether to log test metrics to W&B
    """
    # Make predictions and calculate accuracy (dropout OFF for evaluation)
    y_test_pred, _, _ = net.forward(X_test.T, W, dropout_on=False)
    y_pred = jnp.argmax(y_test_pred, axis=0)
    test_accuracy = float(jnp.mean(y_pred == y_test))  # Convert to Python float

    # Calculate test loss using the configurable loss function
    test_loss = float(net._loss_function(y_test_pred, T_test.T) / X_test.shape[0])  # Average per sample

    print(f"\n================== Final Results ==================")
    print(f"Test Accuracy: {test_accuracy * 100:.2f}%")
    print(f"Test Loss (avg per sample): {test_loss:.4f}")
    print(f"Training Accuracy Improvement: {(train_accuracies[-1] - train_accuracies[0]):.1f}% points")
    print(f"Final Training Accuracy: {train_accuracies[-1]:.2f}%")

    # Log test metrics to W&B if enabled and run is still active
    if use_wandb and wandb.run is not None:
        wandb.log({
            "test_accuracy": float(test_accuracy * 100),
            "test_loss": float(test_loss)
        })
        wandb.finish(quiet=False)  # Finish the W&B run after logging test metrics (quiet=False shows summary)

    return y_pred, test_accuracy, test_loss

## Start Jax Plots

In [None]:
wandb.login()



True

## JaxNetM10

In [None]:
# Dataset Configuration
num_features = 28 * 28  # MNIST: 28x28 pixels
num_classes = 10        # MNIST: digits 0-9

# Architecture Configuration
hidden_units = [32, 32]    # Units per hidden layer [layer1, layer2, ...]
activation = 'relu'        # Activation function: 'relu', 'tanh', 'sigmoid'
weights_init = 'he'        # Weight initialization: 'he', 'xavier', 'normal'

# Training Configuration
num_epochs = 100           # Number of training epochs
learning_rate = 0.001      # Learning rate for gradient descent
batch_size = 512           # Mini-batch size
loss = 'cross_entropy'     # Loss function: 'cross_entropy', 'mse', 'mae'
optimizer = 'adam'         # Optimizer: 'sgd', 'adam', 'rmsprop'
l2_coeff = 1e-8            # L2 regularization coefficient (weight_decay)
dropout_p = [0.1, 0.1]     # Dropout probabilities per layer [hidden1, hidden2, ...]; 0.0 = no dropout
use_grad_clipping = False  # Enable/disable gradient clipping
max_grad_norm = 1.0        # Maximum gradient norm for clipping

# WandB Configuration
use_wandb = True                           # Enable W&B logging
wandb_project = "PyNetxJaxNet"             # Your W&B project name
wandb_mode = "online"                      # W&B mode: "online", "offline", or "disabled"
wandb_config = {
    # Architecture
    "num_features": num_features,
    "hidden_units": hidden_units,
    "num_classes": num_classes,
    "activation": activation,
    "weights_init": weights_init,

    # Training
    "optimizer": optimizer,
    "learning_rate": learning_rate,
    "batch_size": batch_size,
    "num_epochs": num_epochs,
    "loss": loss,
    "l2_coeff": l2_coeff,
    "dropout_p": dropout_p,
    "use_grad_clipping": use_grad_clipping,
    "max_grad_norm": max_grad_norm,

    # Metadata
    "dataset": "MNIST",
    "framework": "JAXNet"
}




#%%######################### 2. Load MNIST Data ############################

# Load MNIST dataset
print("Loading MNIST dataset...")
(X_train, y_train), (X_test, y_test) = mnist.load_data()

# Reshape and normalize inputs
X_train = X_train.reshape(-1, 28*28) / 255.0
X_test = X_test.reshape(-1, 28*28) / 255.0

# Convert to JAX arrays
X_train = jnp.array(X_train)
X_test = jnp.array(X_test)

# One-hot encode labels
T_train = to_categorical(y_train, num_classes=10)
T_test = to_categorical(y_test, num_classes=10)

# Convert to JAX arrays
T_train = jnp.array(T_train)
T_test = jnp.array(T_test)

print(f"Successfully loaded!")
print(f"Training samples: {X_train.shape[0]:,}")
print(f"Test samples: {X_test.shape[0]:,}")
print(f"Classes: 0-9 (10 total)")
print(f"Image shape: 28x28 → {X_train.shape[1]} features")




#%%################ 3. Initialize MNIST Neural Network #####################

# Create MNIST-specific network class
class JAXNet_M10(JAXNetBase):
    """MNIST-specific neural network using shared base functionality"""
    pass

# Initialize network
net = JAXNet_M10(num_features, hidden_units, num_classes, weights_init, activation, loss, optimizer, l2_coeff, dropout_p)

print(f"\nNetwork Architecture:")
print(f"   Input features: {num_features}")
print(f"   Hidden layers: {hidden_units}")
print(f"   Output classes: {num_classes}")
print(f"   Activation: {activation}")
print(f"   Weight init: {weights_init}")
print(f"Training Configuration:")
print(f"   Optimizer: {optimizer}")
print(f"   Learning rate: {learning_rate}")
print(f"   Batch size: {batch_size}")
print(f"   Epochs: {num_epochs}")
print(f"   Loss function: {loss}")
print(f"   L2 coefficient: {l2_coeff}")
print(f"   Gradient clipping: {use_grad_clipping}")
print(f"   Max gradient norm: {max_grad_norm}")
print(f"   Dropout: {dropout_p if dropout_p is not None else 'None'}")




#%%########################### 4. Training Loop ############################

# Train the model (using configured gradient clipping)
net.W, losses, train_accuracies = train(
    net, X_train.T, T_train.T, net.W,
    num_epochs, learning_rate, batch_size,
    use_clipping=use_grad_clipping, max_grad_norm=max_grad_norm,
    use_wandb=use_wandb,
    wandb_project=wandb_project,
    wandb_config=wandb_config,
    wandb_mode=wandb_mode
)

#%%########################## 5. Evaluate Model ############################

# Evaluate and display results
y_pred, test_accuracy, test_loss = evaluate_model(
    net, X_test, T_test, y_test, net.W, train_accuracies
)

Loading MNIST dataset...
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz
[1m11490434/11490434[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 0us/step
Successfully loaded!
Training samples: 60,000
Test samples: 10,000
Classes: 0-9 (10 total)
Image shape: 28x28 → 784 features

Network Architecture:
   Input features: 784
   Hidden layers: [32, 32]
   Output classes: 10
   Activation: relu
   Weight init: he
Training Configuration:
   Optimizer: adam
   Learning rate: 0.001
   Batch size: 512
   Epochs: 100
   Loss function: cross_entropy
   L2 coefficient: 1e-08
   Gradient clipping: False
   Max gradient norm: 1.0
   Dropout: [0.1, 0.1]


0,1
epoch,▁▁▁▁▂▂▂▂▂▂▂▃▃▃▃▃▄▄▄▄▅▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇████
epoch_time,▁▃▁▁▁▂▁▂▂▁▁▆▅▅▂▁▂▁▃▁▁▁▂▁▂▁▁▂▁▂▃▂▁▁█▂▇▂▁▁
train_accuracy,▁▄▆▆▆▆▇▇▇▇▇▇▇▇▇▇▇▇▇▇████████████████████
train_loss,█▇▆▄▄▃▃▂▂▂▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
epoch,100.0
epoch_time,3.4708
train_accuracy,75.5656
train_loss,123356.5391


----------------------------------------------------------------------
Epoch      Accuracy   Gain       Time       ETA
----------------------------------------------------------------------
1/100      89.03%     baseline   9.28sec    calculating...
2/100      91.82%     +2.78%     2.62sec    9min 42sec
3/100      93.03%     +1.21%     2.09sec    7min 32sec
4/100      93.89%     +0.87%     2.06sec    6min 25sec
5/100      94.56%     +0.67%     2.09sec    5min 44sec
6/100      95.09%     +0.53%     2.07sec    5min 16sec
7/100      95.46%     +0.38%     2.73sec    5min 4sec
8/100      95.85%     +0.38%     2.16sec    4min 48sec
9/100      96.11%     +0.26%     2.12sec    4min 35sec
10/100     96.30%     +0.19%     2.06sec    4min 23sec
11/100     96.54%     +0.24%     2.09sec    4min 13sec
12/100     96.74%     +0.20%     2.80sec    4min 10sec
13/100     96.93%     +0.19%     2.09sec    4min 2sec
14/100     97.04%     +0.11%     2.05sec    3min 55sec


KeyboardInterrupt: 