In [1]:
# Key splitting option given for multi devices parrallelism which is not there in PyTorch
# Still a bit unclear the exact advantage of splitting where they can define multiple key for multiple devices.

# Linear Regression

In [2]:
from sklearn.datasets import load_breast_cancer
import pandas as pd

# Load dataset
data = load_breast_cancer()

# Convert to DataFrame
df = pd.DataFrame(data.data, columns=data.feature_names)
df["target"] = data.target  # add target column

df.head()

Unnamed: 0,mean radius,mean texture,mean perimeter,mean area,mean smoothness,mean compactness,mean concavity,mean concave points,mean symmetry,mean fractal dimension,...,worst texture,worst perimeter,worst area,worst smoothness,worst compactness,worst concavity,worst concave points,worst symmetry,worst fractal dimension,target
0,17.99,10.38,122.8,1001.0,0.1184,0.2776,0.3001,0.1471,0.2419,0.07871,...,17.33,184.6,2019.0,0.1622,0.6656,0.7119,0.2654,0.4601,0.1189,0
1,20.57,17.77,132.9,1326.0,0.08474,0.07864,0.0869,0.07017,0.1812,0.05667,...,23.41,158.8,1956.0,0.1238,0.1866,0.2416,0.186,0.275,0.08902,0
2,19.69,21.25,130.0,1203.0,0.1096,0.1599,0.1974,0.1279,0.2069,0.05999,...,25.53,152.5,1709.0,0.1444,0.4245,0.4504,0.243,0.3613,0.08758,0
3,11.42,20.38,77.58,386.1,0.1425,0.2839,0.2414,0.1052,0.2597,0.09744,...,26.5,98.87,567.7,0.2098,0.8663,0.6869,0.2575,0.6638,0.173,0
4,20.29,14.34,135.1,1297.0,0.1003,0.1328,0.198,0.1043,0.1809,0.05883,...,16.67,152.2,1575.0,0.1374,0.205,0.4,0.1625,0.2364,0.07678,0


In [3]:
df.describe()

Unnamed: 0,mean radius,mean texture,mean perimeter,mean area,mean smoothness,mean compactness,mean concavity,mean concave points,mean symmetry,mean fractal dimension,...,worst texture,worst perimeter,worst area,worst smoothness,worst compactness,worst concavity,worst concave points,worst symmetry,worst fractal dimension,target
count,569.0,569.0,569.0,569.0,569.0,569.0,569.0,569.0,569.0,569.0,...,569.0,569.0,569.0,569.0,569.0,569.0,569.0,569.0,569.0,569.0
mean,14.127292,19.289649,91.969033,654.889104,0.09636,0.104341,0.088799,0.048919,0.181162,0.062798,...,25.677223,107.261213,880.583128,0.132369,0.254265,0.272188,0.114606,0.290076,0.083946,0.627417
std,3.524049,4.301036,24.298981,351.914129,0.014064,0.052813,0.07972,0.038803,0.027414,0.00706,...,6.146258,33.602542,569.356993,0.022832,0.157336,0.208624,0.065732,0.061867,0.018061,0.483918
min,6.981,9.71,43.79,143.5,0.05263,0.01938,0.0,0.0,0.106,0.04996,...,12.02,50.41,185.2,0.07117,0.02729,0.0,0.0,0.1565,0.05504,0.0
25%,11.7,16.17,75.17,420.3,0.08637,0.06492,0.02956,0.02031,0.1619,0.0577,...,21.08,84.11,515.3,0.1166,0.1472,0.1145,0.06493,0.2504,0.07146,0.0
50%,13.37,18.84,86.24,551.1,0.09587,0.09263,0.06154,0.0335,0.1792,0.06154,...,25.41,97.66,686.5,0.1313,0.2119,0.2267,0.09993,0.2822,0.08004,1.0
75%,15.78,21.8,104.1,782.7,0.1053,0.1304,0.1307,0.074,0.1957,0.06612,...,29.72,125.4,1084.0,0.146,0.3391,0.3829,0.1614,0.3179,0.09208,1.0
max,28.11,39.28,188.5,2501.0,0.1634,0.3454,0.4268,0.2012,0.304,0.09744,...,49.54,251.2,4254.0,0.2226,1.058,1.252,0.291,0.6638,0.2075,1.0


