In [None]:
# | default_exp experiments

# Experiments

> The code implementing the experiments in the paper:
> 
> Davor Runje, Sharath M. Shankaranarayana. <i>Constrained Monotonic Neural Networks</i>. 40th International Conference on Machine Learning, 2023.


## Imports

In [None]:
# | export

from contextlib import contextmanager
from datetime import datetime
from os import environ
from pathlib import Path
from typing import *

from tempfile import TemporaryDirectory
import urllib.request
import shutil

from tqdm import tqdm

import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import pytest
import seaborn as sns
import tensorflow as tf
from keras_tuner import (
    BayesianOptimization,
    Objective,
    Tuner,
    HyperParameters,
    HyperModel,
)
from numpy.typing import ArrayLike, NDArray
from tensorflow.keras import Model
from tensorflow.keras.backend import count_params
from tensorflow.keras.layers import Concatenate, Dense, Dropout, Input
from tensorflow.keras.optimizers.experimental import AdamW
from tensorflow.types.experimental import TensorLike

from mono_dense_keras import (
    MonoDense,
    replace_kernel_using_monotonicity_indicator,
    create_type_1,
    create_type_2,
)

In [None]:
from keras_tuner import RandomSearch

In [None]:
environ["TF_FORCE_GPU_ALLOW_GROWTH"] = "true"

## Monotonic Dense Layer


### Monotonic Dense Layer

This is an implementation of our Monotonic Dense Unit or Constrained Monotone Fully Connected Layer. The below is the figure from the paper for reference.

In the code, the variable `monotonicity_indicator` corresponds to **t** in the figure and the variable `activation_selector` corresponds to **s**. 

Parameters `convexity_indicator` and `epsilon` are used to calculate `activation_selector` as follows:
- if `convexity_indicator` is  -1 or 1, then `activation_selector` will have all elements 0 or 1, respecively.
- if `convexity_indicator` is `None`, then `epsilon` must have a value between 0 and 1 and corresponds to the percentage of elements of `activation_selector` set to 1.

![mono-dense-layer-diagram](images/mono-dense-layer-diagram.png)

In [None]:
units = 18
activation = "relu"
batch_size = 9
x_len = 11

tf.keras.utils.set_random_seed(42)


def display_kernel(kernel: Union[tf.Variable, np.typing.NDArray[float]]) -> None:
    cm = sns.color_palette("coolwarm_r", as_cmap=True)

    df = pd.DataFrame(kernel)

    display(
        df.style.format("{:.2f}").background_gradient(cmap=cm, vmin=-1e-8, vmax=1e-8)
    )


x = np.random.default_rng(42).normal(size=(batch_size, x_len))

for monotonicity_indicator in [
    [1] * 4 + [0] * 4 + [-1] * 3,
    1,
    np.ones((x_len,)),
    -1,
    -np.ones((x_len,)),
]:
    print("*" * 120)
    mono_layer = MonoDense(
        units=units,
        activation=activation,
        monotonicity_indicator=monotonicity_indicator,
        activation_weights=(7, 7, 4),
    )
    print("input:")
    display_kernel(x)

    y = mono_layer(x)
    print(f"monotonicity_indicator = {monotonicity_indicator}")
    display_kernel(mono_layer.monotonicity_indicator)

    print("kernel:")
    with replace_kernel_using_monotonicity_indicator(
        mono_layer, mono_layer.monotonicity_indicator
    ):
        display_kernel(mono_layer.kernel)

    print("output:")
    display_kernel(y)
print("ok")

************************************************************************************************************************
input:


Unnamed: 0,0,1,2,3,4,5,6,7,8,9,10
0,0.3,-1.04,0.75,0.94,-1.95,-1.3,0.13,-0.32,-0.02,-0.85,0.88
1,0.78,0.07,1.13,0.47,-0.86,0.37,-0.96,0.88,-0.05,-0.18,-0.68
2,1.22,-0.15,-0.43,-0.35,0.53,0.37,0.41,0.43,2.14,-0.41,-0.51
3,-0.81,0.62,1.13,-0.11,-0.84,-0.82,0.65,0.74,0.54,-0.67,0.23
4,0.12,0.22,0.87,0.22,0.68,0.07,0.29,0.63,-1.46,-0.32,-0.47
5,-0.64,-0.28,1.49,-0.87,0.97,-1.68,-0.33,0.16,0.59,0.71,0.79
6,-0.35,-0.46,0.86,-0.19,-1.28,-1.13,-0.92,0.5,0.14,0.69,-0.43
7,0.16,0.63,-0.31,0.46,-0.66,-0.36,-0.38,-1.2,0.49,-0.47,0.01
8,0.48,0.45,0.67,-0.1,-0.42,-0.08,-1.69,-1.45,-1.32,-1.0,0.4


monotonicity_indicator = [1, 1, 1, 1, 0, 0, 0, 0, -1, -1, -1]


Unnamed: 0,0
0,1.0
1,1.0
2,1.0
3,1.0
4,0.0
5,0.0
6,0.0
7,0.0
8,-1.0
9,-1.0


kernel:


Unnamed: 0,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17
0,0.33,0.15,0.13,0.41,0.38,0.14,0.43,0.3,0.02,0.12,0.38,0.05,0.42,0.03,0.0,0.24,0.44,0.28
1,0.01,0.39,0.42,0.32,0.38,0.22,0.33,0.34,0.03,0.06,0.06,0.27,0.26,0.45,0.35,0.05,0.21,0.34
2,0.21,0.29,0.16,0.14,0.42,0.06,0.15,0.1,0.41,0.08,0.03,0.22,0.34,0.2,0.11,0.01,0.43,0.35
3,0.27,0.33,0.06,0.17,0.42,0.42,0.24,0.3,0.11,0.2,0.17,0.25,0.17,0.07,0.32,0.3,0.17,0.36
4,0.32,-0.25,0.12,-0.37,0.41,0.2,0.06,-0.28,-0.27,0.43,-0.41,-0.17,-0.24,-0.31,0.33,0.31,0.11,0.03
5,0.04,0.19,-0.02,-0.34,0.36,-0.12,0.28,0.32,-0.11,-0.4,0.41,0.3,0.06,-0.28,-0.27,0.23,-0.41,-0.12
6,0.35,-0.04,-0.28,0.16,-0.03,0.35,-0.03,-0.16,0.39,-0.36,-0.31,-0.18,0.02,-0.38,-0.4,0.39,0.35,-0.19
7,0.33,-0.34,0.11,-0.29,0.25,-0.21,0.11,0.08,-0.19,-0.39,0.01,0.1,0.39,-0.25,-0.37,-0.27,0.04,0.34
8,-0.27,-0.09,-0.02,-0.45,-0.16,-0.12,-0.09,-0.43,-0.36,-0.09,-0.23,-0.42,-0.28,-0.24,-0.3,-0.31,-0.07,-0.07
9,-0.38,-0.34,-0.44,-0.42,-0.32,-0.06,-0.27,-0.28,-0.22,-0.05,-0.08,-0.07,-0.21,-0.39,-0.01,-0.26,-0.24,-0.42


