-
Notifications
You must be signed in to change notification settings - Fork 124
/
neuralde.py
219 lines (184 loc) · 12.2 KB
/
neuralde.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Callable, Union, List, Iterable, Generator
from torchdyn.core.problems import MultipleShootingProblem, ODEProblem, SDEProblem
from torchdyn.numerics import odeint
from torchdyn.core.defunc import DEFunc, DEFuncBase, SDEFunc
from torchdyn.core.utils import standardize_vf_call_signature
import pytorch_lightning as pl
import torch
from torch import Tensor
import torch.nn as nn
import torchsde
import warnings
class NeuralODE(ODEProblem, pl.LightningModule):
def __init__(self, vector_field:Union[Callable, nn.Module], solver:Union[str, nn.Module]='tsit5', order:int=1,
atol:float=1e-3, rtol:float=1e-3, sensitivity='autograd', solver_adjoint:Union[str, nn.Module, None] = None,
atol_adjoint:float=1e-4, rtol_adjoint:float=1e-4, interpolator:Union[str, Callable, None]=None, \
integral_loss:Union[Callable, None]=None, seminorm:bool=False, return_t_eval:bool=True, optimizable_params:Union[Iterable, Generator]=()):
"""Generic Neural Ordinary Differential Equation.
Args:
vector_field ([Callable]): the vector field, called with `vector_field(t, x)` for `vector_field(x)`.
In the second case, the Callable is automatically wrapped for consistency
solver (Union[str, nn.Module]):
order (int, optional): Order of the ODE. Defaults to 1.
atol (float, optional): Absolute tolerance of the solver. Defaults to 1e-4.
rtol (float, optional): Relative tolerance of the solver. Defaults to 1e-4.
sensitivity (str, optional): Sensitivity method ['autograd', 'adjoint', 'interpolated_adjoint']. Defaults to 'autograd'.
solver_adjoint (Union[str, nn.Module, None], optional): ODE solver for the adjoint. Defaults to None.
atol_adjoint (float, optional): Defaults to 1e-6.
rtol_adjoint (float, optional): Defaults to 1e-6.
integral_loss (Union[Callable, None], optional): Defaults to None.
seminorm (bool, optional): Whether to use seminorms for adaptive stepping in backsolve adjoints. Defaults to False.
return_t_eval (bool): Whether to return (t_eval, sol) or only sol. Useful for chaining NeuralODEs in `nn.Sequential`.
optimizable_parameters (Union[Iterable, Generator]): parameters to calculate sensitivies for. Defaults to ().
Notes:
In `torchdyn`-style, forward calls to a Neural ODE return both a tensor `t_eval` of time points at which the solution is evaluated
as well as the solution itself. This behavior can be controlled by setting `return_t_eval` to False. Calling `trajectory` also returns
the solution only.
The Neural ODE class automates certain delicate steps that must be done depending on the solver and model used.
The `prep_odeint` method carries out such steps. Neural ODEs wrap `ODEProblem`.
"""
super().__init__(vector_field=standardize_vf_call_signature(vector_field, order, defunc_wrap=True), order=order, sensitivity=sensitivity,
solver=solver, atol=atol, rtol=rtol, solver_adjoint=solver_adjoint, atol_adjoint=atol_adjoint, rtol_adjoint=rtol_adjoint,
seminorm=seminorm, interpolator=interpolator, integral_loss=integral_loss, optimizable_params=optimizable_params)
self.u, self.controlled, self.t_span = None, False, None # data-control conditioning
self.return_t_eval = return_t_eval
if integral_loss is not None: self.vf.integral_loss = integral_loss
self.vf.sensitivity = sensitivity
def _prep_integration(self, x:Tensor, t_span:Tensor) -> Tensor:
"Performs generic checks before integration. Assigns data control inputs and augments state for CNFs"
# assign a basic value to `t_span` for `forward` calls that do no explicitly pass an integration interval
if t_span is None and self.t_span is None: t_span = torch.linspace(0, 1, 2)
elif t_span is None: t_span = self.t_span
# loss dimension detection routine; for CNF div propagation and integral losses w/ autograd
excess_dims = 0
if (not self.integral_loss is None) and self.sensitivity == 'autograd':
excess_dims += 1
# handle aux. operations required for some jacobian trace CNF estimators e.g Hutchinson's
# as well as datasets-control set to DataControl module
for _, module in self.vf.named_modules():
if hasattr(module, 'trace_estimator'):
if module.noise_dist is not None: module.noise = module.noise_dist.sample((x.shape[0],))
excess_dims += 1
# data-control set routine. Is performed once at the beginning of odeint since the control is fixed to IC
if hasattr(module, 'u'):
self.controlled = True
module.u = x[:, excess_dims:].detach()
return x, t_span
def forward(self, x:Tensor, t_span:Tensor=None):
x, t_span = self._prep_integration(x, t_span)
t_eval, sol = super().forward(x, t_span)
if self.return_t_eval: return t_eval, sol
else: return sol
def trajectory(self, x:torch.Tensor, t_span:Tensor):
x, t_span = self._prep_integration(x, t_span)
_, sol = odeint(self.vf, x, t_span, solver=self.solver, atol=self.atol, rtol=self.rtol)
return sol
def __repr__(self):
npar = sum([p.numel() for p in self.vf.parameters()])
return f"Neural ODE:\n\t- order: {self.order}\
\n\t- solver: {self.solver}\n\t- adjoint solver: {self.solver_adjoint}\
\n\t- tolerances: relative {self.rtol} absolute {self.atol}\
\n\t- adjoint tolerances: relative {self.rtol_adjoint} absolute {self.atol_adjoint}\
\n\t- num_parameters: {npar}\
\n\t- NFE: {self.vf.nfe}"
class NeuralSDE(SDEProblem, pl.LightningModule):
def __init__(self, drift_func, diffusion_func, noise_type ='diagonal', sde_type = 'ito', order=1,
sensitivity='autograd', s_span=torch.linspace(0, 1, 2), solver='srk',
atol=1e-4, rtol=1e-4, ds = 1e-3, intloss=None):
"""Generic Neural Stochastic Differential Equation. Follows the same design of the `NeuralODE` class.
Args:
drift_func ([type]): drift function
diffusion_func ([type]): diffusion function
noise_type (str, optional): Defaults to 'diagonal'.
sde_type (str, optional): Defaults to 'ito'.
order (int, optional): Defaults to 1.
sensitivity (str, optional): Defaults to 'autograd'.
s_span ([type], optional): Defaults to torch.linspace(0, 1, 2).
solver (str, optional): Defaults to 'srk'.
atol ([type], optional): Defaults to 1e-4.
rtol ([type], optional): Defaults to 1e-4.
ds ([type], optional): Defaults to 1e-3.
intloss ([type], optional): Defaults to None.
Raises:
NotImplementedError: higher-order Neural SDEs are not yet implemented, raised by setting `order` to >1.
Notes:
The current implementation is rougher around the edges compared to `NeuralODE`, and is not guaranteed to have the same features.
"""
super().__init__(func=SDEFunc(f=drift_func, g=diffusion_func, order=order), order=order, sensitivity=sensitivity, s_span=s_span, solver=solver,
atol=atol, rtol=rtol)
if order != 1: raise NotImplementedError
self.defunc.noise_type, self.defunc.sde_type = noise_type, sde_type
self.adaptive = False
self.intloss = intloss
self.u, self.controlled = None, False # datasets-control
self.ds = ds
def _prep_sdeint(self, x:torch.Tensor):
self.s_span = self.s_span.to(x)
# datasets-control set routine. Is performed once at the beginning of odeint since the control is fixed to IC
excess_dims = 0
for _, module in self.defunc.named_modules():
if hasattr(module, 'u'):
self.controlled = True
module.u = x[:, excess_dims:].detach()
return x
def forward(self, x:torch.Tensor):
x = self._prep_sdeint(x)
switcher = {
'autograd': self._autograd,
'adjoint': self._adjoint,
}
sdeint = switcher.get(self.sensitivity)
out = sdeint(x)
return out
def trajectory(self, x:torch.Tensor, s_span:torch.Tensor):
x = self._prep_sdeint(x)
sol = torchsde.sdeint(self.defunc, x, s_span, rtol=self.rtol, atol=self.atol,
method=self.solver, dt=self.ds)
return sol
def backward_trajectory(self, x:torch.Tensor, s_span:torch.Tensor):
raise NotImplementedError
def _autograd(self, x):
self.defunc.intloss, self.defunc.sensitivity = self.intloss, self.sensitivity
return torchsde.sdeint(self.defunc, x, self.s_span, rtol=self.rtol, atol=self.atol,
adaptive=self.adaptive, method=self.solver, dt=self.ds)[-1]
def _adjoint(self, x):
out = torchsde.sdeint_adjoint(self.defunc, x, self.s_span, rtol=self.rtol, atol=self.atol,
adaptive=self.adaptive, method=self.solver, dt=self.ds)[-1]
return out
class MultipleShootingLayer(MultipleShootingProblem, pl.LightningModule):
def __init__(self, vector_field:Callable, solver:str, sensitivity:str='autograd',
maxiter:int=4, fine_steps:int=4, solver_adjoint:Union[str, nn.Module, None] = None, atol_adjoint:float=1e-6,
rtol_adjoint:float=1e-6, seminorm:bool=False, integral_loss:Union[Callable, None]=None):
"""Multiple Shooting Layer as defined in https://arxiv.org/abs/2106.03885.
Uses parallel-in-time ODE solvers to solve an ODE parametrized by neural network `vector_field`.
Args:
vector_field ([Callable]): the vector field, called with `vector_field(t, x)` for `vector_field(x)`.
In the second case, the Callable is automatically wrapped for consistency
solver (Union[str, nn.Module]): parallel-in-time solver, ['zero', 'direct']
sensitivity (str, optional): Sensitivity method ['autograd', 'adjoint', 'interpolated_adjoint']. Defaults to 'autograd'.
maxiter (int): number of iterations of the root finding routine defined to parallel solve the ODE.
fine_steps (int): number of fine-solver steps to perform in each subinterval of the parallel solution.
solver_adjoint (Union[str, nn.Module, None], optional): Standard sequential ODE solver for the adjoint system.
atol_adjoint (float, optional): Defaults to 1e-6.
rtol_adjoint (float, optional): Defaults to 1e-6.
integral_loss (Union[Callable, None], optional): Currently not implemented
seminorm (bool, optional): Whether to use seminorms for adaptive stepping in backsolve adjoints. Defaults to False.
Notes:
The number of shooting parameters (first dimension in `B0`) is implicitly defined by passing `t_span` during forward calls.
For example, a `t_span=torch.linspace(0, 1, 10)` will define 9 intervals and 10 shooting parameters.
For the moment only a thin wrapper around `MultipleShootingProblem`. At this level will be convenience routines for special
initializations of shooting parameters `B0`, as well as usual convenience checks for integral losses.
"""
super().__init__(vector_field, solver, sensitivity, maxiter, fine_steps, solver_adjoint, atol_adjoint,
rtol_adjoint, seminorm, integral_loss)