In [1]:
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
import jax
import jax.numpy as jnp
import flax.nnx as nnx
from jax import random
from flax import serialization
import json
import optax
import numpy as np
from sklearn.utils.estimator_checks import check_is_fitted
from sklearn.base import BaseEstimator, _fit_context
from sklearn.model_selection import GridSearchCV
from sklearn.metrics import r2_score
from sklearn.pipeline import Pipeline
from tqdm.notebook import tqdm # progress bar
from sklearn.preprocessing import StandardScaler
import logging
from matplotlib.pyplot import plot

In [2]:
print(jax.devices())

[CudaDevice(id=0)]


In [3]:
features = np.loadtxt("../Data/Training_data/big_test_features.txt")
labels = np.loadtxt("../Data/Training_data/big_test_labels.txt")

In [5]:
# Split label array into individual columns
true_labels = labels[:, 0]          # Just the actual label
avg_radius_column = labels[:, 1]    # The one you want to move to features

# Reshape so you can concatenate it
avg_radius_column = avg_radius_column.reshape(-1, 1)

# Concatenate to features
features_total_np = np.hstack([features, avg_radius_column])

# Overwrite labels array with only the true label
labels_total_np = true_labels.reshape(-1, 1)


features_total_np = np.delete(features_total_np, [1,4,5,9], axis=1) # Remove variation


In [7]:
# Split into train/test
X_train, X_test, y_train, y_test = train_test_split(features_total_np, labels_total_np, test_size=0.2, random_state=42)



# Standardize input features (only fit on training data!)
scaler = StandardScaler()


X_train = scaler.fit_transform(X_train)
X_test = scaler.transform(X_test)

X_grid, _, y_grid, _ = train_test_split(X_train, y_train, train_size=0.4, random_state=42) # Only using 40% of the data for gridsearching


# Convert to JAX arrays
X_train = jnp.array(X_train, dtype=jnp.float32)
y_train = jnp.array(y_train, dtype=jnp.float32)

X_test = jnp.array(X_test, dtype=jnp.float32)
y_test = jnp.array(y_test, dtype=jnp.float32)

# X_grid = jnp.array(X_grid, dtype=jnp.float32)
# y_grid = jnp.array(y_grid, dtype=jnp.float32)



In [None]:
X_test.shape

(459421, 8)

In [9]:
@nnx.jit
def loss_fn(model, x, y_target):
   y_pred = model(x)
   return ((y_pred - y_target) ** 2).mean()

# Train for a single epoch
@nnx.jit
def train_step(model, optimizer, x, y):
    loss, grads = nnx.value_and_grad(loss_fn)(model, x, y)
    optimizer.update(grads)
    return loss

def _train_epoch(
    model, optimizer, xs_train, ys_train, batch_size, rng
):
    train_ds_size = len(xs_train)
    steps_per_epoch = train_ds_size // batch_size

    perms = random.permutation(rng, len(xs_train))
    perms = perms[: steps_per_epoch * batch_size]
    perms = perms.reshape((steps_per_epoch, batch_size))

    epoch_loss = []

    for perm in perms:
        batch_xs = xs_train[perm, :]
        batch_ys = ys_train[perm, :]
        loss = train_step(model, optimizer, batch_xs, batch_ys)
        epoch_loss.append(loss)  # store training loss for the current batch

    train_loss = np.mean(epoch_loss)
    return model, train_loss


def train(model, optimizer, xs_train, ys_train,
          batch_size,epochs, log_period_epoch=1, show_progress=True):

    train_loss_history = []

    for epoch in tqdm(range(1, epochs + 1), disable=not show_progress):
        model, train_loss = _train_epoch(
            model, optimizer, xs_train, ys_train,
            batch_size, random.key(1),
        )

        train_loss_history.append(train_loss)

        if epoch == 1 or epoch % log_period_epoch == 0:
            logging.info(
                "epoch:% 3d, train_loss: %.4f"
                % (epoch, train_loss)
            )
    return train_loss_history

In [20]:
# adapted from template https://github.com/scikit-learn-contrib/project-template/blob/main/skltemplate/_template.py