In [4]:
df['target'].unique()

array([0, 1])

In [5]:
# ------------------------------
# sklearn implementation
# ------------------------------
from sklearn.datasets import load_breast_cancer
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import accuracy_score
from sklearn.model_selection import train_test_split

# Load dataset
X, y = load_breast_cancer(return_X_y=True)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

# Train sklearn Logistic Regression (linear classifier for binary task)
clf = LogisticRegression(max_iter=10000)
clf.fit(X_train, y_train)

print("sklearn Accuracy:", accuracy_score(y_test, clf.predict(X_test)))

sklearn Accuracy: 0.956140350877193


In [6]:
from sklearn.preprocessing import StandardScaler
import optax
import jax.numpy as jnp
import jax

# Scale features (like sklearn would)
scaler = StandardScaler()
X_train_scaled = scaler.fit_transform(X_train)
X_test_scaled = scaler.transform(X_test)

# Convert to JAX
X_train = jnp.array(X_train_scaled, dtype=jnp.float32)
y_train = jnp.array(y_train, dtype=jnp.float32).reshape(-1, 1)
X_test = jnp.array(X_test_scaled, dtype=jnp.float32)
y_test = jnp.array(y_test, dtype=jnp.float32).reshape(-1, 1)

# Initialize weights + bias
key = jax.random.PRNGKey(0)
W = jax.random.normal(key, (X_train.shape[1], 1)) * 0.01 # Here, 0.01 is going to be the std of normal distribution it is a scaling factor and due to scaling the value itself becomes std

#W is a (num_features, 1) weight matrix.
#Initialized from 𝒩(0, 0.01²) distribution.
#If W = 0, every feature contributes equally at start.


b = jnp.zeros((1,))   # Numpy/JAX broadcasting this single 
params = (W, b)

# Sigmoid
def sigmoid(x): return 1 / (1 + jnp.exp(-x))

def predict(params, X):
    W, b = params
    return sigmoid(X @ W + b)

# Binary cross-entropy loss
def loss_fn(params, X, y):
    y_pred = predict(params, X)
    return -jnp.mean(y * jnp.log(y_pred + 1e-8) + (1 - y) * jnp.log(1 - y_pred + 1e-8))
#Computes the predictions, which is the prediction probability y^ (y cap), after that BCE penalizes the predictions which are far from true_pred(y). 
# In that way, if diff between y_pred and y is larger than BCE is more which is how the loss works.

# Optimizer (higher learning rate)
optimizer = optax.sgd(learning_rate=1e-5)
opt_state = optimizer.init(params)

@jax.jit
def update(params, opt_state, X, y):
    grads = jax.grad(loss_fn)(params, X, y) # computes gradients of loss w.r.t W,b
    updates, opt_state = optimizer.update(grads, opt_state) # Opt then updates the finds the difference in parameters, updates and returns them along with a new opt state
    new_params = optax.apply_updates(params, updates) # The new paramters are applied again 
    return new_params, opt_state

# In PyTorch:

# Call .backward() to compute grads (stored inside tensors).

# Call optimizer.step() to apply updates (internally modifies parameters).

# The state is hidden inside the optimizer object.

# In JAX:

# You explicitly get grads.

# You explicitly compute updates.

# You explicitly return (new_params, new_state).

# Training loop
for epoch in range(10000):
    params, opt_state = update(params, opt_state, X_train, y_train)

# Evaluate
y_pred = predict(params, X_test) > 0.5
acc = jnp.mean(y_pred == y_test)
print("JAX Accuracy:", float(acc))


INFO:2025-10-07 23:51:13,998:jax._src.xla_bridge:924: Unable to initialize backend 'rocm': module 'jaxlib.xla_extension' has no attribute 'GpuAllocatorConfig'
INFO:2025-10-07 23:51:14,000:jax._src.xla_bridge:924: Unable to initialize backend 'tpu': INTERNAL: Failed to open libtpu.so: libtpu.so: cannot open shared object file: No such file or directory