output:


Unnamed: 0,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17
0,0.01,0.4,0.0,1.38,0.0,0.1,0.0,-0.0,-0.0,-0.13,-0.0,-0.26,-0.0,-0.0,-0.55,-0.52,0.79,0.64
1,0.45,1.02,0.96,0.71,1.22,0.0,0.86,-0.0,-0.0,-0.09,-0.0,-0.0,-0.0,-0.0,0.26,-0.17,0.54,1.0
2,0.3,0.0,0.33,0.0,0.41,0.0,0.42,-0.53,-0.89,-0.29,-0.23,-0.84,-0.16,-0.93,-0.9,0.08,0.37,0.08
3,0.21,0.26,0.33,0.42,0.0,0.0,0.0,-0.16,-0.0,-0.61,-0.53,-0.07,-0.0,-0.0,-0.55,-0.66,0.83,0.78
4,1.38,0.49,0.7,0.82,1.47,0.54,0.63,-0.0,-0.0,-0.0,-0.0,-0.0,-0.0,-0.0,0.73,0.97,0.94,0.91
5,0.0,0.0,0.0,0.0,0.0,0.0,0.0,-1.86,-0.25,-0.0,-1.57,-1.19,-0.61,-0.23,0.13,-1.0,0.5,-0.06
6,0.0,0.0,0.0,0.17,0.0,0.0,0.0,-0.15,-0.0,-0.0,-0.0,-0.0,-0.0,-0.0,0.06,-1.0,0.0,0.12
7,0.0,0.96,0.35,0.93,0.0,0.32,0.17,-0.0,-0.0,-0.0,-0.0,-0.0,-0.17,-0.0,0.67,0.06,0.12,0.17
8,0.0,1.33,0.92,1.63,0.52,0.0,0.66,-0.0,-0.0,-0.0,-0.0,-0.0,-0.0,-0.0,1.0,0.23,0.18,0.81


************************************************************************************************************************
input:


Unnamed: 0,0,1,2,3,4,5,6,7,8,9,10
0,0.3,-1.04,0.75,0.94,-1.95,-1.3,0.13,-0.32,-0.02,-0.85,0.88
1,0.78,0.07,1.13,0.47,-0.86,0.37,-0.96,0.88,-0.05,-0.18,-0.68
2,1.22,-0.15,-0.43,-0.35,0.53,0.37,0.41,0.43,2.14,-0.41,-0.51
3,-0.81,0.62,1.13,-0.11,-0.84,-0.82,0.65,0.74,0.54,-0.67,0.23
4,0.12,0.22,0.87,0.22,0.68,0.07,0.29,0.63,-1.46,-0.32,-0.47
5,-0.64,-0.28,1.49,-0.87,0.97,-1.68,-0.33,0.16,0.59,0.71,0.79
6,-0.35,-0.46,0.86,-0.19,-1.28,-1.13,-0.92,0.5,0.14,0.69,-0.43
7,0.16,0.63,-0.31,0.46,-0.66,-0.36,-0.38,-1.2,0.49,-0.47,0.01
8,0.48,0.45,0.67,-0.1,-0.42,-0.08,-1.69,-1.45,-1.32,-1.0,0.4


monotonicity_indicator = 1


Unnamed: 0,0
0,1.0


kernel:


Unnamed: 0,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17
0,0.44,0.02,0.24,0.22,0.29,0.35,0.18,0.03,0.39,0.17,0.25,0.02,0.1,0.13,0.0,0.42,0.21,0.31
1,0.35,0.06,0.26,0.42,0.05,0.41,0.16,0.33,0.03,0.26,0.11,0.03,0.23,0.04,0.37,0.27,0.32,0.4
2,0.37,0.3,0.36,0.14,0.21,0.4,0.01,0.28,0.16,0.44,0.43,0.23,0.27,0.22,0.23,0.25,0.43,0.05
3,0.32,0.25,0.05,0.45,0.08,0.18,0.26,0.24,0.34,0.07,0.07,0.14,0.04,0.19,0.29,0.23,0.43,0.09
4,0.36,0.05,0.2,0.41,0.38,0.29,0.01,0.44,0.17,0.04,0.31,0.34,0.29,0.16,0.25,0.18,0.01,0.28
5,0.34,0.31,0.38,0.34,0.08,0.4,0.15,0.16,0.14,0.25,0.15,0.2,0.1,0.06,0.44,0.19,0.42,0.21
6,0.01,0.38,0.43,0.18,0.0,0.43,0.45,0.28,0.25,0.18,0.03,0.26,0.22,0.26,0.08,0.23,0.45,0.42
7,0.04,0.12,0.28,0.17,0.11,0.0,0.15,0.24,0.05,0.05,0.27,0.32,0.33,0.11,0.09,0.4,0.19,0.06
8,0.3,0.17,0.21,0.42,0.21,0.29,0.19,0.38,0.03,0.34,0.32,0.3,0.34,0.15,0.28,0.11,0.44,0.19
9,0.1,0.1,0.35,0.32,0.24,0.28,0.3,0.28,0.1,0.12,0.3,0.41,0.15,0.0,0.1,0.4,0.18,0.24


output:


Unnamed: 0,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17
0,0.0,0.01,0.0,0.0,0.0,0.0,0.0,-0.93,-0.0,-0.07,-0.58,-0.88,-0.58,-0.0,-0.87,-0.49,-0.05,-1.0
1,0.73,0.1,0.22,0.18,0.18,0.16,0.0,-0.23,-0.0,-0.0,-0.0,-0.09,-0.0,-0.0,0.16,0.47,0.53,-0.27
2,1.15,0.36,0.82,1.2,0.8,1.06,0.61,-0.0,-0.0,-0.0,-0.0,-0.0,-0.0,-0.0,0.53,0.61,1.0,0.94
3,0.0,0.45,0.28,0.0,0.0,0.11,0.14,-0.0,-0.21,-0.0,-0.0,-0.0,-0.0,-0.0,0.15,0.08,0.72,-0.08
4,0.34,0.19,0.36,0.05,0.15,0.3,0.0,-0.0,-0.0,-0.08,-0.0,-0.0,-0.0,-0.0,0.06,0.38,0.04,0.14
5,0.0,0.0,0.26,0.0,0.67,0.05,0.0,-0.0,-0.16,-0.0,-0.0,-0.0,-0.0,-0.0,-0.08,0.3,-0.17,-0.17
6,0.0,0.0,0.0,0.0,0.0,0.0,0.0,-0.76,-0.68,-0.28,-0.11,-0.37,-0.42,-0.4,-0.88,-0.41,-0.67,-1.0
7,0.01,0.0,0.0,0.0,0.0,0.0,0.0,-0.45,-0.17,-0.04,-0.57,-0.82,-0.5,-0.22,-0.07,-0.62,-0.13,-0.18
8,0.0,0.0,0.0,0.0,0.0,0.0,0.0,-1.32,-0.35,-0.39,-0.77,-1.63,-1.12,-0.6,-0.47,-0.99,-1.0,-1.0


