-
Notifications
You must be signed in to change notification settings - Fork 0
/
models.py
85 lines (69 loc) · 2.44 KB
/
models.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
import tensorflow as tf
import numpy as np
from multistartopt.layers import FunctionLayer
class MultiStartOptimizationModel(tf.keras.models.Model):
def __init__(self,
function,
starting_pts,
dtype=tf.float32,
**kwargs
):
"""
Model initialization.
:param function: tensorflow function (better if vectorized for batches of inputs)
:param starting_pts: numpy array N-by-n (N points in R^n)
:param kwargs:
"""
super().__init__(dtype=dtype, **kwargs)
self._tf_dtype = dtype
self._starting_pts = tf.cast(starting_pts, dtype=self._tf_dtype)
self._N, self._n = starting_pts.shape
self._function = function
self._func_layer = FunctionLayer(units=self._N, activation=self._function, dtype=self._tf_dtype)
self._fake_input = np.ones((1, self._n))
self.call(self._fake_input)
self.set_weights([self._starting_pts.numpy().T])
def get_config(self):
config = super().get_config()
config['starting_pts'] = self._starting_pts
config['N'] = self._N
config['n'] = self._n
config['function'] = self._function
config['func_layer'] = self._func_layer
config['fake_input'] = self._fake_input
return config
def call(self, inputs, training=None, mask=None):
inputs = tf.cast(inputs, dtype=self._tf_dtype)
tens = self._func_layer(inputs)
return tens
def minimize(self,
epochs=100,
verbose=False,
callbacks=None,
initial_epoch=0,
):
self.fit(
x=self._fake_input,
y=np.zeros((1, self._N)),
epochs=epochs,
verbose=verbose,
callbacks=callbacks,
initial_epoch=initial_epoch,
)
return self.get_weights()[0].T
def find_level_set(self,
y_level,
epochs=100,
verbose=False,
callbacks=None,
initial_epoch=0,
):
self.fit(
x=self._fake_input,
y=y_level * np.ones((1, self._N)),
epochs=epochs,
verbose=verbose,
callbacks=callbacks,
initial_epoch=initial_epoch,
)
return self.get_weights()[0].T