JAX Accuracy: 0.9649122953414917


In [7]:
# #Training Differences:

## PyTorch

# # 1. Forward pass
# y_pred = model(X)             
    
# # 2. Compute loss
# loss = loss_fn(y_pred, y)    
    
# # 3. Zero old gradients (important!)
# optimizer.zero_grad()          
    
# # 4. Backward pass (compute grads)
# loss.backward()                
    
# # 5. Update parameters
# optimizer.step()        

# Gradients are stored inside each parameter (param.grad).

# loss.backward() accumulates into param.grad.

# zero_grad() clears old grads so accumulation doesn’t mess things up.

# optimizer.step() uses the stored grads to update params in-place.



## JAX

# # 1. Compute gradients w.r.t loss
# grads = jax.grad(loss_fn)(params, X, y)  
    
# # 2. Use optimizer to compute updates + new opt state
# updates, opt_state = optimizer.update(grads, opt_state)  
    
# # 3. Apply updates to parameters
# params = optax.apply_updates(params, updates)  

# No zero_grad() — grads don’t accumulate, they’re computed fresh each step.

# opt_state is explicitly passed around to track optimizer history

# Params are immutable → we return new params each step.

# Functional style: each function call = new state, no side effects & with JIT faster runtime.

# MLP for Binary Classification

In [8]:
X_train.shape[1]

30

In [9]:
import torch
import torch.nn as nn
import torch.optim as optim
from sklearn.datasets import load_breast_cancer
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler

# Load and preprocess dataset
X, y = load_breast_cancer(return_X_y=True)
scaler = StandardScaler()
X = scaler.fit_transform(X)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

X_train_torch = torch.tensor(X_train, dtype=torch.float32)
y_train_torch = torch.tensor(y_train, dtype=torch.float32).unsqueeze(1)
X_test_torch = torch.tensor(X_test, dtype=torch.float32)
y_test_torch = torch.tensor(y_test, dtype=torch.float32).unsqueeze(1)

# Define MLP
class MLP(nn.Module):
    def __init__(self, input_dim):
        super().__init__()
        self.fc1 = nn.Linear(input_dim, 64)
        self.fc2 = nn.Linear(64, 32) 
        self.fc3 = nn.Linear(32, 1)
    def forward(self, x):
        x = torch.relu(self.fc1(x)) # ReLU for adding non-linearity
        x = torch.relu(self.fc2(x))
        return torch.sigmoid(self.fc3(x))

# Initialize model, loss, optimizer
model = MLP(X_train.shape[1])
criterion = nn.BCELoss()
optimizer = optim.Adam(model.parameters(), lr=1e-5)

# Training loop
num_epochs = 256
batch_size = 64

for epoch in range(num_epochs):
    perm = torch.randperm(X_train_torch.size(0)) #Random permutations so that each epoch get different batches which helps the model to learn and converge quickly
    total_loss = 0
    for i in range(0, X_train_torch.size(0), batch_size):
        idx = perm[i:i+batch_size]
        xb, yb = X_train_torch[idx], y_train_torch[idx]
        pred = model(xb)
        loss = criterion(pred, yb)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        total_loss += loss.item() * len(xb)
    avg_loss = total_loss / len(X_train_torch)
    
    # Evaluate each epoch
    with torch.no_grad():
        preds = (model(X_test_torch) > 0.5).float()
        acc = (preds.eq(y_test_torch)).float().mean().item()
    print(f"Epoch [{epoch+1}/{num_epochs}]  Loss: {avg_loss:.4f}  Test Accuracy: {acc*100:.2f}%")

# Final evaluation
with torch.no_grad():
    preds = (model(X_test_torch) > 0.5).float()
    test_acc = (preds.eq(y_test_torch)).float().mean().item()

print(f"\nFinal Test Accuracy: {test_acc*100:.2f}%")


