## Step 1: Set Up the Environment
### 1. Install JAX, Flax, and Optax

Create a new Jupyter notebook and start with installing the required libraries:

In [None]:
# !pip install jax jaxlib flax optax 

### 2. Import Libraries

Start a new notebook cell with necessary imports:

In [2]:
import jax
import jax.numpy as jnp
import flax.linen as nn
import optax
from flax.training import train_state
from sklearn.datasets import fetch_openml
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler


## Step 2: Data Loading and Preprocessing

### 1. Load MNIST Dataset

In [3]:
mnist = fetch_openml('mnist_784')
X = mnist.data / 255.0
y = mnist.target.astype(int)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
scaler = StandardScaler().fit(X_train)
X_train = scaler.transform(X_train)
X_test = scaler.transform(X_test)


#### 2. Create Data Loaders

In [4]:
batch_size = 32
train_loader = [(X_train[i:i + batch_size], y_train[i:i + batch_size]) for i in range(0, len(X_train), batch_size)]
test_loader = [(X_test[i:i + batch_size], y_test[i:i + batch_size]) for i in range(0, len(X_test), batch_size)]


## Step 3: Model Definition

### 1. Define a Feedforward Neural Network using Flax

In [5]:
class FeedforwardNN(nn.Module):
    @nn.compact
    def __call__(self, x):
        x = x.reshape((x.shape[0], -1))  # Flatten the input
        x = nn.Dense(features=128)(x)
        x = nn.relu(x)
        x = nn.Dense(features=10)(x)
        return x


## Step 4: Loss Function and Optimizer

### 1. Define the Loss Function

In [6]:
def cross_entropy_loss(logits, labels):
    one_hot = jax.nn.one_hot(labels, num_classes=10)
    return jnp.mean(optax.softmax_cross_entropy(logits=logits, labels=one_hot))


#### 2. Define the Optimizer

In [9]:
def create_train_state(rng, learning_rate, model, input_shape):
    params = model.init(rng, jnp.ones(input_shape))['params']
    tx = optax.adam(learning_rate)
    return train_state.TrainState.create(apply_fn=model.apply, params=params, tx=tx)


## Step 5: Training Loop
#### 1. Define the Training Step

In [10]:
@jax.jit
def train_step(state, batch):
    def loss_fn(params):
        logits = state.apply_fn({'params': params}, batch['image'])
        loss = cross_entropy_loss(logits, batch['label'])
        return loss
    grads = jax.grad(loss_fn)(state.params)
    state = state.apply_gradients(grads=grads)
    return state


#### 2. Initialize Parameters and Optimizer State

In [11]:
rng = jax.random.PRNGKey(0)
model = FeedforwardNN()
input_shape = (batch_size, 28, 28)
state = create_train_state(rng, 0.001, model, input_shape)


#### 3. Training Loop

In [12]:
num_epochs = 5
for epoch in range(num_epochs):
    for x_batch, y_batch in train_loader:
        batch = {'image': jnp.array(x_batch), 'label': jnp.array(y_batch)}
        state = train_step(state, batch)
    print(f'Epoch {epoch + 1}/{num_epochs} completed.')


Epoch 1/5 completed.
Epoch 2/5 completed.
Epoch 3/5 completed.
Epoch 4/5 completed.
Epoch 5/5 completed.


## Step 6: Model Evaluation 

#### 1. Define the Evaluation Step

In [13]:
@jax.jit
def eval_step(params, batch):
    logits = model.apply({'params': params}, batch['image'])
    return logits


#### 2. Evaluate the Model

In [14]:
correct = 0
total = 0
for x_batch, y_batch in test_loader:
    batch = {'image': jnp.array(x_batch), 'label': jnp.array(y_batch)}
    logits = eval_step(state.params, batch)
    predicted = jnp.argmax(logits, axis=1)
    total += batch['label'].shape[0]
    correct += jnp.sum(predicted == batch['label'])
accuracy = correct / total
print(f'Accuracy: {accuracy:.4f}')


Accuracy: 0.9666


## Step 7: Interactive Elements
#### 1. Adding Markdown Cells for Explanation
##### Use Markdown cells to explain each step, providing context and instructions.

#### 2. Interactive Widgets
##### Consider using ipywidgets to make the notebook interactive:

In [15]:
from ipywidgets import interact, IntSlider

def train_and_evaluate(learning_rate=0.001, num_epochs=5):
    state = create_train_state(rng, learning_rate, model, input_shape)
    for epoch in range(num_epochs):
        for x_batch, y_batch in train_loader:
            batch = {'image': jnp.array(x_batch), 'label': jnp.array(y_batch)}
            state = train_step(state, batch)
        print(f'Epoch {epoch + 1}/{num_epochs} completed.')
    correct = 0
    total = 0
    for x_batch, y_batch in test_loader:
        batch = {'image': jnp.array(x_batch), 'label': jnp.array(y_batch)}
        logits = eval_step(state.params, batch)
        predicted = jnp.argmax(logits, axis=1)
        total += batch['label'].shape[0]
        correct += jnp.sum(predicted == batch['label'])
    accuracy = correct / total
    print(f'Accuracy: {accuracy:.4f}')
    return accuracy

interact(train_and_evaluate, learning_rate=0.001, num_epochs=IntSlider(min=1, max=10, step=1, value=5))


interactive(children=(FloatSlider(value=0.001, description='learning_rate', max=0.003, min=-0.001), IntSlider(…

<function __main__.train_and_evaluate(learning_rate=0.001, num_epochs=5)>