Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

GLAPPO #299

Open
wants to merge 185 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 139 commits
Commits
Show all changes
185 commits
Select commit Hold shift + click to select a range
38a2076
bugfix
fandreuz Jan 1, 2023
81e45ca
as_numpy_array -> as_array
fandreuz Dec 3, 2022
68d5cab
generic linalg draft
fandreuz Dec 4, 2022
b2bfede
generify
fandreuz Dec 4, 2022
71c2114
adapt CDMD to generic linalg
fandreuz Dec 4, 2022
0d93d67
support multiple scipy.sparse versions
fandreuz Dec 4, 2022
66c1ec1
add missing decorator
fandreuz Dec 6, 2022
ffae98a
no need to allocated linalg classes
fandreuz Dec 6, 2022
8f71b25
dot()
fandreuz Dec 6, 2022
6066d3c
minor fix
fandreuz Dec 9, 2022
6f8bf39
pinv. minor fixes
fandreuz Dec 9, 2022
17c3c2f
generic linalg
fandreuz Dec 9, 2022
0cbfa8e
code cleanup
fandreuz Dec 9, 2022
6146a8d
fix imports
fandreuz Dec 29, 2022
71cb291
fix warning
fandreuz Dec 29, 2022
f8aac30
fix some tests
fandreuz Dec 30, 2022
8fadd3e
fix bug
fandreuz Dec 30, 2022
a645c78
guidelines draft
fandreuz Dec 30, 2022
6fcf7c3
remove usages of @ instead of dot
fandreuz Dec 30, 2022
4393634
do not check in internal code
fandreuz Dec 30, 2022
080f3f6
remove append. add *stack
fandreuz Dec 30, 2022
0ca6a29
f-strings. add missing import
fandreuz Dec 30, 2022
f5ebb1f
improved to()
fandreuz Dec 30, 2022
0377820
improve logging
fandreuz Dec 30, 2022
d08c1b1
force conversion to complex in torch dot/multi_dot
fandreuz Dec 30, 2022
418b0f4
cleanup
fandreuz Dec 30, 2022
1a4a72b
port
fandreuz Dec 30, 2022
fa31bc8
cleanup
fandreuz Dec 30, 2022
a3d8a29
remove not implemented methods
fandreuz Dec 30, 2022
030dfbd
improve full
fandreuz Dec 30, 2022
0b78beb
split
fandreuz Dec 30, 2022
14ec47c
cleanup
fandreuz Dec 30, 2022
4481db0
nan methods. bugfix
fandreuz Dec 30, 2022
2b4dba6
port
fandreuz Dec 30, 2022
5c684c2
port
fandreuz Dec 30, 2022
a677472
update tests
fandreuz Dec 30, 2022
9e73456
minor fix
fandreuz Dec 30, 2022
2210950
fix tests
fandreuz Dec 30, 2022
d66b882
optimise imports
fandreuz Dec 30, 2022
5d93359
qol
fandreuz Jan 1, 2023
dbdeb05
implement argsort
fandreuz Jan 1, 2023
45704d0
code cleanup
fandreuz Jan 1, 2023
970240b
improve lstsq
fandreuz Jan 1, 2023
cc799a7
atleast_1d
fandreuz Jan 1, 2023
3f1d5b5
code cleanup
fandreuz Jan 1, 2023
cb026b4
as_numpy_array -> as_array
fandreuz Jan 1, 2023
469e792
col_major_2d -> utils.py
fandreuz Jan 1, 2023
3316351
missing underscore
fandreuz Jan 1, 2023
df493d9
move plot code
fandreuz Jan 1, 2023
ec71614
force invalidation of activation bitmask after computation of low-ran…
fandreuz Jan 1, 2023
b9062fd
Merge branch 'plotter'
fandreuz Jan 2, 2023
e2e777c
Merge branch 'master' into generic-linalg
fandreuz Jan 2, 2023
27088b0
cat
fandreuz Jan 2, 2023
bbf0949
param ids
fandreuz Jan 2, 2023
73a6e6d
cleanup
fandreuz Jan 2, 2023
350dd8d
fix vander
fandreuz Jan 2, 2023
073fa61
force invalidation of activation bitmask after computation of low-ran…
fandreuz Jan 1, 2023
afdd815
Merge branch 'fix-303' into generic-linalg
fandreuz Jan 2, 2023
b6cba6c
force invalidation of activation bitmask after computation of low-ran…
fandreuz Jan 1, 2023
dcc107c
Merge branch 'fix-303' into generic-linalg
fandreuz Jan 2, 2023
024c776
reset
fandreuz Jan 2, 2023
1ecf6b2
tests
fandreuz Jan 2, 2023
241207c
Merge branch 'fix-303' into generic-linalg
fandreuz Jan 2, 2023
15b1bbc
reset
fandreuz Jan 2, 2023
66faa3b
improve new_array
fandreuz Jan 2, 2023
8021a89
add tests for backprop
fandreuz Jan 2, 2023
44b84b3
adjust logging level
fandreuz Jan 3, 2023
0e70d4a
fix device
fandreuz Jan 3, 2023
d7ac480
cleanup
fandreuz Jan 4, 2023
f3867cf
initial support for batched DMD
fandreuz Jan 3, 2023
cc03f9c
hankel-batched
fandreuz Jan 3, 2023
cedd715
remove diag ambiguity
fandreuz Jan 3, 2023
519fa34
fix lstsq backprop batched
fandreuz Jan 3, 2023
710b4d5
compute phase space prediction
fandreuz Jan 5, 2023
fd85502
generic device extraction. arange on appropriate device
fandreuz Jan 5, 2023
abaac75
cleanup tests. unhide hidden tests. qol
fandreuz Jan 6, 2023
e91dcea
Merge branch 'test-cleanup' into generic-linalg
fandreuz Jan 6, 2023
d84f83c
big tests cleanup
fandreuz Jan 6, 2023
8d87f29
pinpoint svd_rank_extra
fandreuz Jan 6, 2023
b7d52fe
cleanup
fandreuz Jan 6, 2023
99f1d28
fix sqrt
fandreuz Jan 6, 2023
1a70132
fix test
fandreuz Jan 6, 2023
02adb63
fix test
fandreuz Jan 6, 2023
2c28156
fix matrix sqrt. fix fbdmd tests
fandreuz Jan 7, 2023
95286f3
tensorize
fandreuz Jan 7, 2023
6911156
code cleanup and hankel bugfix
fandreuz Jan 7, 2023
d1fe8db
fix tests
fandreuz Jan 7, 2023
421a518
tensorize
fandreuz Jan 7, 2023
bd38d20
more tests. fix tests
fandreuz Jan 7, 2023
4bc484d
add comments. fix exception msg
fandreuz Jan 7, 2023
98d36a2
cleanup
fandreuz Jan 7, 2023
86a088a
more cleanup
fandreuz Jan 8, 2023
d618d2a
fix tests. code cleanup
fandreuz Jan 8, 2023
a76c64d
fix for tensorized
fandreuz Jan 8, 2023
4ea41c3
torch.autograd.gradcheck()
fandreuz Jan 8, 2023
9b2b048
reject pytorch
fandreuz Jan 8, 2023
d6f76f1
port subspacedmd
fandreuz Jan 9, 2023
cd64068
general cleanup
fandreuz Jan 9, 2023
4b0f10e
fix subspacedmd for tensorized. fix tests
fandreuz Jan 9, 2023
7c2b3d9
fix symbols
fandreuz Jan 9, 2023
e870a64
qr and random
fandreuz Jan 10, 2023
1be15c7
port RDMD
fandreuz Jan 10, 2023
62131ca
fixed tensorized training. tests
fandreuz Jan 10, 2023
b8a7f24
Merge branch 'master' into generic-linalg
fandreuz Jan 14, 2023
1c4f095
feasible
fandreuz Jan 14, 2023
a2893f2
remove gradcheck by default
fandreuz Jan 14, 2023
afeab3e
batch
fandreuz Jan 14, 2023
61069f3
remove as not needed
fandreuz Jan 14, 2023
f5da174
fix
fandreuz Jan 14, 2023
69359ec
cleanup
fandreuz Jan 14, 2023
5137016
fix msg
fandreuz Jan 14, 2023
ccb95f8
fix CI workflow name. show appropriate badge
fandreuz Jan 13, 2023
3cf0de8
add condition number check
fandreuz Jan 14, 2023
d0045de
disable torch
fandreuz Jan 16, 2023
d0c5f3d
port
fandreuz Jan 17, 2023
a873dde
cleanup
fandreuz Jan 18, 2023
fff605e
fix mrdmd plotter tests
fandreuz Jan 18, 2023
0f2db0e
remove old vals change forwarding
fandreuz Jan 18, 2023
9c0669d
add more info
fandreuz Jan 19, 2023
c8f9ef2
fix alignment
fandreuz Jan 19, 2023
07b118a
wishlist
fandreuz Jan 19, 2023
24af593
fix formatting
fandreuz Jan 19, 2023
dc10727
colon
fandreuz Jan 19, 2023
26f4fbd
fix formatting
fandreuz Jan 19, 2023
bc6564b
lighter phase space. matrix_norm
fandreuz Jan 21, 2023
89614af
Merge branch 'master' into generic-linalg
fandreuz Jan 26, 2023
cd6b25e
Python 3.7 compatible prod
fandreuz Jan 26, 2023
13baff2
multi_dot
fandreuz Feb 4, 2023
bf00cb6
getsvda
fandreuz Feb 4, 2023
00e0176
fix import
fandreuz Feb 4, 2023
499fe21
Merge remote-tracking branch 'upstream/master' into generic-linalg
fandreuz Feb 4, 2023
daf1a6a
Revert "getsvda"
fandreuz Feb 4, 2023
abf9e4b
fix import problems
fandreuz Feb 4, 2023
b73aeab
fix import
fandreuz Feb 4, 2023
5a9f837
fix
fandreuz Feb 5, 2023
67ebb0e
fix cond
fandreuz Feb 5, 2023
27a47f9
enable setting SVD driver
fandreuz Feb 5, 2023
2f1afee
enable/disable cond check
fandreuz Feb 5, 2023
1faa539
Merge branch 'master' into generic-linalg
fandreuz May 29, 2023
6a53e49
ops
fandreuz Jun 5, 2023
e226aad
Merge branch 'master' into generic-linalg
fandreuz Jun 5, 2023
6ecdb0b
Merge branch 'master' into generic-linalg
fandreuz Jul 28, 2023
af61add
black
fandreuz Jul 28, 2023
8a9a43a
cleanup
fandreuz Jul 28, 2023
252a4de
update TODO
fandreuz Jul 28, 2023
0662b5d
link to glappo repo
fandreuz Jul 30, 2023
1a6a3fa
Merge branch 'master' into generic-linalg
fandreuz Aug 1, 2023
ae55e4f
remove
fandreuz Aug 1, 2023
b76cc6b
black
fandreuz Aug 1, 2023
a395c11
move fit_reconstruct
fandreuz Aug 1, 2023
388941d
black
fandreuz Aug 1, 2023
9a1076c
remove
fandreuz Aug 1, 2023
da42578
fix
fandreuz Aug 1, 2023
8f624d7
fix
fandreuz Aug 1, 2023
ad3e90e
bugfix
fandreuz Aug 1, 2023
d97d967
bugfix
fandreuz Aug 1, 2023
3ba5eaf
don't know what happened here
fandreuz Aug 1, 2023
b13d6b0
remove backprop tests
fandreuz Aug 1, 2023
db45633
remove mrdmd stuff
fandreuz Aug 1, 2023
09431ea
param not used
fandreuz Aug 1, 2023
cea040e
ops. new tests
fandreuz Aug 1, 2023
299fbef
remove
fandreuz Aug 1, 2023
6dc0075
code cleanup
fandreuz Aug 1, 2023
16c25a5
minor fixes
fandreuz Aug 2, 2023
4c99592
restore
fandreuz Aug 2, 2023
af8f97e
cleanup
fandreuz Aug 2, 2023
e245b56
code cleanup
fandreuz Aug 2, 2023
ea73a50
code cleanup
fandreuz Aug 2, 2023
7399ff5
bugfix
fandreuz Aug 2, 2023
8f61dc2
nonzero, logical_and
fandreuz Aug 1, 2023
09a913d
generic_linalg
fandreuz Aug 3, 2023
16a28fd
ops
fandreuz Aug 3, 2023
8e6fca0
todo comment
fandreuz Aug 3, 2023
50410c2
unused stuff
fandreuz Aug 3, 2023
e994455
typo
fandreuz Aug 3, 2023
f8c09cd
more tests
fandreuz Aug 3, 2023
43d46be
move
fandreuz Aug 3, 2023
7bb8fde
readme
fandreuz Aug 3, 2023
7e107b1
move
fandreuz Aug 3, 2023
9313503
fix imports
fandreuz Aug 4, 2023
494d8df
no batching
fandreuz Aug 4, 2023
47869c9
dunnoh
fandreuz Aug 11, 2023
7d365aa
Merge branch 'master' into generic-linalg
fandreuz Aug 26, 2023
e50c6eb
isort
fandreuz Aug 26, 2023
d2938d6
black
fandreuz Aug 26, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 19 additions & 11 deletions pydmd/cdmd.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,12 @@