class MLPRegressor(BaseEstimator):

    def __init__(self, model, lr, epochs, batch_size, log_period_epoch=10,
                 show_progress=True):
        self.model = model
        self.lr = lr
        self.epochs = epochs
        self.batch_size = batch_size
        self.log_period_epoch = log_period_epoch
        self.show_progress = show_progress

    def fit(self, X, y):

        # `_validate_data` is defined in the `BaseEstimator` class.
        # It allows to:
        # - run different checks on the input data;
        # - define some attributes associated to the input data: `n_features_in_` and
        #   `feature_names_in_`.
        X, y = self._validate_data(X, y, accept_sparse=True)

        # transform row vectors into columns to be compatible with the output of the NN
        y = y.reshape(-1, 1)
        self._optimizer = nnx.Optimizer(self.model, optax.adam(self.lr))
        self.train_loss_history = train(self.model, self._optimizer, X, y,
                                        self.batch_size, self.epochs,
                                        self.log_period_epoch, self.show_progress)
        self.trained_params = self._optimizer.target  
        self.is_fitted_ = True
        # `fit` should always return `self`
        return self

    def predict(self, X):
        check_is_fitted(self)
        X = self._validate_data(X, accept_sparse=True, reset=False)

        self.model = self.model.replace(self.trained_params)  # Inject trained weights

        return self.model(X)

    

    def save_joblib(self, path="mlp_nnx_model.joblib"):
        import joblib
        joblib.dump(self, path)
        print(f"✅ Saved to {path}")


    def score(self, X, y):
        y_pred = self.predict(X)
        return r2_score(y, y_pred)    


In [21]:
class Model(nnx.Module):
  def __init__(self, din, dout, rngs: nnx.Rngs):
    self.linear1 = nnx.Linear(din, 64, rngs=rngs)
    self.linear2 = nnx.Linear(64, 64, rngs=rngs)
    self.linear3 = nnx.Linear(64, 32, rngs=rngs)
    self.linear4 = nnx.Linear(32, dout, rngs=rngs)

  def __call__(self, x):
    y = self.linear1(x)
    y = nnx.relu(y)
    y = self.linear2(y)
    y = nnx.relu(y)
    y = self.linear3(y)
    y = nnx.relu(y)
    y = self.linear4(y)
    return y




x_dim = X_train.shape[1]
y_dim = 1
model = Model(x_dim, y_dim, rngs=nnx.Rngs(0))

In [22]:
epochs = 1
learning_rate = 0.01
batch_size = 128

m = MLPRegressor(model, learning_rate, epochs, batch_size)

In [23]:
m.fit(X_train, y_train.ravel())  # Training the model



  0%|          | 0/1 [00:00<?, ?it/s]

AttributeError: 'Optimizer' object has no attribute 'target'

In [12]:
from sklearn.metrics import mean_squared_error, mean_absolute_error, r2_score

# Predict on test set
y_pred = m.predict(X_test)

# Metrics
mse = mean_squared_error(y_test, y_pred)
mae = mean_absolute_error(y_test, y_pred)
r2 = r2_score(y_test, y_pred)

print(f"📊 MSE: {mse:.4f}")
print(f"📊 MAE: {mae:.4f}")
print(f"📈 R² Score: {r2:.4f}")

📊 MSE: 0.8117
📊 MAE: 0.6842
📈 R² Score: 0.7688


## Gridsearch


In [26]:
param_grid = {
    "lr": [0.01, 0.1, 1],
    "batch_size": [64, 128],
}


search = GridSearchCV(
    MLPRegressor(model, lr=0.01, epochs=50, batch_size=64),
    param_grid=param_grid,
    scoring="r2",
    cv=3,
    verbose=2
)



In [28]:
search.fit(X_grid, np.ravel(y_grid))

Fitting 3 folds for each of 6 candidates, totalling 18 fits


  0%|          | 0/50 [00:00<?, ?it/s]

[CV] END .............................batch_size=64, lr=0.01; total time=36.2min


  0%|          | 0/50 [00:00<?, ?it/s]

[CV] END .............................batch_size=64, lr=0.01; total time=34.5min


  0%|          | 0/50 [00:00<?, ?it/s]