************************************************************************************************************************
input:


Unnamed: 0,0,1,2,3,4,5,6,7,8,9,10
0,0.3,-1.04,0.75,0.94,-1.95,-1.3,0.13,-0.32,-0.02,-0.85,0.88
1,0.78,0.07,1.13,0.47,-0.86,0.37,-0.96,0.88,-0.05,-0.18,-0.68
2,1.22,-0.15,-0.43,-0.35,0.53,0.37,0.41,0.43,2.14,-0.41,-0.51
3,-0.81,0.62,1.13,-0.11,-0.84,-0.82,0.65,0.74,0.54,-0.67,0.23
4,0.12,0.22,0.87,0.22,0.68,0.07,0.29,0.63,-1.46,-0.32,-0.47
5,-0.64,-0.28,1.49,-0.87,0.97,-1.68,-0.33,0.16,0.59,0.71,0.79
6,-0.35,-0.46,0.86,-0.19,-1.28,-1.13,-0.92,0.5,0.14,0.69,-0.43
7,0.16,0.63,-0.31,0.46,-0.66,-0.36,-0.38,-1.2,0.49,-0.47,0.01
8,0.48,0.45,0.67,-0.1,-0.42,-0.08,-1.69,-1.45,-1.32,-1.0,0.4


monotonicity_indicator = [1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]


Unnamed: 0,0
0,1.0
1,1.0
2,1.0
3,1.0
4,1.0
5,1.0
6,1.0
7,1.0
8,1.0
9,1.0


kernel:


Unnamed: 0,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17
0,0.31,0.02,0.11,0.29,0.1,0.33,0.37,0.06,0.39,0.35,0.15,0.13,0.15,0.45,0.07,0.19,0.03,0.06
1,0.12,0.02,0.06,0.41,0.32,0.24,0.34,0.28,0.22,0.06,0.33,0.27,0.25,0.23,0.43,0.09,0.45,0.27
2,0.19,0.11,0.19,0.25,0.07,0.42,0.32,0.35,0.15,0.05,0.0,0.24,0.22,0.39,0.44,0.11,0.19,0.1
3,0.15,0.37,0.21,0.41,0.25,0.04,0.37,0.04,0.05,0.22,0.31,0.35,0.35,0.08,0.38,0.01,0.25,0.29
4,0.17,0.45,0.24,0.32,0.01,0.0,0.19,0.34,0.17,0.19,0.18,0.34,0.02,0.24,0.03,0.41,0.26,0.0
5,0.29,0.1,0.07,0.34,0.04,0.3,0.39,0.27,0.39,0.16,0.33,0.45,0.06,0.19,0.23,0.04,0.36,0.04
6,0.13,0.15,0.22,0.4,0.14,0.3,0.11,0.45,0.14,0.17,0.26,0.16,0.36,0.1,0.17,0.32,0.14,0.08
7,0.25,0.25,0.24,0.45,0.17,0.45,0.3,0.35,0.41,0.4,0.11,0.26,0.32,0.08,0.22,0.34,0.05,0.09
8,0.16,0.27,0.1,0.23,0.08,0.21,0.19,0.16,0.06,0.04,0.17,0.05,0.39,0.11,0.26,0.25,0.13,0.05
9,0.17,0.17,0.0,0.13,0.12,0.03,0.39,0.11,0.01,0.29,0.43,0.2,0.21,0.43,0.39,0.18,0.19,0.27


output:


Unnamed: 0,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17
0,0.0,0.0,0.08,0.0,0.0,0.0,0.0,-0.82,-0.58,-0.32,-1.07,-1.09,-0.0,-0.63,-0.21,-0.74,-1.0,-0.15
1,0.36,0.0,0.0,0.51,0.11,0.72,0.76,-0.12,-0.0,-0.0,-0.05,-0.0,-0.0,-0.0,0.56,-0.34,0.13,0.22
2,0.72,0.68,0.32,1.1,0.1,0.84,0.68,-0.0,-0.0,-0.0,-0.0,-0.0,-0.0,-0.0,0.2,0.97,0.33,-0.07
3,0.0,0.0,0.36,0.35,0.36,0.82,0.0,-0.0,-0.0,-0.19,-0.29,-0.13,-0.0,-0.2,0.67,0.2,-0.0,0.14
4,0.18,0.14,0.26,0.68,0.09,0.38,0.36,-0.0,-0.0,-0.0,-0.0,-0.0,-0.07,-0.0,0.14,0.15,0.33,0.1
5,0.01,0.55,0.5,0.0,0.0,0.21,0.0,-0.0,-0.27,-0.0,-0.44,-0.25,-0.0,-0.0,0.44,0.83,-0.24,-0.01
6,0.0,0.0,0.0,0.0,0.0,0.0,0.0,-0.89,-0.85,-0.48,-0.77,-0.9,-0.21,-0.3,-0.09,-0.69,-0.83,-0.03
7,0.0,0.0,0.0,0.0,0.01,0.0,0.0,-0.79,-0.59,-0.65,-0.21,-0.55,-0.19,-0.37,-0.17,-0.71,-0.1,0.03
8,0.0,0.0,0.0,0.0,0.0,0.0,0.0,-1.24,-0.48,-0.95,-1.13,-0.71,-1.4,-0.3,-0.76,-1.0,-0.47,-0.39


************************************************************************************************************************
input:


Unnamed: 0,0,1,2,3,4,5,6,7,8,9,10
0,0.3,-1.04,0.75,0.94,-1.95,-1.3,0.13,-0.32,-0.02,-0.85,0.88
1,0.78,0.07,1.13,0.47,-0.86,0.37,-0.96,0.88,-0.05,-0.18,-0.68
2,1.22,-0.15,-0.43,-0.35,0.53,0.37,0.41,0.43,2.14,-0.41,-0.51
3,-0.81,0.62,1.13,-0.11,-0.84,-0.82,0.65,0.74,0.54,-0.67,0.23
4,0.12,0.22,0.87,0.22,0.68,0.07,0.29,0.63,-1.46,-0.32,-0.47
5,-0.64,-0.28,1.49,-0.87,0.97,-1.68,-0.33,0.16,0.59,0.71,0.79
6,-0.35,-0.46,0.86,-0.19,-1.28,-1.13,-0.92,0.5,0.14,0.69,-0.43
7,0.16,0.63,-0.31,0.46,-0.66,-0.36,-0.38,-1.2,0.49,-0.47,0.01
8,0.48,0.45,0.67,-0.1,-0.42,-0.08,-1.69,-1.45,-1.32,-1.0,0.4


