jax 0.5.2 -> 0.5.3
jaxlib 0.5.1 -> 0.5.3

Just In Time Compilation(JIT)- J-byte code
Automatic Differentiation- A
XLA (Accelerated Linear Algebra - X

jax arrays are immutable
jax has asynchrounus dispatch- gives control to python as qucickly as possible


#Just In Time Compilation(JIT)-byte code
@jax.jit
can be used for a funtion, but cannot be used for all functions, if the parameter is used for condition ex: f(x) with if x%==0 else
jax.make_jaxpr helps to see the intermediate representation


#Automatic Differentiation-

jax.grad(f(x))(value) -> f'(x)

print(f(x, y, z))
print(jax.grad(f, argnums=0) (x, y, z))
print(jax.grad(f, argnums=1) (x, y, z))
print(jax.grad(f, argnums=2) (x, y, z))
 or

def f(arr):
return arr[0] ** 2 + 2arr[1] ** 2+3 arr[2] ** 2

x, y, z = 2 2.0, 2.0
4x ^ 2 + 2y ^ 2 + 3z ^ 2
  d/dx (f) = 2x = 4
 d/dy (f) = 4x = 8
d/dz (f) = 6x = 12

print(f([x, y, z])) print(jax.grad (f) ([x, y, z]))


#randomness

jax uses keys for randomness simliar to seed value. keys are slipt for reproducibilty. since a key once used can not be used again.

https://www.youtube.com/watch?v=wq-UsiOkBRU

In [1]:
!pip install -q jax jaxlib

In [2]:
import jax
import jax.numpy as jnp

### JAX Randomness and PRNG Keys
Unlike NumPy, JAX requires explicit pseudo-random number generator (PRNG) keys for reproducible randomness.
You should always split keys instead of reusing them.


In [3]:
import jax
import jax.numpy as jnp

# PRNG Key setup
key = jax.random.PRNGKey(0)
print("Main key:", key)

# Always split keys before using
key, subkey1, subkey2 = jax.random.split(key, 3)
x = jax.random.normal(subkey1, (3, 3))
y = jax.random.uniform(subkey2, (3, 3))
print("Random normal sample:\n", x)
print("Random uniform sample:\n", y)


INFO:2025-10-23 01:34:52,124:jax._src.xla_bridge:924: Unable to initialize backend 'rocm': module 'jaxlib.xla_extension' has no attribute 'GpuAllocatorConfig'
INFO:2025-10-23 01:34:52,150:jax._src.xla_bridge:924: Unable to initialize backend 'tpu': INTERNAL: Failed to open libtpu.so: libtpu.so: cannot open shared object file: No such file or directory


Main key: [0 0]
Random normal sample:
 [[-2.4424558  -2.0356805   0.20554423]
 [-0.3535502  -0.76197404 -1.1785518 ]
 [-1.1482196   0.29716578 -1.3105359 ]]
Random uniform sample:
 [[0.9024495  0.91229284 0.34104764]
 [0.2200911  0.6483767  0.5075352 ]
 [0.71720433 0.22564113 0.5910187 ]]


In [4]:
# JAX arrays look like NumPy arrays but run on accelerators
x = jnp.array([1.0, 2.0, 3.0])
y = jnp.array([4.0, 5.0, 6.0])

print("x:", x)
print("y:", y)
print("dot product:", jnp.dot(x, y))

x: [1. 2. 3.]
y: [4. 5. 6.]
dot product: 32.0


In [5]:
# Simple function: squared norm
def squared_norm(x):
    return jnp.sum(x ** 2)

grad_fn = jax.grad(squared_norm)

print("Value:", squared_norm(x))
print("Gradient:", grad_fn(x))  # derivative wrt x

Value: 14.0
Gradient: [2. 4. 6.]


In [6]:
@jax.jit
def f(w, x):
    return jnp.dot(w, x)

w = jnp.array([0.1, 0.2, 0.3])
print("f:", f(w, x))

f: 1.4000001


In [7]:
### Vectorization with `vmap`
#`jax.vmap` automatically vectorizes functions over batches of inputs, avoiding manual Python loops.
def f(x):
    return x ** 2 + 1

xs = jnp.arange(5)
print("Manual:", [f(x) for x in xs])
print("Vectorized:", jax.vmap(f)(xs))


Manual: [Array(1, dtype=int32), Array(2, dtype=int32), Array(5, dtype=int32), Array(10, dtype=int32), Array(17, dtype=int32)]
Vectorized: [ 1  2  5 10 17]


In [8]:
# Define scalar function
def f_single(w, x):
    return jnp.dot(w, x)

# Batch of data
batch_w = jnp.stack([w, w, w])
batch_x = jnp.stack([x, x + 1, x + 2])

f_batch = jax.vmap(f_single, in_axes=(0, 0))  # apply across batches
print("Batched dot products:", f_batch(batch_w, batch_x))

Batched dot products: [1.4000001 2.        2.6      ]


In [9]:
### Gradients and `value_and_grad`
# Use `jax.value_and_grad` when you need both the loss value and its gradients efficiently.
def loss_fn(params, x, y):
    pred = params["w"] * x + params["b"]
    return jnp.mean((pred - y) ** 2)

params = {"w": 1.0, "b": 0.0}
x, y = jnp.array([1.0, 2.0, 3.0]), jnp.array([2.0, 4.1, 5.9])

loss, grads = jax.value_and_grad(loss_fn)(params, x, y)
print("Loss:", loss)
print("Gradients:", grads)

Loss: 4.6066666
Gradients: {'b': Array(-4., dtype=float32, weak_type=True), 'w': Array(-9.266666, dtype=float32, weak_type=True)}


In [10]:
# Linear regression y = w*x + b
def model(w, b, x):
    return w * x + b

def mse_loss(w, b, x, y):
    preds = model(w, b, x)
    return jnp.mean((preds - y)**2)

# Example data
x_data = jnp.array([1.0, 2.0, 3.0])
y_data = jnp.array([2.0, 4.1, 6.0])

# Gradient wrt (w, b)
loss_grad = jax.grad(mse_loss, argnums=(0,1))
print("Gradients:", loss_grad(1.0, 0.0, x_data, y_data))

Gradients: (Array(-9.466667, dtype=float32, weak_type=True), Array(-4.0666666, dtype=float32, weak_type=True))


In [11]:
!pip install -q jax jaxlib tensorflow_datasets


  pid, fd = os.forkpty()


[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m323.2/323.2 kB[0m [31m8.2 MB/s[0m eta [36m0:00:00[0m
[?25h[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
bigframes 2.12.0 requires google-cloud-bigquery-storage<3.0.0,>=2.30.0, which is not installed.
google-api-core 1.34.1 requires protobuf!=3.20.0,!=3.20.1,!=4.21.0,!=4.21.1,!=4.21.2,!=4.21.3,!=4.21.4,!=4.21.5,<4.0.0dev,>=3.19.5, but you have protobuf 6.33.0 which is incompatible.
google-cloud-translate 3.12.1 requires protobuf!=3.20.0,!=3.20.1,!=4.21.0,!=4.21.1,!=4.21.2,!=4.21.3,!=4.21.4,!=4.21.5,<5.0.0dev,>=3.19.5, but you have protobuf 6.33.0 which is incompatible.
google-cloud-bigtable 2.32.0 requires google-api-core[grpc]<3.0.0,>=2.17.0, but you have google-api-core 1.34.1 which is incompatible.
bigframes 2.12.0 requires google-cloud-bigquery[bqstorage,pandas]>=3.31.0, but you

In [12]:
import jax
import jax.numpy as jnp
import optax
import tensorflow_datasets as tfds

In [13]:
# --- 1. Imports ---
import jax
import jax.numpy as jnp
import optax
import tensorflow_datasets as tfds

# --- 2. Preprocessing function ---
def preprocess(image, label):
    """
    Takes one MNIST image (28x28 grayscale) and its label.
    - Flattens image into vector of length 784
    - Normalizes pixel values from [0, 255] → [0, 1]
    - Converts label into int32
    """
    image = jnp.array(image, dtype=jnp.float32).reshape(-1) / 255.0
    label = jnp.array(label, dtype=jnp.int32)
    return image, label

# --- 3. Load dataset ---
ds_train = tfds.load("mnist", split="train", shuffle_files=True)
ds_test = tfds.load("mnist", split="test")

# Convert to numpy + preprocess subset (faster for demo)
train_data = [preprocess(ex["image"], ex["label"])
              for ex in tfds.as_numpy(ds_train.take(5000))]
test_data  = [preprocess(ex["image"], ex["label"])
              for ex in tfds.as_numpy(ds_test.take(1000))]

# --- 4. Initialize parameters ---
key = jax.random.PRNGKey(0)
W = jax.random.normal(key, (784, 10)) * 0.01   # weights
b = jnp.zeros((10,))                           # bias

params = (W, b)

# --- 5. Model function ---
def model(params, x):
    W, b = params
    logits = jnp.dot(x, W) + b      # shape (10,)
    return logits

# --- 6. Loss function ---
def loss_fn(params, x, y):
    logits = model(params, x)
    one_hot = jax.nn.one_hot(y, 10)           # convert label → one-hot vector
    loss = optax.softmax_cross_entropy(logits, one_hot).mean()
    return loss

# --- 7. Accuracy function ---
def accuracy(params, dataset):
    correct = 0
    for x, y in dataset:
        pred = jnp.argmax(model(params, x))
        correct += (pred == y)
    return correct / len(dataset)

# --- 8. Optimizer ---
optimizer = optax.sgd(learning_rate=0.1)
opt_state = optimizer.init(params)

# --- 9. Training step (with JIT for speed) ---
@jax.jit
def train_step(params, opt_state, x, y):
    grads = jax.grad(loss_fn)(params, x, y)
    updates, opt_state = optimizer.update(grads, opt_state)
    params = optax.apply_updates(params, updates)
    return params, opt_state

# --- 10. Training loop ---
for epoch in range(5):   # train for 5 epochs
    for x, y in train_data:
        params, opt_state = train_step(params, opt_state, x, y)

    acc = accuracy(params, test_data)
    print(f"Epoch {epoch+1}, Test accuracy: {acc:.4f}")


Downloading and preparing dataset Unknown size (download: Unknown size, generated: Unknown size, total: Unknown size) to /root/tensorflow_datasets/mnist/3.0.1...


Dl Completed...: 0 url [00:00, ? url/s]

Dl Size...: 0 MiB [00:00, ? MiB/s]

Extraction completed...: 0 file [00:00, ? file/s]

2025-10-23 01:35:34.058026: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1761183334.420987      71 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1761183334.523223      71 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


AttributeError: 'MessageFactory' object has no attribute 'GetPrototype'

AttributeError: 'MessageFactory' object has no attribute 'GetPrototype'

AttributeError: 'MessageFactory' object has no attribute 'GetPrototype'

AttributeError: 'MessageFactory' object has no attribute 'GetPrototype'

AttributeError: 'MessageFactory' object has no attribute 'GetPrototype'

Generating splits...:   0%|          | 0/2 [00:00<?, ? splits/s]

Generating train examples...: 0 examples [00:00, ? examples/s]

2025-10-23 01:35:53.611198: E external/local_xla/xla/stream_executor/cuda/cuda_driver.cc:152] failed call to cuInit: INTERNAL: CUDA error: Failed call to cuInit: UNKNOWN ERROR (303)


Shuffling /root/tensorflow_datasets/mnist/incomplete.6NG3OA_3.0.1/mnist-train.tfrecord*...:   0%|          | 0…

Generating test examples...: 0 examples [00:00, ? examples/s]

Shuffling /root/tensorflow_datasets/mnist/incomplete.6NG3OA_3.0.1/mnist-test.tfrecord*...:   0%|          | 0/…

Dataset mnist downloaded and prepared to /root/tensorflow_datasets/mnist/3.0.1. Subsequent calls will reuse this data.
Epoch 1, Test accuracy: 0.8610
Epoch 2, Test accuracy: 0.8700
Epoch 3, Test accuracy: 0.8510
Epoch 4, Test accuracy: 0.8800
Epoch 5, Test accuracy: 0.8730


In [14]:
import jax
import jax.numpy as np
import time

from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import OneHotEncoder, StandardScaler

In [15]:
data = load_iris()
X = data.data
y = data.target.reshape(-1, 1)

encoder=OneHotEncoder(sparse_output=False)
y = encoder.fit_transform(y)

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2)

scaler= StandardScaler()
X_train=scaler.fit_transform(X_train)
X_test=scaler.transform(X_test)

In [16]:
#input -> hidden layer 1-> hidden layer 2 -> output
def init_params(input_dim, hidden_dim1, hidden_dim2, output_dim, random_key):
  random_keys=jax.random.split(random_key, 3)

  W1=jax.random.normal(random_keys[0], (input_dim, hidden_dim1))
  b1=jnp.zeros((hidden_dim1,))
  W2 =jax.random.normal(random_keys [1], (hidden_dim1, hidden_dim2))
  b2=jnp.zeros((hidden_dim2,))
  W3=jax.random.normal(random_keys [2], (hidden_dim2, output_dim))
  b3=jnp.zeros((output_dim,))

  return W1, b1, W2, b2, W3, b3

In [17]:
def forward(params, X):
  W1, b1, W2, b2, W3, b3 = params
  h1 = jax.nn.relu(jnp.dot(X, W1) + b1)
  h2 = jax.nn.relu(jnp.dot(h1, W2) + b2)
  logits = jnp.dot(h2, W3) + b3
  return logits

In [18]:
def loss_fn(params, x, y, l2_reg=0.0001):
  logits = forward (params, x)
  probs= jax.nn.softmax(logits)
  l2_loss = l2_reg*sum([jnp.sum(w** 2) for w in params[::2]])
  return -jnp.mean(jnp.sum(y*jnp.log(probs + 1e-8), axis=1)) + l2_loss

In [19]:
@jax.jit
def train_step (params, x, y, lr):
  grads=jax.grad (loss_fn) (params, x, y)
  return [(param - lr * grad) for param, grad in zip (params, grads)]

In [20]:
def accuracy (params, x, y):
  preds = jnp.argmax(forward (params, x), axis=1)
  targets = jnp.argmax(y, axis=1)
  return jnp.mean (preds == targets)

def data_loader (X, y, batch_size):
  for i in range(0, len (X), batch_size):
    yield X[i:i+batch_size], y [i:i+batch_size]

In [21]:
random_key = jax.random.key(int(time.time()))
input_dim = X_train.shape[1]
hidden_dim1 = 16
hidden_dim2 = 8
output_dim = y_train.shape [1]
learning_rate = 0.005
batch_size = 16
epochs = 200

params= init_params(input_dim, hidden_dim1, hidden_dim2, output_dim, random_key)

for epoch in range(epochs):
    for X_batch, y_batch in data_loader (X_train, y_train, batch_size):
        params = train_step(params, X_batch, y_batch, learning_rate)

    if epoch % 10 == 0:
        train_acc = accuracy(params, X_train, y_train)
        test_acc = accuracy(params, X_test, y_test)

        print(f'Epoch {epoch}: Train Acc ({train_acc}), Test Acc ({test_acc})')



print(f'Final Test Acc: {accuracy(params,X_test,y_test)}')



Epoch 0: Train Acc (0.5583333373069763), Test Acc (0.6000000238418579)
Epoch 10: Train Acc (0.8333333730697632), Test Acc (0.8333333730697632)
Epoch 20: Train Acc (0.8750000596046448), Test Acc (0.8333333730697632)
Epoch 30: Train Acc (0.9250000715255737), Test Acc (0.8333333730697632)
Epoch 40: Train Acc (0.9666666984558105), Test Acc (0.8333333730697632)
Epoch 50: Train Acc (0.98333340883255), Test Acc (0.9000000357627869)
Epoch 60: Train Acc (0.98333340883255), Test Acc (0.9333333969116211)
Epoch 70: Train Acc (0.98333340883255), Test Acc (0.9666666984558105)
Epoch 80: Train Acc (0.98333340883255), Test Acc (0.9666666984558105)
Epoch 90: Train Acc (0.98333340883255), Test Acc (0.9666666984558105)
Epoch 100: Train Acc (0.98333340883255), Test Acc (0.9666666984558105)
Epoch 110: Train Acc (0.98333340883255), Test Acc (0.9666666984558105)
Epoch 120: Train Acc (0.98333340883255), Test Acc (0.9666666984558105)
Epoch 130: Train Acc (0.98333340883255), Test Acc (0.9666666984558105)
Epoch 1