Skip to content

Commit b322960

Browse files
committed
Refactor first and second order
1 parent c767ebc commit b322960

File tree

6 files changed

+81
-71
lines changed

6 files changed

+81
-71
lines changed

model/utils.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from abc import ABC, abstractmethod
2+
from typing import Tuple, Any
23
from flax import linen as nn
34
from jax.typing import ArrayLike
45

@@ -16,9 +17,15 @@ class WrappedModule(ABC, nn.Module):
1617
def __call__(self, t: ArrayLike):
1718
t = t / self.T
1819

19-
h = self.other(t)
20-
return self._post_process(t, h)
20+
h, args = self._pre_process(t)
21+
h = self.other(h)
22+
return self._post_process(h, *args)
23+
24+
def _pre_process(self, t: ArrayLike) -> Tuple[ArrayLike, Tuple[Any, ...]]:
25+
"""This function returns a tuple. The first element will be used as an input to the other module,
26+
and the second value will be passed to the post process function."""
27+
return t, (t,)
2128

2229
@abstractmethod
23-
def _post_process(self, t: ArrayLike, h: ArrayLike):
30+
def _post_process(self, h: ArrayLike, *args):
2431
raise NotImplementedError

training/qsetup.py

Lines changed: 24 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -83,21 +83,30 @@ def B(self):
8383

8484
def construct(system: System, model: nn.module, ode: str, parameterization: str, xi: ArrayLike,
8585
args: argparse.Namespace) -> QSetup:
86-
from training import diagonal
86+
from training.setups import diagonal
8787

8888
if ode == 'first_order':
89-
if parameterization == 'diagonal':
90-
return diagonal.FirstOrderSetup(system, model, xi, args.T, args.base_sigma, args.num_gaussians,
91-
args.trainable_weights)
92-
elif args.parameterization == 'low_rank':
93-
raise NotImplementedError("Low-rank parameterization not implemented")
94-
else:
95-
raise ValueError(f"Unknown parameterization: {args.parameterization}")
96-
elif args.ode == 'second_order':
97-
if parameterization == 'diagonal':
98-
return diagonal.SecondOrderSetup(system, model, xi, args.T, args.base_sigma, args.num_gaussians,
99-
args.trainable_weights)
100-
else:
101-
raise NotImplementedError("Second-order ODE not implemented")
89+
order = 'first'
90+
A = system.A
91+
B = system.B
92+
elif ode == 'second_order':
93+
order = 'second'
94+
95+
# We pad the A and B matrices with zeros to account for the velocity
96+
A = jnp.hstack([system.A, jnp.zeros_like(system.A)])
97+
B = jnp.hstack([system.B, jnp.zeros_like(system.B)])
98+
99+
xi_velocity = jnp.ones_like(system.A) * xi
100+
xi_pos = jnp.zeros_like(xi_velocity) + 1e-4
101+
102+
xi = jnp.concatenate((xi_pos, xi_velocity), axis=-1)
103+
else:
104+
raise ValueError(f"Unknown ODE: {ode}")
105+
106+
if parameterization == 'diagonal':
107+
wrapped_module = diagonal.DiagonalWrapper(
108+
model, args.T, A, B, args.num_gaussians, args.trainable_weights, args.base_sigma
109+
)
110+
return diagonal.DiagonalSetup(system, wrapped_module, xi, order, args.T)
102111
else:
103-
raise ValueError(f"Unknown ODE: {args.ode}")
112+
raise ValueError(f"Unknown parameterization: {parameterization}")

training/diagonal.py renamed to training/setups/diagonal.py

Lines changed: 8 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,14 @@
1-
from abc import ABC, abstractmethod
21
from dataclasses import dataclass
32
from jax.typing import ArrayLike
43
from flax import linen as nn
54
import jax.numpy as jnp
6-
from typing import Union, Dict, Any, Callable, Tuple, Optional
5+
from typing import Union, Dict, Any, Callable
76
from flax.training.train_state import TrainState
87
import jax
98
from flax.typing import FrozenVariableDict
109
from model.utils import WrappedModule
11-
from training.qsetup import QSetup
1210
from systems import System
11+
from training.setups.drift import DriftedSetup
1312
from training.utils import forward_and_derivatives
1413

1514

@@ -21,7 +20,7 @@ class DiagonalWrapper(WrappedModule):
2120
base_sigma: float
2221

2322
@nn.compact
24-
def _post_process(self, t: ArrayLike, h: ArrayLike):
23+
def _post_process(self, h: ArrayLike, t: ArrayLike):
2524
ndim = self.A.shape[0]
2625
num_mixtures = self.num_mixtures
2726
h = nn.Dense(2 * ndim * num_mixtures)(h)
@@ -43,15 +42,13 @@ def _post_process(self, t: ArrayLike, h: ArrayLike):
4342

4443

4544
@dataclass
46-
class DiagonalSetup(QSetup, ABC):
45+
class DiagonalSetup(DriftedSetup):
4746
model_q: DiagonalWrapper
4847
T: float
49-
base_sigma: float
50-
num_mixtures: int
5148

52-
@abstractmethod
53-
def _drift(self, _x: ArrayLike, gamma: float) -> ArrayLike:
54-
raise NotImplementedError
49+
def __init__(self, system: System, model_q: DiagonalWrapper, xi: ArrayLike, order: str, T: float):
50+
super().__init__(system, model_q, xi, order)
51+
self.T = T
5552

