<a href="https://colab.research.google.com/github/Mathavk1606/1DHD-PINN/blob/main/test/SNN_using_JAX.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import jax
import jax.numpy as jnp
from jax import grad, jit, random
from jax.scipy.special import logsumexp

In [2]:
def init_network_params(layer_sizes, key):
    params = []
    keys = random.split(key, len(layer_sizes))

    for i in range(len(layer_sizes) - 1):
        w_key, b_key = random.split(keys[i])
        # Xavier initialization
        scale = jnp.sqrt(2.0 / (layer_sizes[i] + layer_sizes[i + 1]))
        w = random.normal(w_key, (layer_sizes[i], layer_sizes[i + 1])) * scale
        b = jnp.zeros(layer_sizes[i + 1])
        params.append((w, b))

    return params

In [3]:
def relu(x):
    return jnp.maximum(0, x)

In [4]:
def predict(params, x):
    activations = x
    for w, b in params[:-1]:
        activations = relu(jnp.dot(activations, w) + b)

    # Final layer (no activation)
    final_w, final_b = params[-1]
    logits = jnp.dot(activations, final_w) + final_b
    return logits - logsumexp(logits, axis=1, keepdims=True)

In [5]:
def loss(params, x, y):
    preds = predict(params, x)
    return -jnp.mean(preds * y)

In [6]:
def accuracy(params, x, y):
    pred_classes = jnp.argmax(predict(params, x), axis=1)
    true_classes = jnp.argmax(y, axis=1)
    return jnp.mean(pred_classes == true_classes)


In [7]:
@jit
def update(params, x, y, lr):
    grads = grad(loss)(params, x, y)
    return [(w - lr * dw, b - lr * db)
            for (w, b), (dw, db) in zip(params, grads)]

In [8]:
def generate_data(n_samples, key):
    key1, key2 = random.split(key)
    X = random.uniform(key1, (n_samples, 2), minval=-1, maxval=1)
    y_raw = ((X[:, 0] > 0) ^ (X[:, 1] > 0)).astype(int)
    y = jnp.eye(2)[y_raw]
    return X, y

In [10]:
key = random.PRNGKey(42)
layer_sizes = [2, 16, 16, 2]
params = init_network_params(layer_sizes, key)

key, subkey = random.split(key)
X_train, y_train = generate_data(1000, subkey)
X_test, y_test = generate_data(200, subkey)

learning_rate = 0.01
n_epochs = 100

print("Training simple neural network with JAX...")
print(f"Architecture: {' -> '.join(map(str, layer_sizes))}")
print(f"Learning rate: {learning_rate}")
print("-" * 50)

for epoch in range(n_epochs):
    params = update(params, X_train, y_train, learning_rate)

    if (epoch + 1) % 10 == 0:
        train_loss = loss(params, X_train, y_train)
        train_acc = accuracy(params, X_train, y_train)
        test_acc = accuracy(params, X_test, y_test)
        print(f"Epoch {epoch + 1:3d} | Loss: {train_loss:.4f} | "
              f"Train Acc: {train_acc:.3f} | Test Acc: {test_acc:.3f}")

print("-" * 50)
print("Training complete!")

test_point = jnp.array([[0.5, -0.5]])
prediction = predict(params, test_point)
pred_class = jnp.argmax(prediction)
print(f"\nTest prediction for point {test_point[0]}: Class {pred_class}")

Training simple neural network with JAX...
Architecture: 2 -> 16 -> 16 -> 2
Learning rate: 0.01
--------------------------------------------------
Epoch  10 | Loss: 0.3443 | Train Acc: 0.563 | Test Acc: 0.610
Epoch  20 | Loss: 0.3426 | Train Acc: 0.561 | Test Acc: 0.610
Epoch  30 | Loss: 0.3410 | Train Acc: 0.552 | Test Acc: 0.600
Epoch  40 | Loss: 0.3395 | Train Acc: 0.540 | Test Acc: 0.585
Epoch  50 | Loss: 0.3380 | Train Acc: 0.541 | Test Acc: 0.585
Epoch  60 | Loss: 0.3366 | Train Acc: 0.545 | Test Acc: 0.585
Epoch  70 | Loss: 0.3352 | Train Acc: 0.544 | Test Acc: 0.580
Epoch  80 | Loss: 0.3339 | Train Acc: 0.547 | Test Acc: 0.590
Epoch  90 | Loss: 0.3326 | Train Acc: 0.552 | Test Acc: 0.590
Epoch 100 | Loss: 0.3313 | Train Acc: 0.559 | Test Acc: 0.590
--------------------------------------------------
Training complete!

Test prediction for point [ 0.5 -0.5]: Class 0