Epoch [1/256]  Loss: 0.6839  Test Accuracy: 63.16%
Epoch [2/256]  Loss: 0.6831  Test Accuracy: 63.16%
Epoch [3/256]  Loss: 0.6823  Test Accuracy: 63.16%
Epoch [4/256]  Loss: 0.6815  Test Accuracy: 64.04%
Epoch [5/256]  Loss: 0.6807  Test Accuracy: 64.91%
Epoch [6/256]  Loss: 0.6800  Test Accuracy: 64.91%
Epoch [7/256]  Loss: 0.6792  Test Accuracy: 64.91%
Epoch [8/256]  Loss: 0.6784  Test Accuracy: 65.79%
Epoch [9/256]  Loss: 0.6777  Test Accuracy: 65.79%
Epoch [10/256]  Loss: 0.6769  Test Accuracy: 65.79%
Epoch [11/256]  Loss: 0.6761  Test Accuracy: 65.79%
Epoch [12/256]  Loss: 0.6753  Test Accuracy: 65.79%
Epoch [13/256]  Loss: 0.6745  Test Accuracy: 65.79%
Epoch [14/256]  Loss: 0.6738  Test Accuracy: 65.79%
Epoch [15/256]  Loss: 0.6730  Test Accuracy: 65.79%
Epoch [16/256]  Loss: 0.6722  Test Accuracy: 65.79%
Epoch [17/256]  Loss: 0.6714  Test Accuracy: 65.79%
Epoch [18/256]  Loss: 0.6706  Test Accuracy: 65.79%
Epoch [19/256]  Loss: 0.6698  Test Accuracy: 65.79%
Epoch [20/256]  Loss:

In [10]:
import jax
import jax.numpy as jnp
import numpy as np
from sklearn.preprocessing import StandardScaler
import optax
import flax.linen as nn
from flax.training import train_state
from sklearn.datasets import load_breast_cancer
from sklearn.model_selection import train_test_split

# ------------------------------------------------------------
# Load and preprocess dataset
# ------------------------------------------------------------
X, y = load_breast_cancer(return_X_y=True)
scaler = StandardScaler()
X = scaler.fit_transform(X)
X_train, X_test, y_train, y_test = train_test_split(
    X, y, test_size=0.2, random_state=42
)

X_train = jnp.array(X_train, dtype=jnp.float32)
y_train = jnp.array(y_train, dtype=jnp.float32).reshape(-1, 1)
X_test = jnp.array(X_test, dtype=jnp.float32)
y_test = jnp.array(y_test, dtype=jnp.float32).reshape(-1, 1)

# ------------------------------------------------------------
# Define MLP model (64 -> 32 -> 1)
# ------------------------------------------------------------
class MLP(nn.Module):
    @nn.compact
    def __call__(self, x):
        x = nn.relu(nn.Dense(64)(x))
        x = nn.relu(nn.Dense(32)(x))
        x = nn.Dense(1)(x)
        return nn.sigmoid(x)

# ------------------------------------------------------------
# Initialize model and optimizer
# ------------------------------------------------------------
model = MLP()
rng = jax.random.PRNGKey(0)
params = model.init(rng, X_train) # model.init() requires key and x_train so that it will automatically calls the shape and initializes the internal parameter iniatialization which is done explicitly in linear regression example.

state = train_state.TrainState.create(
    apply_fn=model.apply,
    params=params,
    tx=optax.adam(1e-5)   # keep learning rate 1e-5
)

# ------------------------------------------------------------
# Loss and accuracy functions
# ------------------------------------------------------------
def loss_fn(params, X, y):
    preds = model.apply(params, X)
    return jnp.mean(optax.sigmoid_binary_cross_entropy(preds, y))

def accuracy_fn(params, X, y):
    preds = model.apply(params, X)
    preds_binary = preds > 0.5
    return jnp.mean(preds_binary == y)

# ------------------------------------------------------------
# Training step (JIT compiled)
# ------------------------------------------------------------
@jax.jit
def train_step(state, X, y):
    grads = jax.grad(loss_fn)(state.params, X, y)
    state = state.apply_gradients(grads=grads)
    loss = loss_fn(state.params, X, y)
    acc = accuracy_fn(state.params, X, y)
    return state, loss, acc