monotonicity_indicator = -1


Unnamed: 0,0
0,-1.0


kernel:


Unnamed: 0,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17
0,-0.29,-0.12,-0.0,-0.17,-0.33,-0.17,-0.33,-0.36,-0.28,-0.16,-0.24,-0.22,-0.1,-0.13,-0.02,-0.38,-0.23,-0.02
1,-0.36,-0.13,-0.05,-0.07,-0.41,-0.3,-0.38,-0.06,-0.4,-0.42,-0.44,-0.03,-0.27,-0.03,-0.32,-0.31,-0.35,-0.4
2,-0.3,-0.07,-0.4,-0.06,-0.1,-0.21,-0.16,-0.22,-0.06,-0.36,-0.4,-0.42,-0.23,-0.22,-0.2,-0.33,-0.45,-0.06
3,-0.05,-0.08,-0.07,-0.3,-0.44,-0.23,-0.4,-0.25,-0.13,-0.31,-0.11,-0.13,-0.13,-0.34,-0.15,-0.05,-0.36,-0.13
4,-0.45,-0.34,-0.41,-0.39,-0.15,-0.1,-0.4,-0.32,-0.19,-0.13,-0.29,-0.39,-0.43,-0.29,-0.13,-0.05,-0.39,-0.01
5,-0.09,-0.38,-0.0,-0.12,-0.07,-0.42,-0.01,-0.12,-0.26,-0.28,-0.16,-0.06,-0.08,-0.43,-0.23,-0.28,-0.28,-0.07
6,-0.34,-0.38,-0.15,-0.44,-0.41,-0.19,-0.25,-0.41,-0.34,-0.22,-0.43,-0.36,-0.25,-0.28,-0.06,-0.12,-0.15,-0.16
7,-0.17,-0.39,-0.4,-0.26,-0.4,-0.2,-0.1,-0.14,-0.42,-0.21,-0.18,-0.25,-0.15,-0.21,-0.13,-0.41,-0.14,-0.14
8,-0.38,-0.03,-0.1,-0.21,-0.13,-0.04,-0.19,-0.0,-0.09,-0.38,-0.01,-0.27,-0.24,-0.24,-0.13,-0.18,-0.37,-0.21
9,-0.43,-0.08,-0.2,-0.29,-0.1,-0.27,-0.08,-0.43,-0.22,-0.37,-0.27,-0.24,-0.15,-0.22,-0.01,-0.45,-0.35,-0.31


output:


Unnamed: 0,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17
0,1.05,0.88,0.59,0.61,0.0,0.7,0.64,-0.0,-0.0,-0.0,-0.0,-0.0,-0.0,-0.0,0.24,0.74,1.0,0.55
1,0.27,0.26,0.0,0.41,0.0,0.0,0.0,-0.0,-0.23,-0.33,-0.21,-0.2,-0.0,-0.02,-0.04,-0.82,-0.52,-0.02
2,0.0,0.0,0.0,0.0,0.0,0.0,0.0,-0.36,-0.77,-0.71,-0.39,-1.0,-0.82,-0.67,-0.11,-0.74,-0.97,-0.31
3,0.0,0.0,0.0,0.0,0.0,0.01,0.0,-0.0,-0.15,-0.5,-0.38,-0.33,-0.2,-0.0,-0.39,-0.2,-0.12,-0.36
4,0.0,0.0,0.0,0.0,0.0,0.0,0.0,-0.45,-0.46,-0.0,-0.84,-0.48,-0.36,-0.13,-0.08,-0.28,-0.33,0.13
5,0.0,0.02,0.0,0.0,0.12,0.33,0.0,-0.41,-0.0,-0.44,-0.33,-0.9,-0.56,-0.04,-0.24,-0.27,-0.48,-0.16
6,0.74,1.2,0.11,0.9,0.84,0.65,0.87,-0.0,-0.0,-0.0,-0.0,-0.0,-0.0,-0.0,0.6,0.01,0.53,0.12
7,0.47,0.89,0.91,0.62,0.26,0.37,0.01,-0.0,-0.0,-0.0,-0.0,-0.0,-0.0,-0.0,0.07,0.61,0.29,0.01
8,1.3,1.17,0.98,1.61,1.09,0.59,0.65,-0.0,-0.0,-0.0,-0.0,-0.0,-0.0,-0.0,0.09,0.93,0.94,0.81


************************************************************************************************************************
input:


Unnamed: 0,0,1,2,3,4,5,6,7,8,9,10
0,0.3,-1.04,0.75,0.94,-1.95,-1.3,0.13,-0.32,-0.02,-0.85,0.88
1,0.78,0.07,1.13,0.47,-0.86,0.37,-0.96,0.88,-0.05,-0.18,-0.68
2,1.22,-0.15,-0.43,-0.35,0.53,0.37,0.41,0.43,2.14,-0.41,-0.51
3,-0.81,0.62,1.13,-0.11,-0.84,-0.82,0.65,0.74,0.54,-0.67,0.23
4,0.12,0.22,0.87,0.22,0.68,0.07,0.29,0.63,-1.46,-0.32,-0.47
5,-0.64,-0.28,1.49,-0.87,0.97,-1.68,-0.33,0.16,0.59,0.71,0.79
6,-0.35,-0.46,0.86,-0.19,-1.28,-1.13,-0.92,0.5,0.14,0.69,-0.43
7,0.16,0.63,-0.31,0.46,-0.66,-0.36,-0.38,-1.2,0.49,-0.47,0.01
8,0.48,0.45,0.67,-0.1,-0.42,-0.08,-1.69,-1.45,-1.32,-1.0,0.4


monotonicity_indicator = [-1. -1. -1. -1. -1. -1. -1. -1. -1. -1. -1.]


Unnamed: 0,0
0,-1.0
1,-1.0
2,-1.0
3,-1.0
4,-1.0
5,-1.0
6,-1.0
7,-1.0
8,-1.0
9,-1.0


kernel:


Unnamed: 0,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17
0,-0.45,-0.28,-0.3,-0.41,-0.17,-0.39,-0.22,-0.45,-0.28,-0.4,-0.18,-0.2,-0.16,-0.18,-0.1,-0.13,-0.14,-0.35
1,-0.09,-0.27,-0.09,-0.14,-0.02,-0.36,-0.21,-0.05,-0.05,-0.01,-0.02,-0.45,-0.03,-0.09,-0.01,-0.05,-0.39,-0.05
2,-0.17,-0.15,-0.37,-0.35,-0.32,-0.03,-0.24,-0.31,-0.35,-0.41,-0.0,-0.37,-0.18,-0.26,-0.09,-0.44,-0.09,-0.17
3,-0.42,-0.17,-0.11,-0.31,-0.32,-0.11,-0.2,-0.1,-0.34,-0.15,-0.24,-0.22,-0.22,-0.08,-0.4,-0.02,-0.23,-0.38
4,-0.13,-0.17,-0.06,-0.13,-0.32,-0.42,-0.28,-0.44,-0.03,-0.26,-0.38,-0.45,-0.08,-0.06,-0.04,-0.33,-0.27,-0.38
5,-0.32,-0.38,-0.19,-0.19,-0.33,-0.01,-0.15,-0.08,-0.31,-0.27,-0.07,-0.11,-0.21,-0.22,-0.18,-0.27,-0.19,-0.15
6,-0.3,-0.16,-0.09,-0.25,-0.23,-0.44,-0.25,-0.16,-0.05,-0.13,-0.2,-0.09,-0.14,-0.18,-0.15,-0.22,-0.37,-0.38
7,-0.2,-0.14,-0.12,-0.1,-0.42,-0.42,-0.14,-0.04,-0.44,-0.11,-0.1,-0.17,-0.06,-0.29,-0.22,-0.24,-0.01,-0.45
8,-0.31,-0.11,-0.16,-0.21,-0.16,-0.39,-0.12,-0.36,-0.36,-0.29,-0.24,-0.24,-0.2,-0.18,-0.33,-0.39,-0.2,-0.02
9,-0.41,-0.14,-0.12,-0.21,-0.01,-0.37,-0.03,-0.22,-0.38,-0.22,-0.09,-0.22,-0.19,-0.17,-0.13,-0.32,-0.3,-0.21


output:


Unnamed: 0,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17
0,0.2,0.84,0.11,0.0,0.55,1.24,0.55,-0.0,-0.02,-0.0,-0.0,-0.0,-0.0,-0.0,-0.2,0.98,1.0,0.3
1,0.0,0.0,0.0,0.0,0.0,0.19,0.0,-0.14,-0.87,-0.5,-0.0,-0.34,-0.28,-0.53,-0.24,-0.34,0.23,-0.09
2,0.0,0.0,0.0,0.0,0.0,0.0,0.0,-1.34,-0.82,-1.02,-0.75,-0.74,-0.56,-0.68,-0.71,-1.0,-0.65,-0.56
3,0.23,0.18,0.0,0.0,0.0,0.0,0.0,-0.0,-0.27,-0.0,-0.0,-0.21,-0.0,-0.28,-0.21,-0.24,0.02,0.0
4,0.09,0.0,0.0,0.0,0.0,0.0,0.0,-0.08,-0.0,-0.14,-0.0,-0.5,-0.01,-0.25,0.23,-0.2,-0.14,-0.66
5,0.18,0.49,0.0,0.0,0.03,0.0,0.0,-0.79,-0.36,-0.49,-0.39,-0.69,-0.0,-0.09,0.08,-0.84,0.1,-0.25
6,0.64,0.76,0.08,0.5,0.62,0.79,0.68,-0.0,-0.06,-0.0,-0.0,-0.0,-0.0,-0.0,0.28,0.24,0.86,0.87
7,0.32,0.24,0.23,0.18,0.76,0.62,0.28,-0.0,-0.0,-0.0,-0.0,-0.0,-0.0,-0.0,0.13,0.73,0.09,0.87
8,1.23,0.5,0.27,0.51,1.08,2.0,0.6,-0.0,-0.0,-0.0,-0.0,-0.0,-0.0,-0.0,1.0,1.0,1.0,1.0


ok


In [None]:
x = Input(shape=(5, 7, 8))

layer = MonoDense(
    units=12,
    activation=activation,
    monotonicity_indicator=[1] * 3 + [-1] * 3 + [0] * 2,
    is_convex=False,
    is_concave=False,
)

y = layer(x)

model = Model(inputs=x, outputs=y)

model.summary()

display_kernel(layer.monotonicity_indicator)

Model: "model"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 input_1 (InputLayer)        [(None, 5, 7, 8)]         0         
                                                                 
 mono_dense_5 (MonoDense)    (None, 5, 7, 12)          108       
                                                                 
Total params: 108
Trainable params: 108
Non-trainable params: 0
_________________________________________________________________


Unnamed: 0,0
0,1.0
1,1.0
2,1.0
3,-1.0
4,-1.0
5,-1.0
6,0.0
7,0.0


## Experiments

For our experiments, we employ the datasets used by the authors of Certified Monotonic Network [1] and COMET [2]. We use the exact train-test split provided by the authors. Their respective repositories are linked below in the references. We directly load the saved train-test data split which have been saved after running the codes from respective papers' authors. 


References:


1.   Xingchao Liu, Xing Han, Na Zhang, and Qiang Liu. Certified monotonic neural networks. Advances in Neural Information Processing Systems, 33:15427–15438, 2020
  
  Github repo: https://github.com/gnobitab/CertifiedMonotonicNetwork



2.   Aishwarya Sivaraman, Golnoosh Farnadi, Todd Millstein, and Guy Van den Broeck. Counterexample-guided learning of monotonic neural networks. Advances in Neural Information Processing Systems, 33:11936–11948, 2020

  Github repo: https://github.com/AishwaryaSivaraman/COMET

In [None]:
# | exporti


class DownloadProgressBar(tqdm):
    def update_to(
        self, b: int = 1, bsize: int = 1, tsize: Optional[int] = None
    ) -> None:
        if tsize is not None:
            self.total = tsize
        self.update(b * bsize - self.n)


def download_url(url: str, output_path: Path) -> None:
    with DownloadProgressBar(
        unit="B", unit_scale=True, miniters=1, desc=url.split("/")[-1]
    ) as t:
        # nosemgrep: python.lang.security.audit.dynamic-urllib-use-detected.dynamic-urllib-use-detected
        urllib.request.urlretrieve(
            url, filename=output_path, reporthook=t.update_to
        )  # nosec

In [None]:
# | export


def get_data_path(data_path: Optional[Union[Path, str]] = None) -> Path:
    if data_path is None:
        data_path = "./data"
    return Path(data_path)


