# Step 1: Define the Model

In [5]:
import jax
import jax.numpy as jnp
from flax import linen as nn

# Define the neural network model
class SimpleNN(nn.Module):
    @nn.compact
    def __call__(self, x):
        x = nn.Dense(10)(x)  # Dense layer with 10 units
        x = nn.relu(x)       # ReLU activation function
        x = nn.Dense(1)(x)   # Dense layer with 1 unit
        return nn.sigmoid(x) # Sigmoid activation function for binary classification

# Initialize the model
model = SimpleNN()

# We will need to initialize the parameters. Flax models require you to specify the input shapes for this.
key = jax.random.PRNGKey(0)  # Random key for parameter initialization
input_shape = (1, 2)         # Assuming input features have 2 dimensions
params = model.init(key, jnp.ones(input_shape))['params']

# Step 2: Prepare the Data

In [7]:
from sklearn.datasets import make_circles
from sklearn.model_selection import train_test_split

# Generate the dataset using sklearn
X, y = make_circles(n_samples=1000, factor=0.5, noise=0.05)

# Convert the data to JAX arrays
X = jnp.array(X)
y = jnp.array(y)

# Split the data into training and validation sets
X_train, X_val, y_train, y_val = train_test_split(X, y, test_size=0.2, random_state=42)

# Step 3: Define the Loss Function and Accuracy


In [24]:
# Define the binary cross-entropy loss function
def binary_cross_entropy_loss(params, X, y):
    logits = model.apply({'params': params}, X)
    labels = y.reshape(-1, 1)
    loss = -jnp.mean(labels * jnp.log(logits) + (1 - labels) * jnp.log(1 - logits))
    return loss

# Define a function to calculate accuracy
def accuracy(params, X, y):
    """Calculate model accuracy."""
    # Predictions using the model
    logits = model.apply({'params': params}, X)
    # Convert logits to class predictions
    predicted_classes = logits > 0.5
    # Compare predictions to true labels
    return jnp.mean(predicted_classes == y.reshape(-1, 1))

# We'll also need a function to update our model's parameters via gradient descent
@jax.jit
def update(params, X, y, lr=0.01):
    grads = jax.grad(binary_cross_entropy_loss)(params, X, y)
    # Update parameters using a simple gradient descent, using the correct non-deprecated function
    new_params = jax.tree_util.tree_map(lambda p, g: p - lr * g, params, grads)
    return new_params

# Step 4: The Training Loop

In [25]:
# Define the number of epochs and the learning rate
num_epochs = 100
learning_rate = 0.01

# Training loop
for epoch in range(num_epochs):
    # Update the parameters using the entire training set
    params = update(params, X_train, y_train, lr=learning_rate)

    # Calculate the loss and accuracy on the training set
    train_loss = binary_cross_entropy_loss(params, X_train, y_train)
    train_accuracy = accuracy(params, X_train, y_train)

    # Calculate the loss and accuracy on the validation set
    val_loss = binary_cross_entropy_loss(params, X_val, y_val)
    val_accuracy = accuracy(params, X_val, y_val)

    # Print out the loss and accuracy for each epoch
    print(f'Epoch {epoch + 1}, Train loss: {train_loss:.4f}, Train accuracy: {train_accuracy:.4f}, '
          f'Val loss: {val_loss:.4f}, Val accuracy: {val_accuracy:.4f}')


Epoch 1, Train loss: 0.7171, Train accuracy: 0.5025, Val loss: 0.7115, Val accuracy: 0.5050
Epoch 2, Train loss: 0.7169, Train accuracy: 0.5025, Val loss: 0.7113, Val accuracy: 0.4950
Epoch 3, Train loss: 0.7166, Train accuracy: 0.5037, Val loss: 0.7110, Val accuracy: 0.4950
Epoch 4, Train loss: 0.7164, Train accuracy: 0.5037, Val loss: 0.7107, Val accuracy: 0.4950
Epoch 5, Train loss: 0.7161, Train accuracy: 0.5050, Val loss: 0.7104, Val accuracy: 0.4950
Epoch 6, Train loss: 0.7158, Train accuracy: 0.5062, Val loss: 0.7101, Val accuracy: 0.4950
Epoch 7, Train loss: 0.7156, Train accuracy: 0.5062, Val loss: 0.7098, Val accuracy: 0.4950
Epoch 8, Train loss: 0.7153, Train accuracy: 0.5062, Val loss: 0.7096, Val accuracy: 0.4950
Epoch 9, Train loss: 0.7151, Train accuracy: 0.5050, Val loss: 0.7093, Val accuracy: 0.4950
Epoch 10, Train loss: 0.7148, Train accuracy: 0.5037, Val loss: 0.7090, Val accuracy: 0.4950
Epoch 11, Train loss: 0.7146, Train accuracy: 0.5025, Val loss: 0.7087, Val acc

# Step 5: Making Predictions

In [26]:
def predict(params, X_new):
    """Make predictions using the trained model parameters on new data."""
    logits = model.apply({'params': params}, X_new)
    predicted_classes = logits > 0.5  # Threshold the logits to get binary predictions
    return predicted_classes

# Example new data (ensure it has the same features and preprocessing as the training data)
X_new = jnp.array([[0.1, 0.2], [1.2, 0.2], [-0.5, 0.6], [0.3, -0.4]])  # Replace with actual new data

# Make predictions
new_predictions = predict(params, X_new)
print("Predicted classes:", new_predictions)

Predicted classes: [[False]
 [ True]
 [ True]
 [False]]
