-
Notifications
You must be signed in to change notification settings - Fork 14
/
derivatives.py
704 lines (580 loc) · 21.2 KB
/
derivatives.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
"""Wrapper classes for JAX automatic differentiation and finite differences."""
from abc import ABC, abstractmethod
import numpy as np
from termcolor import colored
from desc.backend import fori_loop, jnp, put, use_jax
if use_jax:
import jax
class _Derivative(ABC):
"""_Derivative is an abstract base class for derivative matrix calculations.
Parameters
----------
fun : callable
Function to be differentiated.
argnums : int, optional
Specifies which positional argument to differentiate with respect to
"""
@abstractmethod
def __init__(self, fun, argnum=0, mode=None, **kwargs):
pass
@abstractmethod
def compute(self, *args, **kwargs):
"""Compute the derivative matrix.
Parameters
----------
*args : list
Arguments of the objective function where the derivative is to be
evaluated at.
Returns
-------
D : ndarray of float
derivative of f evaluated at x, where f is the output of the function
fun and x is the input argument at position argnum. Exact shape and meaning
will depend on "mode"
"""
@property
def fun(self):
"""Callable : function being differentiated."""
return self._fun
@fun.setter
def fun(self, fun):
self._fun = fun
@property
def argnum(self):
"""Integer : argument being differentiated with respect to."""
return self._argnum
@argnum.setter
def argnum(self, argnum):
self._argnum = argnum
@property
def mode(self):
"""String : the kind of derivative being computed (eg ``'grad'``)."""
return self._mode
def __call__(self, *args, **kwargs):
"""Compute the derivative matrix.
Parameters
----------
*args : list
Arguments of the objective function where the derivative is to be
evaluated at.
Returns
-------
D : ndarray of float
derivative of f evaluated at x, where f is the output of the function
fun and x is the input argument at position argnum. Exact shape and meaning
will depend on "mode"
"""
return self.compute(*args, **kwargs)
def __repr__(self):
"""String form of the object."""
return (
type(self).__name__
+ " at "
+ str(hex(id(self)))
+ " (fun={}, argnum={}, mode={})".format(
repr(self.fun), self.argnum, self.mode
)
)
class AutoDiffDerivative(_Derivative):
"""Computes derivatives using automatic differentiation with JAX.
Parameters
----------
fun : callable
Function to be differentiated.
argnum : int, optional
Specifies which positional argument to differentiate with respect to
mode : str, optional
Automatic differentiation mode.
One of ``'fwd'`` (forward mode Jacobian), ``'rev'`` (reverse mode Jacobian),
``'grad'`` (gradient of a scalar function),
``'hess'`` (Hessian of a scalar function),
or ``'jvp'`` (Jacobian vector product)
Default = ``'fwd'``
Raises
------
ValueError, if mode is not supported
"""
def __init__(self, fun, argnum=0, mode="fwd", **kwargs):
self._fun = fun
self._argnum = argnum
self._set_mode(mode)
def compute(self, *args, **kwargs):
"""Compute the derivative matrix.
Parameters
----------
*args : list
Arguments of the objective function where the derivative is to be
evaluated at.
Returns
-------
D : ndarray of float
derivative of f evaluated at x, where f is the output of the function
fun and x is the input argument at position argnum. Exact shape and meaning
will depend on "mode"
"""
return self._compute(*args, **kwargs)
@classmethod
def compute_vjp(cls, fun, argnum, v, *args, **kwargs):
"""Compute v.T * df/dx.
Parameters
----------
fun : callable
function to differentiate
argnum : int or tuple
arguments to differentiate with respect to
v : array-like or tuple of array-like
tangent vectors. Should be one for each output of fun.
args : tuple
arguments passed to fun
kwargs : dict
keyword arguments passed to fun
Returns
-------
vjp : array-like
Vector v times Jacobian, summed over different argnums
"""
assert jnp.isscalar(argnum), "vjp for multiple args not currently supported"
_ = kwargs.pop("rel_step", None) # unused by autodiff
def _fun(*args):
return v.T @ fun(*args, **kwargs)
return jax.grad(_fun, argnum)(*args)
@classmethod
def compute_jvp(cls, fun, argnum, v, *args, **kwargs):
"""Compute df/dx*v.
Parameters
----------
fun : callable
function to differentiate
argnum : int or tuple
arguments to differentiate with respect to
v : array-like or tuple of array-like
tangent vectors. Should be one for each argnum
args : tuple
arguments passed to fun
kwargs : dict
keyword arguments passed to fun
Returns
-------
jvp : array-like
Jacobian times vectors v, summed over different argnums
"""
_ = kwargs.pop("rel_step", None) # unused by autodiff
argnum = (argnum,) if jnp.isscalar(argnum) else tuple(argnum)
v = (v,) if not isinstance(v, (tuple, list)) else v
def _fun(*x):
_args = list(args)
for i, xi in zip(argnum, x):
_args[i] = xi
return fun(*_args, **kwargs)
y, u = jax.jvp(_fun, tuple(args[i] for i in argnum), v)
return u
@classmethod
def compute_jvp2(cls, fun, argnum1, argnum2, v1, v2, *args, **kwargs):
"""Compute d^2f/dx^2*v1*v2.
Parameters
----------
fun : callable
function to differentiate
argnum1, argnum2 : int or tuple of int
arguments to differentiate with respect to. First entry corresponds to v1,
second to v2
v1,v2 : array-like or tuple of array-like
tangent vectors. Should be one for each argnum
args : tuple
arguments passed to fun
kwargs : dict
keyword arguments passed to fun
Returns
-------
jvp2 : array-like
second derivative times vectors v1, v2, summed over different argnums
"""
if np.isscalar(argnum1):
v1 = (v1,) if not isinstance(v1, (tuple, list)) else v1
argnum1 = (argnum1,)
else:
v1 = tuple(v1)
if np.isscalar(argnum2):
argnum2 = (argnum2 + 1,)
v2 = (v2,) if not isinstance(v2, (tuple, list)) else v2
else:
argnum2 = tuple([i + 1 for i in argnum2])
v2 = tuple(v2)
dfdx = lambda dx1, *args: cls.compute_jvp(fun, argnum1, dx1, *args, **kwargs)
d2fdx2 = lambda dx1, dx2: cls.compute_jvp(
dfdx, argnum2, dx2, dx1, *args, **kwargs
)
return d2fdx2(v1, v2)
@classmethod
def compute_jvp3(cls, fun, argnum1, argnum2, argnum3, v1, v2, v3, *args, **kwargs):
"""Compute d^3f/dx^3*v1*v2*v3.
Parameters
----------
fun : callable
function to differentiate
argnum1, argnum2, argnum3 : int or tuple of int
arguments to differentiate with respect to. First entry corresponds to v1,
second to v2 etc
v1,v2,v3 : array-like or tuple of array-like
tangent vectors. Should be one for each argnum
args : tuple
arguments passed to fun
kwargs : dict
keyword arguments passed to fun
Returns
-------
jvp3 : array-like
third derivative times vectors v2, v3, v3, summed over different argnums
"""
if np.isscalar(argnum1):
v1 = (v1,) if not isinstance(v1, (tuple, list)) else v1
argnum1 = (argnum1,)
else:
v1 = tuple(v1)
if np.isscalar(argnum2):
argnum2 = (argnum2 + 1,)
v2 = (v2,) if not isinstance(v2, (tuple, list)) else v2
else:
argnum2 = tuple([i + 1 for i in argnum2])
v2 = tuple(v2)
if np.isscalar(argnum3):
argnum3 = (argnum3 + 2,)
v3 = (v3,) if not isinstance(v3, (tuple, list)) else v3
else:
argnum3 = tuple([i + 2 for i in argnum3])
v3 = tuple(v3)
dfdx = lambda dx1, *args: cls.compute_jvp(fun, argnum1, dx1, *args, **kwargs)
d2fdx2 = lambda dx1, dx2, *args: cls.compute_jvp(
dfdx, argnum2, dx2, dx1, *args, **kwargs
)
d3fdx3 = lambda dx1, dx2, dx3: cls.compute_jvp(
d2fdx2, argnum3, dx3, dx2, dx1, *args, **kwargs
)
return d3fdx3(v1, v2, v3)
def _compute_jvp(self, v, *args, **kwargs):
return self.compute_jvp(self._fun, self.argnum, v, *args, **kwargs)
def _jac_looped(self, *args, **kwargs):
n = args[self._argnum].size
shp = jax.eval_shape(self._fun, *args).shape
I = jnp.eye(n)
J = jnp.zeros((*shp, n)).T
def body(i, J):
tangents = I[i]
Ji = self._compute_jvp(tangents, *args, **kwargs)
J = put(J, i, Ji.T)
return J
return fori_loop(0, n, body, J).T
def _set_mode(self, mode) -> None:
if mode not in ["fwd", "rev", "grad", "hess", "jvp", "looped"]:
raise ValueError(
colored("invalid mode option for automatic differentiation", "red")
)
self._mode = mode
if self._mode == "fwd":
self._compute = jax.jacfwd(self._fun, self._argnum)
elif self._mode == "rev":
self._compute = jax.jacrev(self._fun, self._argnum)
elif self._mode == "grad":
self._compute = jax.grad(self._fun, self._argnum)
elif self._mode == "hess":
self._compute = jax.hessian(self._fun, self._argnum)
elif self._mode == "jvp":
self._compute = self._compute_jvp
elif self._mode == "looped":
self._compute = self._jac_looped
class FiniteDiffDerivative(_Derivative):
"""Computes derivatives using 2nd order centered finite differences.
Parameters
----------
fun : callable
Function to be differentiated.
argnum : int, optional
Specifies which positional argument to differentiate with respect to
mode : str, optional
Automatic differentiation mode.
One of ``'fwd'`` (forward mode Jacobian), ``'rev'`` (reverse mode Jacobian),
``'grad'`` (gradient of a scalar function),
``'hess'`` (Hessian of a scalar function),
or ``'jvp'`` (Jacobian vector product)
Default = ``'fwd'``
rel_step : float, optional
Relative step size: dx = max(1, abs(x))*rel_step
Default = 1e-3
"""
def __init__(self, fun, argnum=0, mode="fwd", rel_step=1e-3, **kwargs):
self._fun = fun
self._argnum = argnum
self.rel_step = rel_step
self._set_mode(mode)
def _compute_hessian(self, *args, **kwargs):
"""Compute the Hessian matrix using 2nd order centered finite differences.
Parameters
----------
args : tuple
Arguments of the objective function where the derivative is to be
evaluated at.
kwargs : dict
keyword arguments passed to fun
Returns
-------
H : ndarray of float, shape(len(x),len(x))
d^2f/dx^2, where f is the output of the function fun and x is the input
argument at position argnum.
"""
def f(x):
tempargs = args[0 : self._argnum] + (x,) + args[self._argnum + 1 :]
return self._fun(*tempargs, **kwargs)
x = np.atleast_1d(args[self._argnum])
n = len(x)
fx = f(x)
h = np.maximum(1.0, np.abs(x)) * self.rel_step
ee = np.diag(h)
hess = np.outer(h, h)
for i in range(n):
eei = ee[i, :]
hess[i, i] = (f(x + 2 * eei) - 2 * fx + f(x - 2 * eei)) / (4.0 * hess[i, i])
for j in range(i + 1, n):
eej = ee[j, :]
hess[i, j] = (
f(x + eei + eej)
- f(x + eei - eej)
- f(x - eei + eej)
+ f(x - eei - eej)
) / (4.0 * hess[j, i])
hess[j, i] = hess[i, j]
return hess
def _compute_grad_or_jac(self, *args, **kwargs):
"""Compute the gradient or Jacobian matrix (ie, first derivative).
Parameters
----------
args : tuple
Arguments of the objective function where the derivative is to be
evaluated at.
kwargs : dict
keyword arguments passed to fun
Returns
-------
J : ndarray of float, shape(len(f),len(x))
df/dx, where f is the output of the function fun and x is the input
argument at position argnum.
"""
def f(x):
tempargs = args[0 : self._argnum] + (x,) + args[self._argnum + 1 :]
return self._fun(*tempargs, **kwargs)
x0 = np.atleast_1d(args[self._argnum])
f0 = f(x0)
m = f0.size
n = x0.size
J = np.zeros((m, n))
h = np.maximum(1.0, np.abs(x0)) * self.rel_step
h_vecs = np.diag(np.atleast_1d(h))
for i in range(n):
x1 = x0 - h_vecs[i]
x2 = x0 + h_vecs[i]
dx = x2[i] - x1[i]
f1 = f(x1)
f2 = f(x2)
df = f2 - f1
dfdx = df / dx
J = put(J.T, i, dfdx.flatten()).T
if m == 1:
J = np.ravel(J)
return J
@classmethod
def compute_vjp(cls, fun, argnum, v, *args, **kwargs):
"""Compute v.T * df/dx.
Parameters
----------
fun : callable
function to differentiate
argnum : int or tuple
arguments to differentiate with respect to
v : array-like or tuple of array-like
tangent vectors. Should be one for each output of fun
args : tuple
arguments passed to fun
kwargs : dict
keyword arguments passed to fun
Returns
-------
vjp : array-like
Vector v times Jacobian, summed over different argnums
"""
assert np.isscalar(argnum), "vjp for multiple args not currently supported"
rel_step = kwargs.pop("rel_step", 1e-3)
def _fun(*args):
return v.T @ fun(*args, **kwargs)
return FiniteDiffDerivative(_fun, argnum, "grad", rel_step)(*args)
@classmethod
def compute_jvp(cls, fun, argnum, v, *args, **kwargs):
"""Compute df/dx*v.
Parameters
----------
fun : callable
function to differentiate
argnum : int or tuple
arguments to differentiate with respect to
v : array-like or tuple of array-like
tangent vectors. Should be one for each argnum
args : tuple
arguments passed to fun
kwargs : dict
keyword arguments passed to fun
Returns
-------
jvp : array-like
Jacobian times vectors v, summed over different argnums
"""
rel_step = kwargs.pop("rel_step", 1e-3)
if np.isscalar(argnum):
nargs = 1
argnum = (argnum,)
else:
nargs = len(argnum)
v = (v,) if not isinstance(v, tuple) else v
f = np.array(
[
cls._compute_jvp_1arg(
fun, argnum[i], v[i], *args, rel_step=rel_step, **kwargs
)
for i in range(nargs)
]
)
return np.sum(f, axis=0)
@classmethod
def compute_jvp2(cls, fun, argnum1, argnum2, v1, v2, *args, **kwargs):
"""Compute d^2f/dx^2*v1*v2.
Parameters
----------
fun : callable
function to differentiate
argnum1, argnum2 : int or tuple of int
arguments to differentiate with respect to. First entry corresponds to v1,
second to v2
v1,v2 : array-like or tuple of array-like
tangent vectors. Should be one for each argnum
args : tuple
arguments passed to fun
kwargs : dict
keyword arguments passed to fun
Returns
-------
jvp2 : array-like
second derivative times vectors v1, v2, summed over different argnums
"""
if np.isscalar(argnum1):
v1 = (v1,) if not isinstance(v1, tuple) else v1
argnum1 = (argnum1,)
else:
v1 = tuple(v1)
if np.isscalar(argnum2):
argnum2 = (argnum2 + 1,)
v2 = (v2,) if not isinstance(v2, tuple) else v2
else:
argnum2 = tuple([i + 1 for i in argnum2])
v2 = tuple(v2)
dfdx = lambda dx1, *args: cls.compute_jvp(fun, argnum1, dx1, *args, **kwargs)
d2fdx2 = lambda dx1, dx2: cls.compute_jvp(
dfdx, argnum2, dx2, dx1, *args, **kwargs
)
return d2fdx2(v1, v2)
@classmethod
def compute_jvp3(cls, fun, argnum1, argnum2, argnum3, v1, v2, v3, *args, **kwargs):
"""Compute d^3f/dx^3*v1*v2*v3.
Parameters
----------
fun : callable
function to differentiate
argnum1, argnum2, argnum3 : int or tuple of int
arguments to differentiate with respect to. First entry corresponds to v1,
second to v2 etc
v1,v2,v3 : array-like or tuple of array-like
tangent vectors. Should be one for each argnum
args : tuple
arguments passed to fun
kwargs : dict
keyword arguments passed to fun
Returns
-------
jvp3 : array-like
third derivative times vectors v2, v3, v3, summed over different argnums
"""
if np.isscalar(argnum1):
v1 = (v1,) if not isinstance(v1, tuple) else v1
argnum1 = (argnum1,)
else:
v1 = tuple(v1)
if np.isscalar(argnum2):
argnum2 = (argnum2 + 1,)
v2 = (v2,) if not isinstance(v2, tuple) else v2
else:
argnum2 = tuple([i + 1 for i in argnum2])
v2 = tuple(v2)
if np.isscalar(argnum3):
argnum3 = (argnum3 + 2,)
v3 = (v3,) if not isinstance(v3, tuple) else v3
else:
argnum3 = tuple([i + 2 for i in argnum3])
v3 = tuple(v3)
dfdx = lambda dx1, *args: cls.compute_jvp(fun, argnum1, dx1, *args, **kwargs)
d2fdx2 = lambda dx1, dx2, *args: cls.compute_jvp(
dfdx, argnum2, dx2, dx1, *args, **kwargs
)
d3fdx3 = lambda dx1, dx2, dx3: cls.compute_jvp(
d2fdx2, argnum3, dx3, dx2, dx1, *args, **kwargs
)
return d3fdx3(v1, v2, v3)
def _compute_jvp(self, v, *args, **kwargs):
return self.compute_jvp(
self._fun, self._argnum, v, *args, rel_step=self.rel_step, **kwargs
)
@classmethod
def _compute_jvp_1arg(cls, fun, argnum, v, *args, **kwargs):
"""Compute a jvp wrt a single argument."""
rel_step = kwargs.pop("rel_step", 1e-3)
normv = np.linalg.norm(v)
if normv != 0:
vh = v / normv
else:
vh = v
x = args[argnum]
def f(x):
tempargs = args[0:argnum] + (x,) + args[argnum + 1 :]
return fun(*tempargs, **kwargs)
h = rel_step
df = (f(x + h * vh) - f(x - h * vh)) / (2 * h)
return df * normv
def _set_mode(self, mode):
if mode not in ["fwd", "rev", "grad", "hess", "jvp"]:
raise ValueError(
colored(
"invalid mode option for finite difference differentiation", "red"
)
)
self._mode = mode
if self._mode == "fwd":
self._compute = self._compute_grad_or_jac
elif self._mode == "rev":
self._compute = self._compute_grad_or_jac
elif self._mode == "grad":
self._compute = self._compute_grad_or_jac
elif self._mode == "hess":
self._compute = self._compute_hessian
elif self._mode == "jvp":
self._compute = self._compute_jvp
def compute(self, *args, **kwargs):
"""Compute the derivative matrix.
Parameters
----------
*args : list
Arguments of the objective function where the derivative is to be
evaluated at.
Returns
-------
D : ndarray of float
derivative of f evaluated at x, where f is the output of the function
fun and x is the input argument at position argnum. Exact shape and meaning
will depend on "mode"
"""
return self._compute(*args, **kwargs)
Derivative = AutoDiffDerivative if use_jax else FiniteDiffDerivative