# ------------------------------------------------------------
# Data loader for batching
# ------------------------------------------------------------
def data_loader(X, y, batch_size, rng):
    n = X.shape[0]
    perm = jax.random.permutation(rng, n)
    for i in range(0, n, batch_size):
        idx = perm[i:i+batch_size]
        yield X[idx], y[idx]

# ------------------------------------------------------------
# Training loop
# ------------------------------------------------------------
epochs = 256
batch_size = 64
for epoch in range(epochs):
    rng, input_rng = jax.random.split(rng)
    epoch_losses, epoch_accs = [], []
    
    for X_batch, y_batch in data_loader(X_train, y_train, batch_size, input_rng):
        state, loss_val, acc_val = train_step(state, X_batch, y_batch)
        epoch_losses.append(loss_val)
        epoch_accs.append(acc_val)
    
    mean_loss = jnp.mean(jnp.array(epoch_losses))
    mean_acc = jnp.mean(jnp.array(epoch_accs))
    
    if (epoch + 1) % 10 == 0 or epoch == 0:
        print(f"Epoch {epoch+1:03d} | Loss: {mean_loss:.4f} | Train Acc: {mean_acc:.4f}")

# ------------------------------------------------------------
# Final test accuracy
# ------------------------------------------------------------
test_preds = model.apply(state.params, X_test) > 0.5
test_acc = jnp.mean(test_preds == y_test)
print("\nFinal Test Accuracy:", float(test_acc))

Epoch 001 | Loss: 0.6331 | Train Acc: 0.7065
Epoch 010 | Loss: 0.6179 | Train Acc: 0.7517
Epoch 020 | Loss: 0.6402 | Train Acc: 0.7453
Epoch 030 | Loss: 0.6188 | Train Acc: 0.7729
Epoch 040 | Loss: 0.6292 | Train Acc: 0.8203
Epoch 050 | Loss: 0.5979 | Train Acc: 0.8359
Epoch 060 | Loss: 0.5918 | Train Acc: 0.8477
Epoch 070 | Loss: 0.5867 | Train Acc: 0.8594
Epoch 080 | Loss: 0.5849 | Train Acc: 0.8750
Epoch 090 | Loss: 0.5912 | Train Acc: 0.8848
Epoch 100 | Loss: 0.5842 | Train Acc: 0.8926
Epoch 110 | Loss: 0.5810 | Train Acc: 0.9004
Epoch 120 | Loss: 0.5831 | Train Acc: 0.9062
Epoch 130 | Loss: 0.5681 | Train Acc: 0.8943
Epoch 140 | Loss: 0.5591 | Train Acc: 0.9121
Epoch 150 | Loss: 0.5816 | Train Acc: 0.9199
Epoch 160 | Loss: 0.5677 | Train Acc: 0.8761
Epoch 170 | Loss: 0.5656 | Train Acc: 0.9277
Epoch 180 | Loss: 0.5630 | Train Acc: 0.9316
Epoch 190 | Loss: 0.5627 | Train Acc: 0.9216
Epoch 200 | Loss: 0.5634 | Train Acc: 0.9235
Epoch 210 | Loss: 0.5620 | Train Acc: 0.9414
Epoch 220 

### CNN for MNIST

In [12]:
# ------------------------------
# PyTorch CNN on MNIST (CUDA optimized for Kaggle)
# ------------------------------
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

# Device setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Dataset
transform = transforms.ToTensor()
train_data = datasets.MNIST(root="./data", train=True, transform=transform, download=True)
test_data = datasets.MNIST(root="./data", train=False, transform=transform, download=True)
train_loader = DataLoader(train_data, batch_size=128, shuffle=True)
test_loader = DataLoader(test_data, batch_size=1000)

# Define CNN
class CNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 32, 3, 1)
        self.fc1 = nn.Linear(26*26*32, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = torch.relu(self.conv1(x))
        x = torch.flatten(x, 1)
        x = torch.relu(self.fc1(x))
        return self.fc2(x)

# Initialize model, loss, and optimizer
model = CNN().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-5)

