In [6]:
import numpy as np
import matplotlib.pyplot as plt
from mpmath import mp, sqrt, sin, cos, acos, pi, exp, matrix, arange
from tqdm.auto import tqdm

# Step 1: Define the mapping function
def mapping(x, alpha):
    """
    Maps the input x using a sinusoidal function based on alpha.

    Parameters:
    x (float): Input value.
    alpha (float): Scaling parameter for the mapping.

    Returns:
    float: Mapped value.
    """
    return sin(alpha * pi * x / 2) / sin(alpha * pi / 2)

# Step 3: Define the basis functions φn(x)
def phi_n(x, n, alpha):
    """
    Defines the basis functions φn(x) using cosine and the mapping function.

    Parameters:
    x (float): Input value.
    n (int): Order of the basis function.
    alpha (float): Scaling parameter for the mapping function.

    Returns:
    float: Value of the basis function at x.
    """
    cn = sqrt(1/pi) if n == 0 else sqrt(2/pi)
    return cn * cos(n * acos(mapping(x, alpha)))

def compute_svd_fourier_extension(n_coll, n_ext, n_phys, n_fourier, precision=64, cutoff=0):
    """
    Computes the Fourier extension using Singular Value Decomposition (SVD).

    Parameters:
    n_coll (int): Number of overcollocation points.
    n_phys (int): Number of physical grid points.
    n_ext (int): Number of extension grid points.
    n_fourier (int): Number of Fourier modes.
    precision (int, optional): Precision for mpmath calculations. Default is 64.
    cutoff (float, optional): Cutoff value for singular values. Default is 0.

    Returns:
    tuple: Tuple containing the interpolated matrix, extended grid points, and grid spacing.
    """
    mp.dps = precision  # Set decimal precision for mpmath

    L_phys  = mp.mpf(2)
    dx_coll = L_phys / (n_coll - 1)  # Grid spacing for collocation points
    dx_phys = L_phys / (n_phys - 1)  # Grid spacing for physical points
    L_ext   = dx_phys * n_ext        # Extension domain length
    ks      = np.arange(-int(n_fourier / 2) + 1, int(n_fourier / 2) + 1)  # Fourier modes
    x_coll  = matrix([i * dx_coll for i in range(n_coll)]) # Collocation grid points

    # Construct the matrix M
    M = matrix(n_coll, len(ks))
    for i in range(n_coll):
        for j in range(len(ks)):
            M[i, j] = exp(1j * ks[j] * 2 * pi / (L_phys + L_ext) * x_coll[i])

    # Perform SVD on the matrix M
    U, s, Vh = mp.svd(M)
    sinv = mp.diag([1 / si if si > cutoff else 0 for si in s])  # Inverted singular values matrix with cutoff
    Vht = Vh.H
    Ut = U.H
    M_inv = Vht * sinv * Ut

    # Extended grid points
    x_ext = matrix([i * dx_phys for i in range(n_ext + n_phys)])

    # Construct the interpolated matrix
    rec = matrix(len(x_ext), len(ks))
    for i in range(len(x_ext)):
        for j in range(len(ks)):
            rec[i, j] = exp(1j * ks[j] * 2 * pi / (L_phys + L_ext) * x_ext[i])

    M_int = rec * M_inv  # Interpolated matrix
    return M_int, x_ext, dx_phys

