/
models.py
189 lines (141 loc) · 5.97 KB
/
models.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
from functools import partial
from typing import Any, Callable, Sequence, Tuple, Optional, Dict
from flax.training import train_state
from flax import jax_utils
import jax.numpy as jnp
from jax import lax, jit, grad, pmap, random, tree_map, jacfwd, jacrev
from jax.tree_util import tree_map, tree_reduce, tree_leaves
import optax
from jaxpi import archs
from jaxpi.utils import flatten_pytree
class TrainState(train_state.TrainState):
weights: Dict
momentum: float
def apply_weights(self, weights, **kwargs):
"""Updates `weights` using running average in return value.
Returns:
An updated instance of `self` with new weights updated by applying `running_average`,
and additional attributes replaced as specified by `kwargs`.
"""
running_average = (
lambda old_w, new_w: old_w * self.momentum + (1 - self.momentum) * new_w
)
weights = tree_map(running_average, self.weights, weights)
weights = lax.stop_gradient(weights)
return self.replace(
step=self.step,
params=self.params,
opt_state=self.opt_state,
weights=weights,
**kwargs,
)
def _create_arch(config):
if config.arch_name == "Mlp":
arch = archs.Mlp(**config)
elif config.arch_name == "ModifiedMlp":
arch = archs.ModifiedMlp(**config)
elif config.arch_name == "DeepONet":
arch = archs.DeepONet(**config)
else:
raise NotImplementedError(f"Arch {config.arch_name} not supported yet!")
return arch
def _create_optimizer(config):
if config.optimizer == "Adam":
lr = optax.exponential_decay(
init_value=config.learning_rate,
transition_steps=config.decay_steps,
decay_rate=config.decay_rate,
)
tx = optax.adam(
learning_rate=lr, b1=config.beta1, b2=config.beta2, eps=config.eps
)
else:
raise NotImplementedError(f"Optimizer {config.optimizer} not supported yet!")
# Gradient accumulation
if config.grad_accum_steps > 1:
tx = optax.MultiSteps(tx, every_k_schedule=config.grad_accum_steps)
return tx
def _create_train_state(config):
# Initialize network
arch = _create_arch(config.arch)
x = jnp.ones(config.input_dim)
params = arch.init(random.PRNGKey(config.seed), x)
# Initialize optax optimizer
tx = _create_optimizer(config.optim)
# Convert config dict to dict
init_weights = dict(config.weighting.init_weights)
state = TrainState.create(
apply_fn=arch.apply,
params=params,
tx=tx,
weights=init_weights,
momentum=config.weighting.momentum,
)
return jax_utils.replicate(state)
class PINN:
def __init__(self, config):
self.config = config
self.state = _create_train_state(config)
def u_net(self, params, *args):
raise NotImplementedError("Subclasses should implement this!")
def r_net(self, params, *args):
raise NotImplementedError("Subclasses should implement this!")
def losses(self, params, batch, *args):
raise NotImplementedError("Subclasses should implement this!")
def compute_diag_ntk(self, params, batch, *args):
raise NotImplementedError("Subclasses should implement this!")
@partial(jit, static_argnums=(0,))
def loss(self, params, weights, batch, *args):
# Compute losses
losses = self.losses(params, batch, *args)
# Compute weighted loss
weighted_losses = tree_map(lambda x, y: x * y, losses, weights)
# Sum weighted losses
loss = tree_reduce(lambda x, y: x + y, weighted_losses)
return loss
@partial(jit, static_argnums=(0,))
def compute_weights(self, params, batch, *args):
if self.config.weighting.scheme == "grad_norm":
# Compute the gradient of each loss w.r.t. the parameters
grads = jacrev(self.losses)(params, batch, *args)
# Compute the grad norm of each loss
grad_norm_dict = {}
for key, value in grads.items():
flattened_grad = flatten_pytree(value)
grad_norm_dict[key] = jnp.linalg.norm(flattened_grad)
# Compute the mean of grad norms over all losses
mean_grad_norm = jnp.mean(jnp.stack(tree_leaves(grad_norm_dict)))
# Grad Norm Weighting
w = tree_map(lambda x: (mean_grad_norm / x), grad_norm_dict)
elif self.config.weighting.scheme == "ntk":
# Compute the diagonal of the NTK of each loss
ntk = self.compute_diag_ntk(params, batch, *args)
# Compute the mean of the diagonal NTK corresponding to each loss
mean_ntk_dict = tree_map(lambda x: jnp.mean(x), ntk)
# Compute the average over all ntk means
mean_ntk = jnp.mean(jnp.stack(tree_leaves(mean_ntk_dict)))
# NTK Weighting
w = tree_map(lambda x: (mean_ntk / x), mean_ntk_dict)
return w
@partial(pmap, axis_name="batch", static_broadcasted_argnums=(0,))
def update_weights(self, state, batch, *args):
weights = self.compute_weights(state.params, batch, *args)
weights = lax.pmean(weights, "batch")
state = state.apply_weights(weights=weights)
return state
@partial(pmap, axis_name="batch", static_broadcasted_argnums=(0,))
def step(self, state, batch, *args):
grads = grad(self.loss)(state.params, state.weights, batch, *args)
grads = lax.pmean(grads, "batch")
state = state.apply_gradients(grads=grads)
return state
class ForwardIVP(PINN):
def __init__(self, config):
super().__init__(config)
if config.weighting.use_causal:
self.tol = config.weighting.causal_tol
self.num_chunks = config.weighting.num_chunks
self.M = jnp.triu(jnp.ones((self.num_chunks, self.num_chunks)), k=1).T
class ForwardBVP(PINN):
def __init__(self, config):
super().__init__(config)