In [1]:
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
import jax
import jax.numpy as jnp
import flax.nnx as nnx
from jax import random
from flax import serialization
import json
import optax
import numpy as np
from sklearn.utils.estimator_checks import check_is_fitted
from sklearn.base import BaseEstimator, _fit_context
from sklearn.model_selection import GridSearchCV
from sklearn.metrics import r2_score
from sklearn.pipeline import Pipeline
from tqdm.notebook import tqdm # progress bar
from sklearn.preprocessing import StandardScaler
import logging
from matplotlib.pyplot import plot

In [2]:
print(jax.devices())

[CpuDevice(id=0)]


In [3]:
features = np.loadtxt("../Data/Training_data/big_test_features.txt")
labels = np.loadtxt("../Data/Training_data/big_test_labels.txt")

In [None]:
# Split label array into individual columns
true_labels = labels[:, 0]          # Just the actual label
avg_radius_column = labels[:, 1]    # The one you want to move to features

# Reshape so you can concatenate it
avg_radius_column = avg_radius_column.reshape(-1, 1)

# Concatenate to features
features_total_np = np.hstack([features, avg_radius_column])

# Overwrite labels array with only the true label
labels_total_np = true_labels.reshape(-1, 1)

In [19]:
# Split into train/test
X_train, X_test, y_train, y_test = train_test_split(
    features_total_np, labels_total_np, test_size=0.2, random_state=42
)

# Standardize input features (only fit on training data!)
scaler = StandardScaler()
X_train = scaler.fit_transform(X_train)
X_test = scaler.transform(X_test)

# Convert to JAX arrays
X_train = jnp.array(X_train, dtype=jnp.float32)
y_train = jnp.array(y_train, dtype=jnp.float32)

X_test = jnp.array(X_test, dtype=jnp.float32)
y_test = jnp.array(y_test, dtype=jnp.float32)

In [20]:
@nnx.jit
def loss_fn(model, x, y_target):
   y_pred = model(x)
   return ((y_pred - y_target) ** 2).mean()

# Train for a single epoch
@nnx.jit
def train_step(model, optimizer, x, y):
    loss, grads = nnx.value_and_grad(loss_fn)(model, x, y)
    optimizer.update(grads)
    return loss

def _train_epoch(
    model, optimizer, xs_train, ys_train, batch_size, rng
):
    train_ds_size = len(xs_train)
    steps_per_epoch = train_ds_size // batch_size

    perms = random.permutation(rng, len(xs_train))
    perms = perms[: steps_per_epoch * batch_size]
    perms = perms.reshape((steps_per_epoch, batch_size))

    epoch_loss = []

    for perm in perms:
        batch_xs = xs_train[perm, :]
        batch_ys = ys_train[perm, :]
        loss = train_step(model, optimizer, batch_xs, batch_ys)
        epoch_loss.append(loss)  # store training loss for the current batch

    train_loss = np.mean(epoch_loss)
    return model, train_loss


def train(model, optimizer, xs_train, ys_train,
          batch_size,epochs, log_period_epoch=1, show_progress=True):

    train_loss_history = []

    for epoch in tqdm(range(1, epochs + 1), disable=not show_progress):
        model, train_loss = _train_epoch(
            model, optimizer, xs_train, ys_train,
            batch_size, random.key(1),
        )

        train_loss_history.append(train_loss)

        if epoch == 1 or epoch % log_period_epoch == 0:
            logging.info(
                "epoch:% 3d, train_loss: %.4f"
                % (epoch, train_loss)
            )
    return train_loss_history

In [None]:
# adapted from template https://github.com/scikit-learn-contrib/project-template/blob/main/skltemplate/_template.py

class MLPRegressor(BaseEstimator):

    def __init__(self, model, lr, epochs, batch_size, log_period_epoch=10,
                 show_progress=True):
        self.model = model
        self.lr = lr
        self.epochs = epochs
        self.batch_size = batch_size
        self.log_period_epoch = log_period_epoch
        self.show_progress = show_progress

    def fit(self, X, y):
        # `_validate_data` is defined in the `BaseEstimator` class.
        # It allows to:
        # - run different checks on the input data;
        # - define some attributes associated to the input data: `n_features_in_` and
        #   `feature_names_in_`.
        X, y = self._validate_data(X, y, accept_sparse=True)

        # transform row vectors into columns to be compatible with the output of the NN
        y = y.reshape(-1, 1)
        self._optimizer = nnx.Optimizer(self.model, optax.adam(self.lr))

        self.train_loss_history = train(self.model, self._optimizer, X, y,
                                        self.batch_size, self.epochs,
                                        self.log_period_epoch, self.show_progress)

        self.is_fitted_ = True
        # `fit` should always return `self`
        return self

    def predict(self, X):
        # Check if fit had been called
        check_is_fitted(self)
        X = self._validate_data(X, accept_sparse=True, reset=False)
        return self.model(X)

    def save(self, path="mlp_nnx_model"):
        # Get params from optimizer (Flax NNX)
        params = self._optimizer.params  # 👈 this is where they live

        # Serialize params
        params_bytes = serialization.to_bytes(params)
        with open(f"{path}_params.npz", "wb") as f:
            f.write(params_bytes)

        # Save config
        config = {
            "epochs": self.epochs,
            "batch_size": self.batch_size,
            "lr": self.lr
        }
        with open(f"{path}_config.json", "w") as f:
            json.dump(config, f)

        print("✅ Model saved.")

    def score(self, X, y):
        y_pred = self.predict(X)
        return r2_score(y, y_pred)    

In [22]:
class Model(nnx.Module):
  def __init__(self, din, dout, rngs: nnx.Rngs):
    self.linear1 = nnx.Linear(din, 64, rngs=rngs)
    self.linear2 = nnx.Linear(64, 64, rngs=rngs)
    self.linear3 = nnx.Linear(64, 32, rngs=rngs)
    self.linear4 = nnx.Linear(32, dout, rngs=rngs)

  def __call__(self, x):
    y = self.linear1(x)
    y = nnx.relu(y)
    y = self.linear2(y)
    y = nnx.relu(y)
    y = self.linear3(y)
    y = nnx.relu(y)
    y = self.linear4(y)
    return y

x_dim = X_train.shape[1]
y_dim = 1
model = Model(x_dim, y_dim, rngs=nnx.Rngs(0))

In [None]:
epochs = 40
learning_rate = 0.01
batch_size = 64

m = MLPRegressor(model, learning_rate, epochs, batch_size)

In [None]:
epochs = 40
learning_rate = 0.01
batch_size = 64

m = MLPRegressor(model, learning_rate, epochs, batch_size)


  0%|          | 0/40 [00:00<?, ?it/s]

In [None]:
m.fit(X_train, y_train.ravel())

In [None]:
from sklearn.metrics import mean_squared_error, mean_absolute_error, r2_score

# Predict on test set
y_pred = m.predict(X_test)

# Metrics
mse = mean_squared_error(y_test, y_pred)
mae = mean_absolute_error(y_test, y_pred)
r2 = r2_score(y_test, y_pred)

print(f"📊 MSE: {mse:.4f}")
print(f"📊 MAE: {mae:.4f}")
print(f"📈 R² Score: {r2:.4f}")

NotFittedError: This MLPRegressor instance is not fitted yet. Call 'fit' with appropriate arguments before using this estimator.

In [46]:
m.save()

AttributeError: 'MLPRegressor' object has no attribute '_optimizer'

## Gridsearch


In [None]:
params = {"learning_rate": [0.01, 0.1, 1],
          "batch_size": [32, 64, 128],
          }

