In [None]:
import sys, os
sys.path.extend([os.path.abspath('../')])

In [None]:
import numpy as np
import sympy as sy
import sparse as sp
import math
from numba import njit

In [None]:
from qgs.params.params import QgParams
from qgs.functions.tendencies import create_tendencies
from qgs.functions.symbolic_tendencies import create_symbolic_equations

In [None]:
from qgs.inner_products.symbolic import AtmosphericSymbolicInnerProducts, OceanicSymbolicInnerProducts
from qgs.tensors.qgtensor import QgsTensor, QgsTensorDynamicT, QgsTensorT4
from qgs.tensors.symbolic_qgtensor import SymbolicQgsTensor, SymbolicQgsTensorDynamicT

In [None]:
model_parameters = QgParams({'n': 1.5}, dynamic_T=True)

In [None]:
model_parameters.set_atmospheric_channel_fourier_modes(2, 2, mode="symbolic")
# Mode truncation at the wavenumber 2 in the x and at the 
# wavenumber 4 in the y spatial coordinates for the ocean
model_parameters.set_oceanic_basin_fourier_modes(2, 4, mode="symbolic")

In [None]:
# Setting MAOOAM parameters according to the publication linked above
model_parameters.set_params({'kd': 0.0290, 'kdp': 0.0290, 'r': 1.e-7,
                             'h': 136.5, 'd': 1.1e-7})
model_parameters.atemperature_params.set_params({'eps': 0.7, 'hlambda': 15.06})
model_parameters.gotemperature_params.set_params({'gamma': 5.6e8})

In [None]:
model_parameters.atemperature_params.set_insolation(103., 0)
model_parameters.atemperature_params.set_insolation(103., 1)
model_parameters.gotemperature_params.set_insolation(310., 0)
model_parameters.gotemperature_params.set_insolation(310., 1)

In [None]:
model_parameters.print_params()

In [None]:
f_num, Df = create_tendencies(model_parameters)

In [None]:
f_sym, jac_sym = create_symbolic_equations(model_parameters, language='python', continuation_variables=[model_parameters.atemperature_params.eps, model_parameters.atmospheric_params.kd], return_jacobian=True)

In [None]:
print(f_sym)

In [None]:
print(jac_sym)

In [None]:
exec(f_sym)

In [None]:
from qgs.integrators.integrator import RungeKuttaIntegrator, RungeKuttaTglsIntegrator
import matplotlib.pyplot as plt

In [None]:
eps = model_parameters.atemperature_params.eps
kd = model_parameters.atmospheric_params.kd

In [None]:
@njit
def f_fix(t, U):
    return f(t, U, eps, kd)

In [None]:
integrator = RungeKuttaIntegrator()
integrator.set_func(f_fix)

In [None]:
integrator_num = RungeKuttaIntegrator()
integrator_num.set_func(f_num)

In [None]:
ic = np.array([ 3.84101549e-02, -8.29674554e-03,  3.04587364e-02,  2.80766373e-02,
       -9.14885177e-03, -9.17520676e-04, -1.76115081e-02,  1.32010146e-02,
        1.62515224e-02,  1.08600254e-03,  1.53918671e+00,  4.13205067e-02,
       -9.25169842e-04,  4.01449139e-03,  6.97326597e-03, -9.93383832e-03,
        8.88594931e-03, -6.07097456e-03,  4.34490969e-03,  4.19834122e-03,
       -2.91974161e-03,  1.03085300e-05,  5.98444985e-04, -2.57753313e-05,
        5.22115566e-06, -3.01445438e-05,  3.26249104e-04, -1.92171554e-05,
        1.38469482e-05,  3.17552667e+00,  2.46854576e-03,  1.44249578e-01,
       -5.94828283e-03,  2.34242352e-02, -3.08095487e-03,  9.15501463e-02,
        1.17932987e-03, -4.34659450e-05])

In [None]:
%%time
integrator.integrate(0., 1000000., 0.1, ic=ic, write_steps=10)
reference_time, reference_traj = integrator.get_trajectories()

In [None]:
%%time
integrator_num.integrate(0., 1000000., 0.1, ic=ic, write_steps=10)
reference_time_num, reference_traj_num = integrator_num.get_trajectories()

In [None]:
varx = 22
vary = 31
plt.figure(figsize=(10, 8))

plt.plot(reference_traj[varx], reference_traj[vary], marker='o', ms=0.07, ls='')
plt.plot(reference_traj_num[varx], reference_traj_num[vary], marker='o', ms=0.07, ls='')

plt.xlabel('$'+model_parameters.latex_var_string[varx]+'$')
plt.ylabel('$'+model_parameters.latex_var_string[vary]+'$');

In [None]:
plt.figure(figsize=(12, 8))

plt.plot(reference_traj[varx, :])
plt.plot(reference_traj_num[varx, :])
plt.show()

In [None]:
plt.figure(figsize=(12, 8))

plt.plot(reference_traj_num[vary, :])
plt.plot(reference_traj[vary, :])
plt.show()

In [None]:
plt.figure(figsize=(12, 8))

plt.plot(reference_traj_num[varx, ::1000])
plt.plot(reference_traj[varx, ::1000])
plt.show()

In [None]:
tendencies_sym = np.empty((38, reference_time[::1000].shape[0]))
for n, x in enumerate(reference_time[::1000]):
    x = reference_traj_num[:, n]
    tendencies_sym[:, n] = f_fix(0, x)

In [None]:
tendencies_sym.shape

In [None]:
tendencies_num = np.empty_like(tendencies_sym)
for n, x in enumerate(reference_time[::1000]):
    x = reference_traj_num[:, n]
    tendencies_num[:, n] = f_num(0, x)

In [None]:
tendencies_err = tendencies_sym - tendencies_num

In [None]:
plt.figure(figsize=(12, 8))
varx = 10
plt.plot(tendencies_err[:, :].T)

plt.show()

In [None]:
plt.figure(figsize=(12, 8))
varx = [1, 12]
plt.plot(tendencies_err[varx, :].T)

plt.show()