import numpy as np
import scipy.sparse
from scipy.linalg import sqrtm

from .dmdbase import DMDBase
from .dmdoperator import DMDOperator
from .snapshots import Snapshots
from pydmd.linalg import build_linalg_module
from .utils import compute_svd, compute_tlsq
from .snapshots import Snapshots


class CDMDOperator(DMDOperator):
Expand Down Expand Up @@ -78,6 +78,7 @@ def compute_operator(self, compressedX, compressedY, nonCompressedY):
:rtype: numpy.ndarray, numpy.ndarray, numpy.ndarray
"""

linalg_module = build_linalg_module(compressedX)
U, s, V = compute_svd(compressedX, svd_rank=self._svd_rank)

atilde = self._least_square_operator(U, s, V, compressedY)
Expand All @@ -86,7 +87,9 @@ def compute_operator(self, compressedX, compressedY, nonCompressedY):
# b stands for "backward"
bU, bs, bV = compute_svd(compressedY, svd_rank=self._svd_rank)
atilde_back = self._least_square_operator(bU, bs, bV, compressedX)
atilde = sqrtm(atilde.dot(np.linalg.inv(atilde_back)))
atilde_back_inv = linalg_module.inv(atilde_back)
atilde_dotted = linalg_module.dot(atilde, atilde_back_inv)
atilde = linalg_module.matrix_sqrt(atilde_dotted)

self._Atilde = atilde
self._compute_eigenquantities()
Expand Down Expand Up @@ -196,7 +199,7 @@ def _compress_snapshots(self):
:rtype: numpy.ndarray
"""