5653
def construct_loss(self, state_q: TrainState, gamma: float, BS: int) -> Callable[
5754
[Union[FrozenVariableDict, Dict[str, Any]], ArrayLike], ArrayLike]:
@@ -70,7 +67,7 @@ def v_t(_eps, _t):
7067

7168
_x = _mu_t[jnp.arange(BS), _i, None] + _sigma_t[jnp.arange(BS), _i, None] * eps
7269

73-
if self.num_mixtures == 1:
70+
if _mu_t.shape[1] == 1:
7471
# This completely ignores the weights and saves some time
7572
relative_mixture_weights = 1
7673
else:
@@ -102,43 +99,3 @@ def u_t(self, state_q: TrainState, t: ArrayLike, x_t: ArrayLike, deterministic:
10299
log_q_t = -(relative_mixture_weights / (_sigma_t ** 2) * (_x - _mu_t)).sum(axis=1)
103100

104101
return _u_t + 0.5 * (self.xi ** 2) * log_q_t
105-
106-
107-
class FirstOrderSetup(DiagonalSetup):
108-
def __init__(self, system: System, model: nn.module, xi: ArrayLike, T: float, base_sigma: float, num_mixtures: int,
109-
trainable_weights: bool):
110-
model_q = DiagonalWrapper(model, T, system.A, system.B, num_mixtures, trainable_weights, base_sigma)
111-
super().__init__(system, model_q, xi, T, base_sigma, num_mixtures)
112-
113-
def _drift(self, _x: ArrayLike, gamma: float) -> ArrayLike:
114-
return -self.system.dUdx(_x / (gamma * self.system.mass))
115-
116-
117-
class SecondOrderSetup(DiagonalSetup):
118-
def __init__(self, system: System, model: nn.module, xi: ArrayLike, T: float, base_sigma: float, num_mixtures: int,
119-
trainable_weights: bool):
120-
# We pad the A and B matrices with zeros to account for the velocity
121-
self._A = jnp.hstack([system.A, jnp.zeros_like(system.A)])
122-
self._B = jnp.hstack([system.B, jnp.zeros_like(system.B)])
123-
124-
xi_velocity = jnp.ones_like(system.A) * xi
125-
xi_pos = jnp.zeros_like(xi_velocity) + 1e-4
126-
127-
xi_second_order = jnp.concatenate((xi_pos, xi_velocity), axis=-1)
128-
129-
model_q = DiagonalWrapper(model, T, self._A, self._B, num_mixtures, trainable_weights, base_sigma)
130-
super().__init__(system, model_q, xi_second_order, T, base_sigma, num_mixtures)
131-
132-
def _drift(self, _x: ArrayLike, gamma: float) -> ArrayLike:
133-
# number of dimensions without velocity
134-
ndim = self.system.A.shape[0]
135-
136-
return jnp.hstack([_x[:, ndim:] / self.system.mass, -self.system.dUdx(_x[:, :ndim]) - _x[:, ndim:] * gamma])
137-
138-
@property
139-
def A(self):
140-
return self._A
141-
142-
@property
143-
def B(self):
144-
return self._B

training/setups/drift.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
from abc import ABC
2+
from flax import linen as nn
3+
import jax.numpy as jnp
4+
from systems import System
5+
from training.qsetup import QSetup
6+
from jax.typing import ArrayLike
7+
8+
9+
class DriftedSetup(QSetup, ABC):
10+
"""A QSetup that has a drift term. This drift term can be either first or second order."""
11+
12+
def __init__(self, system: System, model_q: nn.Module, xi: ArrayLike, order: str):
13+
"""Either instantiate with first or second order drift."""
14+
assert order == 'first' or order == 'second', "Order must be either 'first' or 'second'."
15+
16+
self.order = order
17+
self._A = system.A
18+
self._B = system.B
19+
20+
super().__init__(system, model_q, xi)
21+
22+
def _drift(self, _x: ArrayLike, gamma: float) -> ArrayLike:
23+
if self.order == 'first':
24+
return -self.system.dUdx(_x / (gamma * self.system.mass))
25+
else:
26+
# number of dimensions without velocity
27+
ndim = self.system.A.shape[0]
28+
29+
return jnp.hstack([_x[:, ndim:] / self.system.mass, -self.system.dUdx(_x[:, :ndim]) - _x[:, ndim:] * gamma])
30+
31+
@property
32+
def A(self):
33+
return self._A
34+
35+
@property
36+
def B(self):
37+
return self._B

training/setups/lowrank.py

Whitespace-only changes.

training/train.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,9 +29,9 @@ def train_step(_state_q: TrainState, _key: ArrayLike) -> (TrainState, float):
2929
log_loss = True
3030

3131
if log_loss:
32-
pbar.set_postfix(log_loss=jnp.log(loss))
32+
pbar.set_postfix(log_loss=f"{jnp.log(loss):.4f}")
3333
else:
34-
pbar.set_postfix(loss=loss)
34+
pbar.set_postfix(loss=f"{loss:.4f}")
3535
ckpt['losses'].append(loss.item())
3636

3737
if checkpoint_manager.should_save(i + 1):

0 commit comments

Comments
 (0)