# Compute the SVD directly and using FJLT

In [None]:
from tqdm import tqdm
import sys
sys.path.append('../../../utils')
import gc
from TurboFJLT import *
from TurboFJLT_helpers import FJLT, TurboHDF5Reader

In [None]:
datafile = "../../../data/fine_airfoil_cascade.h5"
reader = TurboHDF5Reader(datafile)
print(reader)

### Extract the data

In [None]:
num_snapshots = 500
snapshot_sequence = list(range(num_snapshots))

In [None]:
q_mf = reader.load_meanflow()

In [None]:
def formQ(reader, seq_to_extract):
    num_dofs = reader.state_dim
    Q = np.zeros((num_dofs, len(list(seq_to_extract))))
    reader.reset_chunked_loading(seq_to_extract, chunks_dim=50)
    for i, _ in enumerate(tqdm(seq_to_extract)):
        Q[:, i] = reader.load_next()-q_mf
    return Q

In [None]:
Q = formQ(reader, snapshot_sequence)
u, s, vh = np.linalg.svd(Q, full_matrices=False)

with h5.File("../data/direct_svd.h5", 'w') as f:
      f.create_dataset("/U", data=u)
      f.create_dataset("/s", data=s)
      f.create_dataset("/VH", data=vh)

In [None]:
# Garbage collection for memory
del Q
del u
del s
del vh
gc.collect()

In [None]:
def formB(reader, fjlt, seq_to_extract):
    num_dofs = reader.state_dim
    B = np.zeros((fjlt.embedding_dim, len(list(seq_to_extract))))
    reader.reset_chunked_loading(seq_to_extract, chunks_dim=50)
    for i, _ in enumerate(tqdm(seq_to_extract)):
        q_ss = reader.load_next()-q_mf
        B[:, i] = applyFJLT(q_ss, fjlt.P, fjlt.s, fjlt.D)
    return B

### Compute the SVD with the FJLT

In [None]:
def fjlt_svd(reader, num_linking_snapshots, snapshots_to_extract):
    fjlt = FJLT(reader.state_dim, num_linking_snapshots, 0.01)
    B = formB(reader, fjlt, snapshots_to_extract)
    fjlt_u, fjlt_s, fjlt_vh = np.linalg.svd(B, full_matrices=False)
    return fjlt_u, fjlt_s, fjlt_vh

In [None]:
num_linking_snapshots = [2, 4, 6, 8, 12, 16, 24, 32]
for n_sp in num_linking_snapshots:
    fjlt_u, fjlt_s, fjlt_vh = fjlt_svd(reader, n_sp, snapshot_sequence)
    with h5.File("../data/fjlt_svd_{}_linking_snapshots.h5".format(n_sp), 'w') as f:
        f.create_dataset("/U", data=fjlt_u)
        f.create_dataset("/s", data=fjlt_s)
        f.create_dataset("/VH", data=fjlt_vh)
    # Garbage collection for memory
    del fjlt_u
    del fjlt_s
    del fjlt_vh
    gc.collect()