C_shape = (self.snapshots.shape[1], self.snapshots.shape[0])
C_shape = (self.snapshots.shape[-1], self.snapshots.shape[-2])
if isinstance(self.compression_matrix, np.ndarray):
C = self.compression_matrix
elif self.compression_matrix == "uniform":
Expand All @@ -212,29 +215,34 @@ def _compress_snapshots(self):
np.random.choice(*self.snapshots.shape, replace=False),
] = 1.0

linalg_module = build_linalg_module(self.snapshots)
C = linalg_module.to(self.snapshots, C)

# compress the matrix
Y = C.dot(self.snapshots)
Y = linalg_module.dot(C, self.snapshots)

return Y

def fit(self, X):
def fit(self, X, batch=False):
"""
Compute the Dynamic Modes Decomposition to the input data.

:param X: the input snapshots.
:type X: numpy.ndarray or iterable
:param batch: If `True`, the first dimension is dedicated to batching.
:type batch: bool
"""
self._reset()

self._snapshots_holder = Snapshots(X)
self._snapshots_holder = Snapshots(X, batch=batch)
compressed_snapshots = self._compress_snapshots()

n_samples = compressed_snapshots.shape[1]
X = compressed_snapshots[:, :-1]
Y = compressed_snapshots[:, 1:]
n_samples = compressed_snapshots.shape[-1]
X = compressed_snapshots[..., :-1]
Y = compressed_snapshots[..., 1:]

