@@ -53,7 +53,7 @@ class DiagonalSetup(QSetup, ABC):
53
53
def _drift (self , _x : ArrayLike , gamma : float ) -> ArrayLike :
54
54
raise NotImplementedError
55
55
56
- def construct_loss (self , state_q : TrainState , xi : ArrayLike , gamma : float , BS : int ) -> Callable [
56
+ def construct_loss (self , state_q : TrainState , gamma : float , BS : int ) -> Callable [
57
57
[Union [FrozenVariableDict , Dict [str , Any ]], ArrayLike ], ArrayLike ]:
58
58
59
59
def loss_fn (params_q : Union [FrozenVariableDict , Dict [str , Any ]], key : ArrayLike ) -> ArrayLike :
@@ -80,14 +80,14 @@ def v_t(_eps, _t):
80
80
log_q_t = - (relative_mixture_weights / (_sigma_t ** 2 ) * (_x - _mu_t )).sum (axis = 1 )
81
81
u_t = (relative_mixture_weights * (1 / _sigma_t * _dsigmadt * (_x - _mu_t ) + _dmudt )).sum (axis = 1 )
82
82
83
- return u_t - self ._drift (_x .reshape (BS , ndim ), gamma ) + 0.5 * (xi ** 2 ) * log_q_t
83
+ return u_t - self ._drift (_x .reshape (BS , ndim ), gamma ) + 0.5 * (self . xi ** 2 ) * log_q_t
84
84
85
- loss = 0.5 * ((v_t (eps , t ) / xi ) ** 2 ).sum (- 1 , keepdims = True )
85
+ loss = 0.5 * ((v_t (eps , t ) / self . xi ) ** 2 ).sum (- 1 , keepdims = True )
86
86
return loss .mean ()
87
87
88
88
return loss_fn
89
89
90
- def u_t (self , state_q : TrainState , t : ArrayLike , x_t : ArrayLike , xi : ArrayLike , * args , ** kwargs ) -> ArrayLike :
90
+ def u_t (self , state_q : TrainState , t : ArrayLike , x_t : ArrayLike , deterministic : bool , * args , ** kwargs ) -> ArrayLike :
91
91
_mu_t , _sigma_t , _w_logits , _dmudt , _dsigmadt = forward_and_derivatives (state_q , t )
92
92
_x = x_t [:, None , :]
93
93
@@ -96,56 +96,45 @@ def u_t(self, state_q: TrainState, t: ArrayLike, x_t: ArrayLike, xi: ArrayLike,
96
96
97
97
_u_t = (relative_mixture_weights * (1 / _sigma_t * _dsigmadt * (_x - _mu_t ) + _dmudt )).sum (axis = 1 )
98
98
99
- if xi == 0 :
99
+ if deterministic :
100
100
return _u_t
101
101
102
102
log_q_t = - (relative_mixture_weights / (_sigma_t ** 2 ) * (_x - _mu_t )).sum (axis = 1 )
103
103
104
- return _u_t + 0.5 * (xi ** 2 ) * log_q_t
104
+ return _u_t + 0.5 * (self . xi ** 2 ) * log_q_t
105
105
106
106
107
107
class FirstOrderSetup (DiagonalSetup ):
108
- def __init__ (self , system : System , model : nn .module , T : float , base_sigma : float , num_mixtures : int ,
108
+ def __init__ (self , system : System , model : nn .module , xi : ArrayLike , T : float , base_sigma : float , num_mixtures : int ,
109
109
trainable_weights : bool ):
110
110
model_q = DiagonalWrapper (model , T , system .A , system .B , num_mixtures , trainable_weights , base_sigma )
111
- super ().__init__ (system , model_q , T , base_sigma , num_mixtures )
111
+ super ().__init__ (system , model_q , xi , T , base_sigma , num_mixtures )
112
112
113
113
def _drift (self , _x : ArrayLike , gamma : float ) -> ArrayLike :
114
114
return - self .system .dUdx (_x / (gamma * self .system .mass ))
115
115
116
116
117
117
class SecondOrderSetup (DiagonalSetup ):
118
- def __init__ (self , system : System , model : nn .module , T : float , base_sigma : float , num_mixtures : int ,
118
+ def __init__ (self , system : System , model : nn .module , xi : ArrayLike , T : float , base_sigma : float , num_mixtures : int ,
119
119
trainable_weights : bool ):
120
120
# We pad the A and B matrices with zeros to account for the velocity
121
121
self ._A = jnp .hstack ([system .A , jnp .zeros_like (system .A )])
122
122
self ._B = jnp .hstack ([system .B , jnp .zeros_like (system .B )])
123
123
124
+ xi_velocity = jnp .ones_like (system .A ) * xi
125
+ xi_pos = jnp .zeros_like (xi_velocity ) + 1e-4
126
+
127
+ xi_second_order = jnp .concatenate ((xi_pos , xi_velocity ), axis = - 1 )
128
+
124
129
model_q = DiagonalWrapper (model , T , self ._A , self ._B , num_mixtures , trainable_weights , base_sigma )
125
- super ().__init__ (system , model_q , T , base_sigma , num_mixtures )
130
+ super ().__init__ (system , model_q , xi_second_order , T , base_sigma , num_mixtures )
126
131
127
132
def _drift (self , _x : ArrayLike , gamma : float ) -> ArrayLike :
128
133
# number of dimensions without velocity
129
134
ndim = self .system .A .shape [0 ]
130
135
131
136
return jnp .hstack ([_x [:, ndim :] / self .system .mass , - self .system .dUdx (_x [:, :ndim ]) - _x [:, ndim :] * gamma ])
132
137
133
- def _xi_to_second_order (self , xi : ArrayLike ) -> ArrayLike :
134
- if xi .shape == self .model_q .A .shape :
135
- return xi
136
-
137
- xi_velocity = jnp .ones_like (self .system .A ) * xi
138
- xi_pos = jnp .zeros_like (xi_velocity ) + 1e-4
139
-
140
- return jnp .concatenate ((xi_pos , xi_velocity ), axis = - 1 )
141
-
142
- def construct_loss (self , state_q : TrainState , xi : ArrayLike , gamma : float , BS : int ) -> Callable [
143
- [Union [FrozenVariableDict , Dict [str , Any ]], ArrayLike ], ArrayLike ]:
144
- return super ().construct_loss (state_q , self ._xi_to_second_order (xi ), gamma , BS )
145
-
146
- def u_t (self , state_q : TrainState , t : ArrayLike , x_t : ArrayLike , xi : ArrayLike , * args , ** kwargs ) -> ArrayLike :
147
- return super ().u_t (state_q , t , x_t , self ._xi_to_second_order (xi ), * args , ** kwargs )
148
-
149
138
@property
150
139
def A (self ):
151
140
return self ._A
0 commit comments