[CV] END .............................batch_size=64, lr=0.01; total time=33.0min


  0%|          | 0/50 [00:00<?, ?it/s]

[CV] END ..............................batch_size=64, lr=0.1; total time=30.4min


  0%|          | 0/50 [00:00<?, ?it/s]

[CV] END ..............................batch_size=64, lr=0.1; total time=32.7min


  0%|          | 0/50 [00:00<?, ?it/s]

[CV] END ..............................batch_size=64, lr=0.1; total time=33.4min


  0%|          | 0/50 [00:00<?, ?it/s]

[CV] END ................................batch_size=64, lr=1; total time=33.2min


  0%|          | 0/50 [00:00<?, ?it/s]

[CV] END ................................batch_size=64, lr=1; total time=31.8min


  0%|          | 0/50 [00:00<?, ?it/s]

[CV] END ................................batch_size=64, lr=1; total time=33.3min


  0%|          | 0/50 [00:00<?, ?it/s]

[CV] END ............................batch_size=128, lr=0.01; total time=17.0min


  0%|          | 0/50 [00:00<?, ?it/s]

[CV] END ............................batch_size=128, lr=0.01; total time=15.8min


  0%|          | 0/50 [00:00<?, ?it/s]

[CV] END ............................batch_size=128, lr=0.01; total time=15.4min


  0%|          | 0/50 [00:00<?, ?it/s]

[CV] END .............................batch_size=128, lr=0.1; total time=16.7min


  0%|          | 0/50 [00:00<?, ?it/s]

[CV] END .............................batch_size=128, lr=0.1; total time=16.4min


  0%|          | 0/50 [00:00<?, ?it/s]

[CV] END .............................batch_size=128, lr=0.1; total time=16.6min


  0%|          | 0/50 [00:00<?, ?it/s]

[CV] END ...............................batch_size=128, lr=1; total time=16.4min


  0%|          | 0/50 [00:00<?, ?it/s]

[CV] END ...............................batch_size=128, lr=1; total time=16.4min


  0%|          | 0/50 [00:00<?, ?it/s]

[CV] END ...............................batch_size=128, lr=1; total time=16.8min


  0%|          | 0/50 [00:00<?, ?it/s]

In [29]:
best_params = search.best_params_
print("🎯 Best Parameters Found:", best_params)

🎯 Best Parameters Found: {'batch_size': 128, 'lr': 0.01}


In [30]:
print("📈 Best Cross-Validated R²:", search.best_score_)


📈 Best Cross-Validated R²: 0.7618448535601298


In [33]:
import pandas as pd
pd.DataFrame(search.cv_results_).sort_values("rank_test_score")


Unnamed: 0,mean_fit_time,std_fit_time,mean_score_time,std_score_time,param_batch_size,param_lr,params,split0_test_score,split1_test_score,split2_test_score,mean_test_score,std_test_score,rank_test_score
3,963.674136,40.319145,0.10303,0.001208,128,0.01,"{'batch_size': 128, 'lr': 0.01}",0.760199,0.767713,0.757622,0.761845,0.004281,1
0,2074.235761,81.088929,-0.220861,0.849345,64,0.01,"{'batch_size': 64, 'lr': 0.01}",0.765398,0.769718,0.740881,0.758666,0.012699,2
4,995.966957,6.63042,0.102977,0.002213,128,0.1,"{'batch_size': 128, 'lr': 0.1}",-0.003209,-0.00059,-6.4e-05,-0.001288,0.001375,3
5,990.985527,12.138161,0.102486,0.001232,128,1.0,"{'batch_size': 128, 'lr': 1}",-2.2e-05,-0.001559,-0.00584,-0.002474,0.002462,4
1,1930.788955,77.549282,0.102861,0.000519,64,0.1,"{'batch_size': 64, 'lr': 0.1}",-0.025542,-0.000317,-1.4e-05,-0.008624,0.011963,5
2,1964.484499,41.339932,0.103885,0.002582,64,1.0,"{'batch_size': 64, 'lr': 1}",-0.029662,-0.232707,-0.154879,-0.139082,0.083642,6