# Training loop
epochs = 16
for epoch in range(epochs):
    model.train()
    correct = 0
    total = 0
    running_loss = 0.0

    for xb, yb in train_loader:
        xb, yb = xb.to(device), yb.to(device)

        pred = model(xb)
        loss = criterion(pred, yb)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        running_loss += loss.item() * xb.size(0)
        correct += (pred.argmax(dim=1) == yb).sum().item()
        total += yb.size(0)

    train_acc = correct / total
    avg_loss = running_loss / total
    print(f"Epoch {epoch+1}/{epochs} | Loss: {avg_loss:.4f} | Train Acc: {train_acc*100:.2f}%")

# Evaluation on test data
model.eval()
correct = 0
with torch.no_grad():
    for xb, yb in test_loader:
        xb, yb = xb.to(device), yb.to(device)
        pred = model(xb).argmax(dim=1)
        correct += pred.eq(yb).sum().item()

test_acc = correct / len(test_data)
print(f"\nFinal Test Accuracy: {test_acc*100:.2f}%")


Using device: cuda
Epoch 1/16 | Loss: 1.5009 | Train Acc: 74.42%
Epoch 2/16 | Loss: 0.7410 | Train Acc: 85.61%
Epoch 3/16 | Loss: 0.5176 | Train Acc: 88.06%
Epoch 4/16 | Loss: 0.4219 | Train Acc: 89.39%
Epoch 5/16 | Loss: 0.3693 | Train Acc: 90.16%
Epoch 6/16 | Loss: 0.3350 | Train Acc: 90.81%
Epoch 7/16 | Loss: 0.3107 | Train Acc: 91.32%
Epoch 8/16 | Loss: 0.2912 | Train Acc: 91.84%
Epoch 9/16 | Loss: 0.2755 | Train Acc: 92.21%
Epoch 10/16 | Loss: 0.2617 | Train Acc: 92.65%
Epoch 11/16 | Loss: 0.2499 | Train Acc: 92.95%
Epoch 12/16 | Loss: 0.2393 | Train Acc: 93.24%
Epoch 13/16 | Loss: 0.2299 | Train Acc: 93.50%
Epoch 14/16 | Loss: 0.2214 | Train Acc: 93.74%
Epoch 15/16 | Loss: 0.2133 | Train Acc: 93.95%
Epoch 16/16 | Loss: 0.2057 | Train Acc: 94.22%

Final Test Accuracy: 94.41%


In [13]:
#No need to manually tell the code to use GPU device it does it automatically.

import jax
import jax.numpy as jnp
import flax.linen as nn
from flax.training import train_state
import optax
import tensorflow_datasets as tfds

# Dataset
ds_builder = tfds.builder("mnist")
ds_builder.download_and_prepare()
train_ds = ds_builder.as_dataset(split="train").batch(64).prefetch(1)
test_ds = ds_builder.as_dataset(split="test").batch(1000)

def preprocess(batch):
    images = jnp.array(batch["image"], dtype=jnp.float32) / 255.0
    labels = jnp.array(batch["label"], dtype=jnp.int32)
    return images, labels

# Define CNN
class CNN(nn.Module):
    @nn.compact
    def __call__(self, x):
        x = nn.Conv(32, (3,3))(x)
        x = nn.relu(x)
        x = x.reshape((x.shape[0], -1))
        x = nn.relu(nn.Dense(128)(x))
        x = nn.Dense(10)(x)
        return x

# Init model
model = CNN()
params = model.init(jax.random.PRNGKey(0), jnp.ones((1,28,28,1)))
state = train_state.TrainState.create(
    apply_fn=model.apply,
    params=params,
    tx=optax.adam(1e-5)
)

# Loss and accuracy functions
def loss_fn(params, X, y):
    logits = model.apply(params, X)
    return optax.softmax_cross_entropy_with_integer_labels(logits, y).mean()

def accuracy_fn(params, X, y):
    logits = model.apply(params, X)
    preds = jnp.argmax(logits, axis=1)
    return jnp.mean(preds == y)

