Testing functionality of shift reflect and it's implementation in the Newton solve

In [None]:
from jax.config import config
config.update("jax_enable_x64", True)

In [None]:
import jax.numpy as jnp
import jax_cfd.base as cfd

import interact_spectral as insp
import newton_spectral as nt_sp
from time_forward_map_spectral import generate_time_forward_map

Configure the flow and generate a realistic initial condition

In [None]:
Nx = 256
Ny = 256
L = 2 * jnp.pi

Re = 40.

In [None]:
grid = cfd.grids.Grid((Nx, Ny), domain=((0, L), (0, L)))
max_velocity = 5. # estimate 
dt_stable = cfd.equations.stable_time_step(max_velocity, 0.5, 1. / Re, grid) 

In [None]:
burn_in_time_forward_map = generate_time_forward_map(dt_stable, int(50. / dt_stable), grid, 1. / Re)

In [None]:
import jax
v0 = cfd.initial_conditions.filtered_velocity_field(jax.random.PRNGKey(123), grid, max_velocity, 4)
vorticity0 = cfd.finite_differences.curl_2d(v0).data
vorticity0_rft = jnp.fft.rfftn(vorticity0)

In [None]:
# now get realistic initial condition
omega_rft = burn_in_time_forward_map(vorticity0_rft)

In [None]:
n_shift_reflects = 3
omega = jnp.fft.irfftn(omega_rft)
omega_shift_reflect = insp.y_shift_reflect(omega, 
                                           grid, 
                                           n_shift_reflects)

In [None]:
import matplotlib.pyplot as plt
fig = plt.figure(figsize=(10, 5))

ax = fig.add_subplot(1, 2, 1)
ax.contourf(omega.T, 101)
ax.set_xticks([])
ax.set_yticks([])

ax = fig.add_subplot(1, 2, 2)
ax.contourf(omega_shift_reflect.T, 101)
ax.set_xticks([])
ax.set_yticks([])

fig.tight_layout()

Now verify functionality of Newton solver

In [None]:
newton_solver = nt_sp.rpoSolverSpectral(grid, nmax_hook=10, eps_newt=1e-10)

In [None]:
initial_guess = nt_sp.poGuessSpectral(omega_rft, 20., 0., n_shift_reflects=1) # testing functionality with n_shift_reflects > 0

In [None]:
newton_solver.iterate(initial_guess, Re, dt_stable)

Now let's load in a solution -- check we can converge it, then try and do some branch continuation

In [None]:
soln_number = 0
upo_file_path = '../../../newton-jax-minimal/newt_minimal_spectral/soln_ars_Re40/soln_array_Re40_' + str(soln_number) + '.npy'
meta_file_path = '../../../newton-jax-minimal/newt_minimal_spectral/soln_ars_Re40/soln_meta_Re40_' + str(soln_number) + '.npy'

upo_rft = jnp.load(upo_file_path)
upo_meta = jnp.load(meta_file_path)

Warning! Regular UPO solve will over-write the original guess, using the "record outcome" method. 

In [None]:
upo_guess = nt_sp.poGuessSpectral(upo_rft, upo_meta[0], upo_meta[1], n_shift_reflects=0)

Converge solution at slightly higher value of $Re$ to start the branch continuation

In [None]:
Re_new = 40.25
upo_perturbed = newton_solver.iterate(upo_guess, Re_new, dt_stable)

Re-generate the original guess (overwritten by the above. Need a specific method to do this)

In [None]:
upo_guess = nt_sp.poGuessSpectral(upo_rft, upo_meta[0], upo_meta[1], n_shift_reflects=0)
upo_guess.record_outcome(upo_rft, upo_guess.T_init, upo_guess.shift_init, None, None)

In [None]:
import importlib
importlib.reload(nt_sp)

In [None]:
branch_continuation_solver = nt_sp.rpoBranchContinuation(grid)

In [60]:
new_soln, new_Re = branch_continuation_solver.iterate(upo_perturbed, upo_guess, Re_new, Re, dt_stable)

Unexpected exception formatting exception. Falling back to standard exception


Traceback (most recent call last):
  File "/Users/jpage2/miniconda3/envs/jax_cfd/lib/python3.10/site-packages/IPython/core/interactiveshell.py", line 3508, in run_code
    exec(code_obj, self.user_global_ns, self.user_ns)
  File "/var/folders/0t/c6rdvjkx6msgst60p0y4wg1h0000gn/T/ipykernel_48791/63916131.py", line 1, in <module>
    new_soln, new_Re = branch_continuation_solver.iterate(upo_perturbed, upo_guess, Re_new, Re, dt_stable)
  File "/Users/jpage2/code/DEV/newton-jax-minimal/spectral/newton_spectral.py", line 524, in iterate
  File "/Users/jpage2/code/DEV/newton-jax-minimal/spectral/arnoldi.py", line 87, in gmres
    basis.add_basis_vector(A_operator)
  File "/Users/jpage2/code/DEV/newton-jax-minimal/spectral/arnoldi.py", line 30, in add_basis_vector
    v = A_operator(self.basis[:, self.n_basis_vec])
  File "/Users/jpage2/code/DEV/newton-jax-minimal/spectral/newton_spectral.py", line 605, in _timestep_A
  File "/Users/jpage2/code/DEV/newton-jax-minimal/spectral/newton_spectral.p

In [None]:
new_soln, new_Re = branch_continuation_solver.iterate(upo_perturbed, upo_guess, Re_new, Re, dt_stable)