# Perf counters for DMD applied directly and using FJLT

In [None]:
from tqdm import tqdm
import sys
import gc
import matplotlib.pyplot as plt
from time import perf_counter
sys.path.append('../../../utils')
from TurboFJLT import *
from TurboFJLT_helpers import FJLT, TurboHDF5Reader
%config InlineBackend.figure_format='retina'

In [None]:
%config InlineBackend.figure_format='retina'
plt.style.use("../../../mplstyles/paper_half.mplstyle")

In [None]:
timing_averaging_iterations = 5

In [None]:
datafile = "../../../data/fine_airfoil_cascade.h5"
reader = TurboHDF5Reader(datafile)
print(reader)

### Extract the data

In [None]:
num_snapshots = 300
snapshot_sequence = list(range(num_snapshots))

In [None]:
def extract_dmd_result(dmd_result):
    num_non_zero_amps = np.count_nonzero(dmd_result.amplitudes)
    print("Non-zero amplitudes: ", num_non_zero_amps)
    amps = np.zeros(num_non_zero_amps, dtype="complex128")
    eigs = np.zeros(num_non_zero_amps, dtype="complex128")
    modes = np.zeros((dmd_result.modes.shape[0], eigs.shape[0]), dtype="complex128")
    count = 0
    for eig, amp, mode in zip(dmd_result.eigs, dmd_result.amplitudes, dmd_result.modes.T):
        if np.abs(amp) != 0:
            amps[count] = amp
            eigs[count] = eig
            modes[:, count] = mode.T
            count += 1
    return amps, eigs, modes

In [None]:
def formQ(reader, seq_to_extract):
    q_mf = reader.load_meanflow()
    num_dofs = reader.state_dim
    Q = np.zeros((num_dofs, len(list(seq_to_extract))))
    reader.reset_chunked_loading(seq_to_extract, chunks_dim=50)
    for i, _ in enumerate(tqdm(seq_to_extract)):
        Q[:, i] = reader.load_next()-q_mf
    return Q

In [None]:
def extract_DMD(Q, gamma=600):
    t0 = perf_counter()
    dmd_sol = SpDMD(svd_rank=50,gamma=gamma,rho=1.e4).fit(Q)
    t1 = perf_counter()
    dt = t1 - t0
    return dmd_sol, dt

In [None]:
Q = formQ(reader, snapshot_sequence)
dt_direct = 0
for _ in range(timing_averaging_iterations):
    dmd_Q, dt = extract_DMD(Q)
    dt_direct += dt
dt_direct /= timing_averaging_iterations
print("Direct SpDMD application time: {}s".format(dt_direct))
del Q
gc.collect()

amps, eigs, modes = extract_dmd_result(dmd_Q)

In [None]:
def formB(reader, fjlt, seq_to_extract):
    q_mf = reader.load_meanflow()
    B = np.zeros((fjlt.embedding_dim, len(list(seq_to_extract))))
    reader.reset_chunked_loading(seq_to_extract, chunks_dim=50)
    for i, _ in enumerate(tqdm(seq_to_extract)):
        q_ss = reader.load_next()-q_mf
        B[:, i] = applyFJLT(q_ss, fjlt.P, fjlt.s, fjlt.D)
    return B

In [None]:
def fjlt_dmd(reader, num_linking_snapshots, snapshots_to_extract):
    fjlt = FJLT(reader.state_dim, num_linking_snapshots, 0.01)
    B = formB(reader, fjlt, snapshots_to_extract)
    dt_fjlt = 0
    for _ in range(timing_averaging_iterations):
        dmd_B, dt = extract_DMD(B)
        dt_fjlt += dt
    dt_fjlt /= timing_averaging_iterations
    del B
    gc.collect()
    return dmd_B, dt_fjlt

In [None]:
num_linking_snapshots = [2, 4, 6, 8, 12, 16, 24, 32]
perf_counters = []

for n_sp in num_linking_snapshots:
    dmd_B, dt_fjlt = fjlt_dmd(reader, n_sp, snapshot_sequence)
    perf_counters.append(dt_fjlt)
    print("FJLT ({}, snapshots) SpDMD application time: {}s".format(n_sp, dt_fjlt))

    amps, eigs, modes = extract_dmd_result(dmd_B)

    # Garbage collection for memory
    del amps
    del eigs
    del modes
    gc.collect()

### Plot the timing results

In [None]:
fig, ax = plt.subplots(ncols=1, nrows=1)
ax.plot(num_linking_snapshots, perf_counters, marker='o')
ax.set_xlabel(r"N")
ax.set_ylabel(r"Time (s)")

In [None]:
print(num_linking_snapshots)
print(perf_counters)