@jax.jit
def train_step(state, X, y):
    grads = jax.grad(loss_fn)(state.params, X, y)
    state = state.apply_gradients(grads=grads)
    loss = loss_fn(state.params, X, y)
    acc = accuracy_fn(state.params, X, y)
    return state, loss, acc

# Training loop
for epoch in range(16):
    train_loss, train_acc = 0.0, 0.0
    num_batches = 0
    for batch in train_ds:
        X, y = preprocess(batch)
        state, loss_val, acc_val = train_step(state, X, y)
        train_loss += loss_val
        train_acc += acc_val
        num_batches += 1
    train_loss /= num_batches
    train_acc /= num_batches

    # Evaluate on test set
    correct, total = 0, 0
    for batch in test_ds:
        X, y = preprocess(batch)
        logits = model.apply(state.params, X)
        preds = jnp.argmax(logits, axis=1)
        correct += (preds == y).sum()
        total += len(y)
    test_acc = correct / total

    print(f"Epoch {epoch+1:02d} | Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.4f} | Test Acc: {float(test_acc):.4f}")

## apply_gradients is for updating gradient which does both forward n backward propagation and apply is for updating the opt state it only does forward pass not the backward

Downloading and preparing dataset Unknown size (download: Unknown size, generated: Unknown size, total: Unknown size) to /root/tensorflow_datasets/mnist/3.0.1...


Dl Completed...: 0 url [00:00, ? url/s]

Dl Size...: 0 MiB [00:00, ? MiB/s]

Extraction completed...: 0 file [00:00, ? file/s]

2025-10-07 23:53:27.873815: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1759881208.070627     306 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1759881208.135063     306 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


Generating splits...:   0%|          | 0/2 [00:00<?, ? splits/s]

Generating train examples...: 0 examples [00:00, ? examples/s]

I0000 00:00:1759881217.070578     110 gpu_device.cc:2022] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 2967 MB memory:  -> device: 0, name: Tesla P100-PCIE-16GB, pci bus id: 0000:00:04.0, compute capability: 6.0


Shuffling /root/tensorflow_datasets/mnist/incomplete.GI32C6_3.0.1/mnist-train.tfrecord*...:   0%|          | 0…

Generating test examples...: 0 examples [00:00, ? examples/s]

Shuffling /root/tensorflow_datasets/mnist/incomplete.GI32C6_3.0.1/mnist-test.tfrecord*...:   0%|          | 0/…

Dataset mnist downloaded and prepared to /root/tensorflow_datasets/mnist/3.0.1. Subsequent calls will reuse this data.
Epoch 01 | Train Loss: 0.7114 | Train Acc: 0.8497 | Test Acc: 0.9174
Epoch 02 | Train Loss: 0.2998 | Train Acc: 0.9222 | Test Acc: 0.9337
Epoch 03 | Train Loss: 0.2356 | Train Acc: 0.9371 | Test Acc: 0.9432
Epoch 04 | Train Loss: 0.2001 | Train Acc: 0.9459 | Test Acc: 0.9486
Epoch 05 | Train Loss: 0.1752 | Train Acc: 0.9522 | Test Acc: 0.9541
Epoch 06 | Train Loss: 0.1559 | Train Acc: 0.9575 | Test Acc: 0.9582
Epoch 07 | Train Loss: 0.1403 | Train Acc: 0.9625 | Test Acc: 0.9606
Epoch 08 | Train Loss: 0.1273 | Train Acc: 0.9662 | Test Acc: 0.9639
Epoch 09 | Train Loss: 0.1163 | Train Acc: 0.9695 | Test Acc: 0.9672
Epoch 10 | Train Loss: 0.1068 | Train Acc: 0.9719 | Test Acc: 0.9692
Epoch 11 | Train Loss: 0.0985 | Train Acc: 0.9746 | Test Acc: 0.9705
Epoch 12 | Train Loss: 0.0913 | Train Acc: 0.9768 | Test Acc: 0.9722
Epoch 13 | Train Loss: 0.0848 | Train Acc: 0.9788 | T

In [16]:
print(jax.devices())

[CudaDevice(id=0)]