def download_data(
    dataset_name: str,
    data_path: Optional[Union[Path, str]] = "data",
    force_download: bool = False,
) -> None:
    data_path = get_data_path(data_path)
    data_path.mkdir(exist_ok=True, parents=True)

    for prefix in ["train", "test"]:
        filename = f"{prefix}_{dataset_name}.csv"
        if not (data_path / filename).exists() or force_download:
            with TemporaryDirectory() as d:
                download_url(
                    f"https://zenodo.org/record/7968969/files/{filename}",
                    Path(d) / filename,
                )
                shutil.copyfile(Path(d) / filename, data_path / filename)
        else:
            print(f"Upload skipped, file {(data_path / filename).resolve()} exists.")

In [None]:
download_data("auto", force_download=True)

!ls -l data

assert (Path("data") / "train_auto.csv").exists()

train_auto.csv: 49.2kB [00:01, 47.5kB/s]                            
test_auto.csv: 16.4kB [00:00, 25.0kB/s]                            


total 257812
-rw-rw-r-- 1 davor davor    11161 Jun  1 08:40 test_auto.csv
-rw-rw-r-- 1 davor davor 11340054 May 25 04:48 test_blog.csv
-rw-rw-r-- 1 davor davor   101210 May 25 04:48 test_compas.csv
-rw-rw-r-- 1 davor davor    15798 May 25 04:48 test_heart.csv
-rw-rw-r-- 1 davor davor 13339777 May 25 04:48 test_loan.csv
-rw-rw-r-- 1 davor davor    44626 Jun  1 08:40 train_auto.csv
-rw-rw-r-- 1 davor davor 79478767 May 25 04:48 train_blog.csv
-rw-rw-r-- 1 davor davor   405660 May 25 04:48 train_compas.csv
-rw-rw-r-- 1 davor davor    62282 May 25 04:48 train_heart.csv
-rw-rw-r-- 1 davor davor 79588030 May 25 04:48 train_loan.csv
-rw-rw-r-- 1 davor davor 79588030 May 29 13:57 {prefix}_{name}.csv


In [None]:
# | export


def sanitize_col_names(df: pd.DataFrame) -> pd.DataFrame:
    columns = {c: c.replace(" ", "_") for c in df}
    df = df.rename(columns=columns)
    return df

In [None]:
sanitize_col_names(pd.DataFrame({"a b": [1, 2, 3]}))

Unnamed: 0,a_b
0,1
1,2
2,3


In [None]:
# | export


def get_train_n_test_data(
    dataset_name: str,
    *,
    data_path: Optional[Union[Path, str]] = "./data",
) -> Tuple[pd.DataFrame, pd.DataFrame]:
    data_path = get_data_path(data_path)

    dfx = [
        pd.read_csv(data_path / f"{prefix}_{dataset_name}.csv")
        for prefix in ["train", "test"]
    ]
    dfx = [sanitize_col_names(df) for df in dfx]
    return dfx[0], dfx[1]

In [None]:
train_df, test_df = get_train_n_test_data("auto")
display(train_df)
display(test_df)

Unnamed: 0,Cylinders,Displacement,Horsepower,Weight,Acceleration,Model_Year,Origin,ground_truth
0,1.482807,1.073028,0.650564,0.606625,-1.275546,-1.631803,-0.701669,18.0
1,1.482807,1.482902,1.548993,0.828131,-1.452517,-1.631803,-0.701669,15.0
2,1.482807,1.044432,1.163952,0.523413,-1.275546,-1.631803,-0.701669,16.0
3,1.482807,1.025368,0.907258,0.542165,-1.806460,-1.631803,-0.701669,17.0
4,1.482807,2.235927,2.396084,1.587581,-1.983431,-1.631803,-0.701669,15.0
...,...,...,...,...,...,...,...,...
309,0.310007,0.358131,0.188515,-0.177437,-0.319901,1.720778,-0.701669,22.0
310,-0.862792,-0.566468,-0.530229,-0.722413,-0.921604,1.720778,-0.701669,36.0
311,-0.862792,-0.928683,-1.351650,-1.003691,3.184131,1.720778,0.557325,44.0
312,-0.862792,-0.566468,-0.530229,-0.810312,-1.417123,1.720778,-0.701669,32.0


Unnamed: 0,Cylinders,Displacement,Horsepower,Weight,Acceleration,Model_Year,Origin,ground_truth
0,-0.862792,-1.043066,-1.017947,-1.027131,1.272841,1.162014,1.816319,40.8
1,1.482807,1.177880,1.163952,0.526929,-1.629489,-1.631803,-0.701669,18.0
2,1.482807,1.482902,1.934034,0.794143,-1.629489,-0.793657,-0.701669,11.0
3,0.310007,0.529707,-0.119518,0.346443,-0.213718,-1.352421,-0.701669,19.0
4,-0.862792,-1.004939,-0.863931,-1.243949,-0.567661,0.882633,0.557325,31.9
...,...,...,...,...,...,...,...,...
73,-0.862792,-0.699916,0.188515,-0.062582,-0.390690,-1.073039,0.557325,18.0
74,-0.862792,-0.518809,-0.838261,-0.686081,1.379024,-0.793657,-0.701669,21.0
75,0.310007,-0.251914,0.701903,-0.089538,-1.487912,1.162014,1.816319,32.7
76,1.482807,1.492434,1.138283,1.580549,-0.390690,0.323869,-0.701669,16.0


In [None]:
# | export


def df2ds(df: pd.DataFrame) -> tf.data.Dataset:
    x = df.to_dict("list")
    y = x.pop("ground_truth")

    ds = tf.data.Dataset.from_tensor_slices((x, y))

    return ds


def peek(ds: tf.data.Dataset) -> tf.Tensor:
    for x in ds:
        return x

In [None]:
x, y = peek(df2ds(train_df).batch(8))
display(x)
display(y)

expected = {
    "Acceleration",
    "Cylinders",
    "Displacement",
    "Horsepower",
    "Model_Year",
    "Origin",
    "Weight",
}
assert set(x.keys()) == expected
for k in expected:
    assert x[k].shape == (8,)
assert y.shape == (8,)

