1
- from abc import ABC , abstractmethod
2
1
from dataclasses import dataclass
3
2
from jax .typing import ArrayLike
4
3
from flax import linen as nn
5
4
import jax .numpy as jnp
6
- from typing import Union , Dict , Any , Callable , Tuple , Optional
5
+ from typing import Union , Dict , Any , Callable
7
6
from flax .training .train_state import TrainState
8
7
import jax
9
8
from flax .typing import FrozenVariableDict
10
9
from model .utils import WrappedModule
11
- from training .qsetup import QSetup
12
10
from systems import System
11
+ from training .setups .drift import DriftedSetup
13
12
from training .utils import forward_and_derivatives
14
13
15
14
@@ -21,7 +20,7 @@ class DiagonalWrapper(WrappedModule):
21
20
base_sigma : float
22
21
23
22
@nn .compact
24
- def _post_process (self , t : ArrayLike , h : ArrayLike ):
23
+ def _post_process (self , h : ArrayLike , t : ArrayLike ):
25
24
ndim = self .A .shape [0 ]
26
25
num_mixtures = self .num_mixtures
27
26
h = nn .Dense (2 * ndim * num_mixtures )(h )
@@ -43,15 +42,13 @@ def _post_process(self, t: ArrayLike, h: ArrayLike):
43
42
44
43
45
44
@dataclass
46
- class DiagonalSetup (QSetup , ABC ):
45
+ class DiagonalSetup (DriftedSetup ):
47
46
model_q : DiagonalWrapper
48
47
T : float
49
- base_sigma : float
50
- num_mixtures : int
51
48
52
- @ abstractmethod
53
- def _drift ( self , _x : ArrayLike , gamma : float ) -> ArrayLike :
54
- raise NotImplementedError
49
+ def __init__ ( self , system : System , model_q : DiagonalWrapper , xi : ArrayLike , order : str , T : float ):
50
+ super (). __init__ ( system , model_q , xi , order )
51
+ self . T = T
55
52
56
53
def construct_loss (self , state_q : TrainState , gamma : float , BS : int ) -> Callable [
57
54
[Union [FrozenVariableDict , Dict [str , Any ]], ArrayLike ], ArrayLike ]:
@@ -70,7 +67,7 @@ def v_t(_eps, _t):
70
67
71
68
_x = _mu_t [jnp .arange (BS ), _i , None ] + _sigma_t [jnp .arange (BS ), _i , None ] * eps
72
69
73
- if self . num_mixtures == 1 :
70
+ if _mu_t . shape [ 1 ] == 1 :
74
71
# This completely ignores the weights and saves some time
75
72
relative_mixture_weights = 1
76
73
else :
@@ -102,43 +99,3 @@ def u_t(self, state_q: TrainState, t: ArrayLike, x_t: ArrayLike, deterministic:
102
99
log_q_t = - (relative_mixture_weights / (_sigma_t ** 2 ) * (_x - _mu_t )).sum (axis = 1 )
103
100
104
101
return _u_t + 0.5 * (self .xi ** 2 ) * log_q_t
105
-
106
-
107
- class FirstOrderSetup (DiagonalSetup ):
108
- def __init__ (self , system : System , model : nn .module , xi : ArrayLike , T : float , base_sigma : float , num_mixtures : int ,
109
- trainable_weights : bool ):
110
- model_q = DiagonalWrapper (model , T , system .A , system .B , num_mixtures , trainable_weights , base_sigma )
111
- super ().__init__ (system , model_q , xi , T , base_sigma , num_mixtures )
112
-
113
- def _drift (self , _x : ArrayLike , gamma : float ) -> ArrayLike :
114
- return - self .system .dUdx (_x / (gamma * self .system .mass ))
115
-
116
-
117
- class SecondOrderSetup (DiagonalSetup ):
118
- def __init__ (self , system : System , model : nn .module , xi : ArrayLike , T : float , base_sigma : float , num_mixtures : int ,
119
- trainable_weights : bool ):
120
- # We pad the A and B matrices with zeros to account for the velocity
121
- self ._A = jnp .hstack ([system .A , jnp .zeros_like (system .A )])
122
- self ._B = jnp .hstack ([system .B , jnp .zeros_like (system .B )])
123
-
124
- xi_velocity = jnp .ones_like (system .A ) * xi
125
- xi_pos = jnp .zeros_like (xi_velocity ) + 1e-4
126
-
127
- xi_second_order = jnp .concatenate ((xi_pos , xi_velocity ), axis = - 1 )
128
-
129
- model_q = DiagonalWrapper (model , T , self ._A , self ._B , num_mixtures , trainable_weights , base_sigma )
130
- super ().__init__ (system , model_q , xi_second_order , T , base_sigma , num_mixtures )
131
-
132
- def _drift (self , _x : ArrayLike , gamma : float ) -> ArrayLike :
133
- # number of dimensions without velocity
134
- ndim = self .system .A .shape [0 ]
135
-
136
- return jnp .hstack ([_x [:, ndim :] / self .system .mass , - self .system .dUdx (_x [:, :ndim ]) - _x [:, ndim :] * gamma ])
137
-
138
- @property
139
- def A (self ):
140
- return self ._A
141
-
142
- @property
143
- def B (self ):
144
- return self ._B
0 commit comments