X, Y = compute_tlsq(X, Y, self._tlsq_rank)
self.operator.compute_operator(X, Y, self.snapshots[:, 1:])
self.operator.compute_operator(X, Y, self.snapshots[..., 1:])

# Default timesteps
self._set_initial_time_dictionary(
Expand Down
24 changes: 14 additions & 10 deletions pydmd/dmd.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,8 @@
Derived module from dmdbase.py for classic dmd.
"""

import numpy as np
from scipy.linalg import pinv

from .dmdbase import DMDBase
from pydmd.linalg import assert_same_linalg_type, build_linalg_module
from .snapshots import Snapshots
from .utils import compute_tlsq

Expand Down Expand Up @@ -46,20 +44,22 @@ class DMD(DMDBase):
:type tikhonov_regularization: int or float
"""

def fit(self, X):
def fit(self, X, batch=False):
"""
Compute the Dynamic Modes Decomposition to the input data.

:param X: the input snapshots.
:type X: numpy.ndarray or iterable
:param batch: If `True`, the first dimension is dedicated to batching.
:type batch: bool
"""
self._reset()

self._snapshots_holder = Snapshots(X)
self._snapshots_holder = Snapshots(X, batch=batch)

n_samples = self.snapshots.shape[1]
X = self.snapshots[:, :-1]
Y = self.snapshots[:, 1:]
n_samples = self.snapshots.shape[-1]
X = self.snapshots[..., :-1]
Y = self.snapshots[..., 1:]

X, Y = compute_tlsq(X, Y, self._tlsq_rank)
self._svd_modes, _, _ = self.operator.compute_operator(X, Y)
Expand All @@ -81,6 +81,10 @@ def predict(self, X):
:return: one time-step ahead predicted output.
:rtype: numpy.ndarray
"""
return np.linalg.multi_dot(
[self.modes, np.diag(self.eigs), pinv(self.modes), X]
assert_same_linalg_type(X, self.modes)

linalg_module = build_linalg_module(X)
return linalg_module.multi_dot(
(self.modes, linalg_module.diag_matrix(self.eigs),
linalg_module.pinv(self.modes), X)
)
5 changes: 0 additions & 5 deletions pydmd/dmd_modes_tuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@ def select_modes(
criteria,
in_place=True,
return_indexes=False,
nullify_amplitudes=False,
):
"""
Select the DMD modes by using the given `criteria`.
Expand Down Expand Up @@ -45,10 +44,6 @@ def select_modes(
:param bool return_indexes: If `True`, this function returns the indexes
corresponding to DMD modes cut using the given `criteria` (default
`False`).
:param bool nullify_amplitudes: If `True`, the amplitudes associated with
DMD modes to be removed are set to 0, therefore the number of DMD
modes remains constant. If `False` (default) DMD modes are actually
removed, therefore the number of DMD modes in the instance decreases.
:returns: If `return_indexes` is `True`, the returned value is a tuple
whose items are:

Expand Down
110 changes: 55 additions & 55 deletions pydmd/dmdbase.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,12 @@

import pickle
from copy import copy, deepcopy
from os.path import splitext

import numpy as np
from past.utils import old_div
fandreuz marked this conversation as resolved.
Show resolved Hide resolved

from .dmdoperator import DMDOperator
from pydmd.linalg import build_linalg_module
from .utils import compute_svd


Expand Down Expand Up @@ -39,12 +40,13 @@ class ActivationBitmaskProxy:
"""

def __init__(self, dmd_operator, amplitudes):
linalg_module = build_linalg_module(dmd_operator.modes)
self._original_modes = dmd_operator.modes
self._original_eigs = np.atleast_1d(dmd_operator.eigenvalues)
self._original_amplitudes = np.atleast_1d(amplitudes)
self._original_eigs = linalg_module.atleast_1d(dmd_operator.eigenvalues)
self._original_amplitudes = linalg_module.atleast_1d(amplitudes)

self.old_bitmask = None
self.change_bitmask(np.full(len(dmd_operator.eigenvalues), True))
self.change_bitmask(np.full(dmd_operator.eigenvalues.shape[-1], True))

def change_bitmask(self, value):
"""
Expand All @@ -58,16 +60,9 @@ def change_bitmask(self, value):
`bool` whose size is the same of the number of DMD modes.
:type value: np.ndarray
"""

# apply changes made on the proxied values to the original values
if self.old_bitmask is not None:
self._original_modes[:, self.old_bitmask] = self.modes
self._original_eigs[self.old_bitmask] = self.eigs
self._original_amplitudes[self.old_bitmask] = self.amplitudes

self._modes = np.array(self._original_modes)[:, value]
self._eigs = np.array(self._original_eigs)[value]
self._amplitudes = np.array(self._original_amplitudes)[value]
self._modes = self._original_modes[..., value]
self._eigs = self._original_eigs[..., value]
self._amplitudes = self._original_amplitudes[..., value]

self.old_bitmask = value

Expand Down Expand Up @@ -205,10 +200,12 @@ def dmd_timesteps(self):
:return: the time intervals of the original snapshots.
:rtype: numpy.ndarray
"""
return np.arange(
linalg_module = build_linalg_module(self.eigs)
return linalg_module.arange(
self.dmd_time["t0"],
self.dmd_time["tend"] + self.dmd_time["dt"],
self.dmd_time["dt"],
device=linalg_module.device(self.snapshots),
)

@property
Expand All @@ -219,10 +216,12 @@ def original_timesteps(self):
:return: the time intervals of the original snapshots.
:rtype: numpy.ndarray
"""
return np.arange(
linalg_module = build_linalg_module(self.eigs)
return linalg_module.arange(
self.original_time["t0"],
self.original_time["tend"] + self.original_time["dt"],
self.original_time["dt"],
device=linalg_module.device(self.snapshots),
)

@property
Expand Down Expand Up @@ -285,12 +284,14 @@ def dynamics(self):
row.
:rtype: numpy.ndarray
"""
temp = np.repeat(
self.eigs[:, None], self.dmd_timesteps.shape[0], axis=1
linalg_module = build_linalg_module(self.eigs)
temp = linalg_module.repeat(
self.eigs[..., None], self.dmd_timesteps.shape[0], axis=-1
)
tpow = old_div(
self.dmd_timesteps - self.original_time["t0"],
self.original_time["dt"],
)
tpow = (
self.dmd_timesteps - self.original_time["t0"]
) / self.original_time["dt"]

# The new formula is x_(k+j) = \Phi \Lambda^k \Phi^(-1) x_j.
# Since j is fixed, for a given snapshot "u" we have the following
Expand All @@ -299,7 +300,7 @@ def dynamics(self):
# Therefore tpow must be scaled appropriately.
tpow = self._translate_eigs_exponent(tpow)

return np.power(temp, tpow) * self.amplitudes[:, None]
return linalg_module.pow(temp, tpow) * self.amplitudes[..., None]

def _translate_eigs_exponent(self, tpow):
"""
Expand Down Expand Up @@ -332,7 +333,8 @@ def reconstructed_data(self):
:return: the matrix that contains the reconstructed snapshots.
:rtype: numpy.ndarray
"""
return self.modes.dot(self.dynamics)
linalg_module = build_linalg_module(self.modes)
return linalg_module.dot(self.modes, self.dynamics)

@property
def snapshots(self):
Expand Down Expand Up @@ -366,7 +368,9 @@ def frequency(self):
:return: the array that contains the frequencies of the eigenvalues.
:rtype: numpy.ndarray
"""
return np.log(self.eigs).imag / (2 * np.pi * self.original_time["dt"])
linalg_module = build_linalg_module(self.eigs)
div = 2 * np.pi * self.original_time["dt"]
return linalg_module.log(self.eigs).imag / div

@property
def growth_rate(self): # To check
Expand Down Expand Up @@ -443,7 +447,7 @@ def modes_activation_bitmask(self):

bitmask = self._modes_activation_bitmask_proxy.old_bitmask
# make sure that the array is immutable
bitmask.flags.writeable = False
build_linalg_module(bitmask).make_not_writeable(bitmask)
return bitmask

@modes_activation_bitmask.setter
Expand All @@ -452,7 +456,6 @@ def modes_activation_bitmask(self, value):
if not self.fitted:
raise RuntimeError("This DMD instance has not been fitted yet.")

value = np.array(value)
if value.dtype != bool:
raise RuntimeError(
"Unxpected dtype, expected bool, got {}.".format(value.dtype)
Expand Down Expand Up @@ -659,36 +662,31 @@ def load(fname):

def _optimal_dmd_matrices(self):
# compute the vandermonde matrix
vander = np.vander(self.eigs, len(self.dmd_timesteps), True)
linalg_module = build_linalg_module(self.eigs)
vander = linalg_module.vander(self.eigs, len(self.dmd_timesteps), True)

P = np.multiply(
np.dot(self.modes.conj().T, self.modes),
np.conj(np.dot(vander, vander.conj().T)),
)
a = linalg_module.dot(self.modes.conj().swapaxes(-1, -2), self.modes)
b = linalg_module.dot(vander, vander.conj().swapaxes(-1, -2)).conj()
P = linalg_module.multiply_elementwise(a, b)

if self._exact:
q = np.conj(
np.diag(
np.linalg.multi_dot(
[vander, self.snapshots.conj().T, self.modes]
)
)
vsm = linalg_module.multi_dot(
(vander, self.snapshots.conj().swapaxes(-1, -2), self.modes)
)
q = linalg_module.extract_diagonal(vsm).conj()
else:
_, s, V = compute_svd(self.snapshots[:, :-1], self.modes.shape[-1])

q = np.conj(
np.diag(
np.linalg.multi_dot(
[
vander[:, :-1],
V,
np.diag(s).conj(),
self.operator.eigenvectors,
]
)
)
_, s, V = compute_svd(
self.snapshots[..., :-1], self.modes.shape[-1]
)

s_conj = linalg_module.diag_matrix(s).conj()
s_conj, V, vander = linalg_module.to(
self.operator.eigenvectors, s_conj, V, vander
)
vVse = linalg_module.multi_dot(
(vander[..., :-1], V, s_conj, self.operator.eigenvectors)
)
q = linalg_module.extract_diagonal(vVse).conj()

return P, q

Expand All @@ -710,20 +708,22 @@ def _compute_amplitudes(self):
Jovanovic et al. 2014, Sparsity-promoting dynamic mode decomposition,
https://hal-polytechnique.archives-ouvertes.fr/hal-00995141/document
"""
linalg_module = build_linalg_module(self.modes)
if isinstance(self._opt, bool) and self._opt:
# b optimal
a = np.linalg.solve(*self._optimal_dmd_matrices())
A, b = self._optimal_dmd_matrices()
a = linalg_module.solve(A, b)
else:
if isinstance(self._opt, bool):
amplitudes_snapshot_index = 0
else:
amplitudes_snapshot_index = self._opt

a = np.linalg.lstsq(
selected_snapshots = self.snapshots[..., amplitudes_snapshot_index]
a = linalg_module.lstsq(
self.modes,
self.snapshots.T[amplitudes_snapshot_index],
linalg_module.to(self.modes, selected_snapshots),
rcond=None,
)[0]
)

return a

Expand Down
Loading
Loading