def plot_and_store_basis_functions(order, alpha, n_coll, n_ext, n_phys, n_fourier, precision=64, cutoff=0, plot_function=False):
    """
    Plots and stores the basis functions and their Fourier extensions.

    Parameters:
    order (int): Order of the basis functions.
    alpha (float): Scaling parameter for the mapping function.
    n_coll (int): Number of overcollocation points.
    n_ext (int): Number of extension grid points.
    n_phys (int): Number of physical grid points.
    n_fourier (int): Number of Fourier modes.
    precision (int, optional): Precision for mpmath calculations. Default is 64.
    cutoff (float, optional): Cutoff value for singular values. Default is 0.
    plot_function (bool, optional): If True, plots the basis functions. Default is False.

    Returns:
    tuple: Tuple containing the basis functions and the plot points.
    """
    # Compute the SVD Fourier extension
    M_int, x_ext, dx = compute_svd_fourier_extension(n_coll, n_ext, n_phys, n_fourier, precision, cutoff)
    M_int_numpy = np.array(M_int.tolist(), dtype=complex).reshape(len(x_ext), n_coll)

    x_plot = np.linspace(-1, 1, 1000)

    basis_functions = []
    max_vals = []
    max_diffs = []
    orders = np.arange(order)

    for n in tqdm(orders):
        # Compute basis function values for plotting
        basis_plot = [phi_n(mp.mpf(x), n, alpha) for x in x_plot]

        # Compute SVD extension of the basis function
        f = matrix([phi_n(mp.mpf(x), n, alpha) for x in np.linspace(-1, 1, n_coll)])
        f_ext = M_int * f
        f_ext_numpy = np.array(f_ext.tolist(), dtype=complex)
        basis_functions.append(f_ext_numpy)

        max_val  = np.max(np.abs(f_ext_numpy))
        diff     = np.abs([phi_n(mp.mpf(x), n, alpha) - f_ext_numpy[i].real for i, x in enumerate(np.linspace(-1, 1, n_phys))])
        max_diff = np.max(diff)

        max_vals.append(max_val)
        max_diffs.append(max_diff)

        # Plot the basis function and its periodic extension
        if plot_function:
            fig, axs = plt.subplots(1, 2, dpi=600, figsize=(12, 6))
            axs[0].plot(x_plot, basis_plot, label=f'φ_{n}(x)')
            axs[0].plot(np.linspace(-1, 1 + np.float32(dx) * n_ext, len(f_ext_numpy)), f_ext_numpy.real, label=f'Extended φ_{n}(x)', linestyle='--')
            axs[0].set_xlabel('x')
            axs[0].set_ylabel('Basis Function and Extension')
            axs[0].set_title(f'Basis Function φ_{n}(x) and its Extension')
            axs[0].legend()
            axs[0].grid(True)

            axs[1].plot(np.linspace(-1, 1, n_phys), diff, label=f'Difference φ_{n}(x)', linestyle='-.')
            axs[1].set_yscale('log')
            axs[1].set_xlabel('x')
            axs[1].set_ylabel('Absolute Difference')
            axs[1].set_title(f'Absolute Difference on Physical Domain')
            axs[1].legend()
            axs[1].grid(True)

            plt.show()

    if 0:  # Plot maximum values and differences
        plt.figure(figsize=(10, 6))
        fig, axs = plt.subplots(1, 2, dpi=600, figsize=(12, 6))
        axs[0].plot(orders, max_vals, label=f'Maximum value in extension domain')
        axs[0].set_xlabel('Mapped polynomial order')
        axs[0].set_ylabel('Value')
        axs[0].set_yscale('log')
        axs[0].set_title(f'Maximum value of extension of Basis Function φ_n(x)')
        axs[0].legend()
        axs[0].grid(True)

        axs[1].plot(orders, max_diffs, label=f'Max difference φ_n(x)', linestyle='-.')
        axs[1].set_yscale('log')
        axs[1].set_xlabel('Mapped polynomial order')
        axs[1].set_ylabel('Absolute Difference')
        axs[1].set_title(f'Absolute Difference on Physical Domain')
        axs[1].legend()
        axs[1].grid(True)

        plt.savefig(f'extended_basis_functions_alpha{alpha}_order{order}_ncoll{n_coll}_next{n_ext}_nphys{n_phys}_nfourier{n_fourier}_precision{precision}_cutoff{cutoff}.png')
        plt.show()

    return np.array(basis_functions), x_plot

# Parameters
for cutoff in [1e-15, 1e-13]:
    for n_fourier in [52, 56, 60]:
        for order in [28, 30, 32]:
            for alpha in [0.5]:
                n_phys    = 32
                n_coll    = 150
                n_ext     = 32
                precision = 64

                basis_functions, x_plot = plot_and_store_basis_functions(order, alpha, n_coll, n_ext, n_phys, n_fourier, precision, cutoff)

                # Save the basis functions to a file
                filename = f'extended_basis_functions_alpha{alpha}_order{order}_ncoll{n_coll}_next{n_ext}_nphys{n_phys}_nfourier{n_fourier}_precision{precision}_cutoff{cutoff}.npy'
                np.save(filename, basis_functions)