{'Cylinders': <tf.Tensor: shape=(8,), dtype=float32, numpy=
 array([1.4828068, 1.4828068, 1.4828068, 1.4828068, 1.4828068, 1.4828068,
        1.4828068, 1.4828068], dtype=float32)>,
 'Displacement': <tf.Tensor: shape=(8,), dtype=float32, numpy=
 array([1.0730283, 1.4829025, 1.0444324, 1.0253685, 2.235927 , 2.474226 ,
        2.3407786, 1.8641808], dtype=float32)>,
 'Horsepower': <tf.Tensor: shape=(8,), dtype=float32, numpy=
 array([0.65056413, 1.5489933 , 1.1639522 , 0.9072582 , 2.3960838 ,
        2.9608107 , 2.8324637 , 2.1907284 ], dtype=float32)>,
 'Weight': <tf.Tensor: shape=(8,), dtype=float32, numpy=
 array([0.6066247, 0.828131 , 0.5234134, 0.5421652, 1.5875812, 1.602817 ,
        1.5535934, 1.0121336], dtype=float32)>,
 'Acceleration': <tf.Tensor: shape=(8,), dtype=float32, numpy=
 array([-1.2755462, -1.4525175, -1.2755462, -1.8064601, -1.9834315,
        -2.3373742, -2.5143454, -2.5143454], dtype=float32)>,
 'Model_Year': <tf.Tensor: shape=(8,), dtype=float32, numpy=
 array([-

<tf.Tensor: shape=(8,), dtype=float32, numpy=array([18., 15., 16., 17., 15., 14., 14., 15.], dtype=float32)>

In [None]:
# | export


def build_mono_model_f(
    *,
    monotonicity_indicator: Dict[str, int],
    final_activation: Union[str, Callable[[TensorLike], TensorLike]],
    loss: Union[str, Callable[[TensorLike, TensorLike], TensorLike]],
    metrics: Union[str, Callable[[TensorLike, TensorLike], TensorLike]],
    train_ds: tf.data.Dataset,
    batch_size: int,
    units: int,
    n_layers: int,
    activation: Union[str, Callable[[TensorLike], TensorLike]],
    learning_rate: float,
    weight_decay: float,
    dropout: float,
    decay_rate: float,
) -> Model:
    inputs = {k: Input(name=k, shape=(1,)) for k in monotonicity_indicator.keys()}
    outputs = create_type_2(
        inputs,
        units=units,
        final_units=1,
        activation=activation,
        n_layers=n_layers,
        monotonicity_indicator=monotonicity_indicator,
        is_convex=False,
        is_concave=False,
        dropout=dropout,
        final_activation=final_activation,
    )
    model = Model(inputs=inputs, outputs=outputs)

    lr_schedule = tf.keras.optimizers.schedules.ExponentialDecay(
        learning_rate,
        decay_steps=len(train_ds.batch(batch_size)),
        decay_rate=decay_rate,
        staircase=True,
    )

    optimizer = AdamW(learning_rate=lr_schedule, weight_decay=weight_decay)
    model.compile(optimizer=optimizer, loss=loss, metrics=metrics)

    return model

In [None]:
train_df, test_df = get_train_n_test_data("auto")
train_ds = df2ds(train_df)
test_ds = df2ds(test_df)

build_model_f = lambda: build_mono_model_f(
    monotonicity_indicator={
        "Cylinders": 0,
        "Displacement": -1,
        "Horsepower": -1,
        "Weight": -1,
        "Acceleration": 0,
        "Model_Year": 0,
        "Origin": 0,
    },
    final_activation=None,
    loss="mse",
    metrics="mse",
    train_ds=train_ds,
    batch_size=8,
    units=16,
    n_layers=3,
    activation="elu",
    learning_rate=0.01,
    weight_decay=0.001,
    dropout=0.25,
    decay_rate=0.95,
)
model = build_model_f()
model.summary()
model.fit(train_ds.batch(8), validation_data=test_ds.batch(256), epochs=1)

Model: "model_1"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
 Acceleration (InputLayer)      [(None, 1)]          0           []                               
                                                                                                  
 Cylinders (InputLayer)         [(None, 1)]          0           []                               
                                                                                                  
 Displacement (InputLayer)      [(None, 1)]          0           []                               
                                                                                                  
 Horsepower (InputLayer)        [(None, 1)]          0           []                               
                                                                                            

<keras.callbacks.History>

In [None]:
# | export


def get_build_model_with_hp_f(
    build_model_f: Callable[[], Model],
    hp_params_f: Optional[Callable[[HyperParameters], Dict[str, Any]]] = None,
    **kwargs: Any,
) -> Callable[[HyperParameters], Model]:
    def build_model_with_hp_f(
        hp: HyperParameters,
        hp_params_f: Optional[
            Callable[[HyperParameters], Dict[str, Any]]
        ] = hp_params_f,
        kwargs: Dict[str, Any] = kwargs,
    ) -> Model:
        override_kwargs = hp_params_f(hp) if hp_params_f is not None else {}

        default_kwargs = dict(
            units=hp.Int("units", min_value=8, max_value=32, step=1),
            n_layers=hp.Int("n_layers", min_value=1, max_value=4),
            activation=hp.Choice("activation", values=["elu"]),
            learning_rate=hp.Float(
                "learning_rate", min_value=1e-3, max_value=0.3, sampling="log"
            ),
            weight_decay=hp.Float(
                "weight_decay", min_value=1e-1, max_value=0.3, sampling="log"
            ),
            dropout=hp.Float(
                "dropout", min_value=0.0, max_value=0.5, sampling="linear"
            ),
            decay_rate=hp.Float(
                "decay_rate", min_value=0.5, max_value=1.0, sampling="reverse_log"
            ),
        )

        default_kwargs.update(**override_kwargs)
        model = build_model_f(**default_kwargs, **kwargs)
        return model

    return build_model_with_hp_f


class TestHyperModel(HyperModel):
    def __init__(self, **kwargs: Any):
        self.kwargs = kwargs

    def build(self, hp: HyperParameters) -> Model:
        build_model_with_hp_f = get_build_model_with_hp_f(
            build_mono_model_f, **self.kwargs  # type: ignore
        )
        return build_model_with_hp_f(hp)

In [None]:
def hp_params_f(hp: HyperParameters):
    return dict(
        units=hp.Fixed(name="units", value=3),
        layers=hp.Fixed(name="units", value=1),
    )


with TemporaryDirectory() as d:
    tuner = RandomSearch(
        hypermodel=TestHyperModel(
            monotonicity_indicator={
                "Cylinders": 0,
                "Displacement": -1,
                "Horsepower": -1,
                "Weight": -1,
                "Acceleration": 0,
                "Model_Year": 0,
                "Origin": 0,
            },
            hp_params_f=lambda hp: {"units": hp.Fixed(name="units", value=3)},
            final_activation=None,
            loss="mse",
            metrics="mse",
            train_ds=train_ds,
            batch_size=8,
        ),
        directory=d,
        project_name="testing",
        max_trials=2,
        objective="val_loss",
    )
    tuner.search(
        train_ds.shuffle(len(train_ds)).batch(8).prefetch(2),
        validation_data=test_ds.batch(256),
        epochs=2,
    )

Trial 2 Complete [00h 00m 04s]
val_loss: 60.91734313964844

Best val_loss So Far: 60.91734313964844
Total elapsed time: 00h 00m 08s
INFO:tensorflow:Oracle triggered exit


In [None]:
# | export


def find_hyperparameters(
    dataset_name: str,
    *,
    monotonicity_indicator: Dict[str, int],
    final_activation: Union[str, Callable[[TensorLike, TensorLike], TensorLike]],
    loss: Union[str, Callable[[TensorLike, TensorLike], TensorLike]],
    metrics: Union[str, Callable[[TensorLike, TensorLike], TensorLike]],
    hp_params_f: Optional[Callable[[HyperParameters], Dict[str, Any]]] = None,
    max_trials: int = 100,
    max_epochs: int = 50,
    batch_size: int = 8,
    objective: Union[str, Objective],
    direction: str,
    dir_root: Union[Path, str] = "tuner",
    seed: int = 42,
    executions_per_trial: int = 3,
    max_consecutive_failed_trials: int = 5,
    patience: int = 10,
) -> Tuner:
    tf.keras.utils.set_random_seed(seed)

    train_df, test_df = get_train_n_test_data(dataset_name)
    train_ds, test_ds = df2ds(train_df), df2ds(test_df)

    oracle = TestHyperModel(
        monotonicity_indicator=monotonicity_indicator,
        hp_params_f=hp_params_f,
        final_activation=final_activation,
        loss=loss,
        metrics=metrics,
        train_ds=train_ds,
        batch_size=batch_size,
    )

    tuner = BayesianOptimization(
        oracle,
        objective=Objective(objective, direction),
        max_trials=max_trials,
        seed=seed,
        directory=Path(dir_root),
        project_name=dataset_name,
        executions_per_trial=executions_per_trial,
        max_consecutive_failed_trials=max_consecutive_failed_trials,
    )

    stop_early = tf.keras.callbacks.EarlyStopping(monitor="val_loss", patience=patience)

    tuner.search(
        train_ds.shuffle(len(train_ds)).batch(batch_size).prefetch(2),
        validation_data=test_ds.batch(256),
        callbacks=[stop_early],
        epochs=max_epochs,
    )

    return tuner

In [None]:
shutil.rmtree("tuner", ignore_errors=True)

tuner = find_hyperparameters(
    "auto",
    monotonicity_indicator={
        "Cylinders": 0,
        "Displacement": -1,
        "Horsepower": -1,
        "Weight": -1,
        "Acceleration": 0,
        "Model_Year": 0,
        "Origin": 0,
    },
    max_trials=2,
    final_activation=None,
    loss="mse",
    metrics="mse",
    objective="val_mse",
    direction="min",
    max_epochs=1,
    executions_per_trial=1,
)

Trial 2 Complete [00h 00m 03s]
val_mse: 32.87412643432617

Best val_mse So Far: 32.87412643432617
Total elapsed time: 00h 00m 06s
INFO:tensorflow:Oracle triggered exit


In [None]:
# | export


def count_model_params(model: Model) -> int:
    return sum([sum([count_params(v) for v in l.variables]) for l in model.layers])


def create_model_stats(
    tuner: Tuner,
    hp: Dict[str, Any],
    *,
    stats: Optional[pd.DataFrame] = None,
    max_epochs: int,
    num_runs: int,
    top_runs: int,
    batch_size: int,
    patience: int,
    verbose: int,
    train_ds: tf.data.Dataset,
    test_ds: tf.data.Dataset,
) -> pd.DataFrame:
    tf.keras.utils.set_random_seed(42)

    def model_stats(
        tuner: Tuner = tuner,
        hp: Dict[str, Any] = hp,
        max_epochs: int = max_epochs,
        batch_size: int = batch_size,
        patience: int = patience,
        verbose: int = verbose,
        train_ds: tf.data.Dataset = train_ds,
        test_ds: tf.data.Dataset = test_ds,
    ) -> float:
        model = tuner.hypermodel.build(hp)
        stop_early = tf.keras.callbacks.EarlyStopping(
            monitor="val_loss", patience=patience
        )
        history = model.fit(
            train_ds.shuffle(len(train_ds)).batch(batch_size).prefetch(2),
            epochs=max_epochs,
            validation_data=test_ds.batch(256),
            verbose=verbose,
            callbacks=[stop_early],
        )
        objective = history.history[tuner.oracle.objective.name]
        if tuner.oracle.objective.direction == "max":
            best_epoch = objective.index(max(objective))
        else:
            best_epoch = objective.index(min(objective))
        return objective[best_epoch]  # type: ignore

    xs = sorted(
        [model_stats() for _ in range(num_runs)],
        reverse=tuner.oracle.objective.direction == "max",
    )
    stats = pd.Series(xs[:top_runs])
    stats = stats.describe()
    stats = {
        f"{tuner.oracle.objective.name}_{k}": stats[k]
        for k in ["mean", "std", "min", "max"]
    }
    model = tuner.hypermodel.build(hp)
    stats_df = pd.DataFrame(
        dict(**hp.values, **stats, params=count_model_params(model)),  # type: ignore
        index=[0],
    )
    return stats_df


def create_tuner_stats(
    tuner: Tuner,
    *,
    num_models: int = 10,
    stats: Optional[pd.DataFrame] = None,
    max_epochs: int = 50,
    batch_size: int = 8,
    patience: int = 10,
    verbose: int = 0,
) -> pd.DataFrame:
    stats = None

    train_df, test_df = get_train_n_test_data(tuner.project_name)
    train_ds, test_ds = df2ds(train_df), df2ds(test_df)

    for hp in tuner.get_best_hyperparameters(num_trials=num_models):
        new_entry = create_model_stats(
            tuner,
            hp,
            stats=stats,
            max_epochs=max_epochs,
            num_runs=10,
            top_runs=5,
            batch_size=batch_size,
            patience=patience,
            verbose=verbose,
            train_ds=train_ds,
            test_ds=test_ds,
        )
        if stats is None:
            stats = new_entry
        else:
            stats = pd.concat([stats, new_entry]).reset_index(drop=True)

        try:
            display(stats.sort_values(f"{tuner.oracle.objective.name}_mean"))  # type: ignore
        # nosemgrep
        except Exception as e:  # nosec
            pass

    return stats.sort_values(f"{tuner.oracle.objective.name}_mean")  # type: ignore

In [None]:
# | notest

stats = create_tuner_stats(tuner, verbose=0)



Unnamed: 0,units,n_layers,activation,learning_rate,weight_decay,dropout,decay_rate,val_mse_mean,val_mse_std,val_mse_min,val_mse_max,params
0,9,2,elu,0.265157,0.196993,0.456821,0.560699,12.738773,1.8673,10.745923,15.125115,173


Unnamed: 0,units,n_layers,activation,learning_rate,weight_decay,dropout,decay_rate,val_mse_mean,val_mse_std,val_mse_min,val_mse_max,params
0,9,2,elu,0.265157,0.196993,0.456821,0.560699,12.738773,1.8673,10.745923,15.125115,173
1,23,1,elu,0.004715,0.265345,0.175923,0.816107,21.378424,1.743336,18.393278,22.992584,106
