In [1]:
import jax.numpy as jnp
import numpy as np
from jax import random
from jax.config import config
config.update("jax_enable_x64", True)
import matplotlib.pyplot as plt

%config InlineBackend.figure_format='retina'

In [2]:
! pip -q install optax

[?25l[K     |██▎                             | 10 kB 18.3 MB/s eta 0:00:01[K     |████▋                           | 20 kB 15.4 MB/s eta 0:00:01[K     |███████                         | 30 kB 10.9 MB/s eta 0:00:01[K     |█████████▎                      | 40 kB 9.4 MB/s eta 0:00:01[K     |███████████▋                    | 51 kB 4.9 MB/s eta 0:00:01[K     |██████████████                  | 61 kB 5.7 MB/s eta 0:00:01[K     |████████████████▎               | 71 kB 5.8 MB/s eta 0:00:01[K     |██████████████████▋             | 81 kB 4.3 MB/s eta 0:00:01[K     |█████████████████████           | 92 kB 4.8 MB/s eta 0:00:01[K     |███████████████████████▎        | 102 kB 5.2 MB/s eta 0:00:01[K     |█████████████████████████▋      | 112 kB 5.2 MB/s eta 0:00:01[K     |████████████████████████████    | 122 kB 5.2 MB/s eta 0:00:01[K     |██████████████████████████████▎ | 133 kB 5.2 MB/s eta 0:00:01[K     |████████████████████████████████| 140 kB 5.2 MB/s 
[?25h[?25l[K

In [3]:
! git init .
! git remote add origin https://github.com/VLSF/SNO
! git pull origin main

Initialized empty Git repository in /content/.git/
remote: Enumerating objects: 36, done.[K
remote: Counting objects: 100% (36/36), done.[K
remote: Compressing objects: 100% (18/18), done.[K
remote: Total 36 (delta 18), reused 32 (delta 17), pack-reused 0[K
Unpacking objects: 100% (36/36), done.
From https://github.com/VLSF/SNO
 * branch            main       -> FETCH_HEAD
 * [new branch]      main       -> origin/main


In [4]:
from functions import Chebyshev

# Introduction

In this notebook we explain and test functions from module `Chebyshev.py`.

This module provide basic functionality for non-periodic functions defined on $[-1, 1]$.

# Chebyshev grid

Chebyshev grid of the second kind.

In [5]:
Chebyshev.Chebyshev_grid(4)



DeviceArray([-1. , -0.5,  0.5,  1. ], dtype=float64)

grids are nested

In [6]:
Chebyshev.Chebyshev_grid(7)

DeviceArray([-1.00000000e+00, -8.66025404e-01, -5.00000000e-01,
              6.12323400e-17,  5.00000000e-01,  8.66025404e-01,
              1.00000000e+00], dtype=float64)

# Values to coefficients / coefficients to values

Test for known coefficients:

1. Draw random Chebyshev series
2. Compute values at chebyshev grid
3. Recover coefficients
4. Compare with coefficients from 1

In [7]:
n = 50
coeff = random.normal(random.PRNGKey(13), (n, ))*4
T = np.polynomial.chebyshev.Chebyshev(np.array(coeff))

values = T(np.array(Chebyshev.Chebyshev_grid(n+10))).reshape(-1, 1)

num_coeff = Chebyshev.values_to_coefficients(values).reshape(-1, )

error1 = jnp.linalg.norm(coeff - num_coeff[:n], ord=jnp.inf)
error2 = jnp.linalg.norm(num_coeff[n:], ord=jnp.inf)
print(error1 + error2)

1.0171900986294486e-13


Forward + inverse = identity

In [8]:
x = random.normal(random.PRNGKey(13), (10, 20, 40, 5))*4

error1 = jnp.linalg.norm((x - Chebyshev.values_to_coefficients(Chebyshev.coefficients_to_values(x))).reshape(-1, ), ord=jnp.inf)
error2 = jnp.linalg.norm((x - Chebyshev.coefficients_to_values(Chebyshev.values_to_coefficients(x))).reshape(-1, ), ord=jnp.inf)

print(error1 + error2)

3.375077994860476e-14


# Integration

In [9]:
n = 50
coeff = random.normal(random.PRNGKey(13), (n, ))*4
T = np.polynomial.chebyshev.Chebyshev(np.array(coeff))

values = T(np.array(Chebyshev.Chebyshev_grid(n+10))).reshape(-1, 1)

num_coeff = Chebyshev.values_to_coefficients(values)

error = jnp.linalg.norm(Chebyshev.integrate(num_coeff, 0).reshape(-1,)[1:(n+1)] - jnp.array(T.integ().coef)[1:], ord=jnp.inf)
print(error)

8.881784197001252e-15


# Differentiation

In [10]:
n = 50
coeff = random.normal(random.PRNGKey(13), (n, ))*4
T = np.polynomial.chebyshev.Chebyshev(np.array(coeff))

values = T(np.array(Chebyshev.Chebyshev_grid(n+10))).reshape(-1, 1)

num_coeff = Chebyshev.values_to_coefficients(values)

error = jnp.linalg.norm(Chebyshev.differentiate(num_coeff, 0).reshape(-1,)[:(n-1)] - jnp.array(T.deriv().coef), ord=jnp.inf)
print(error)

2.660272002685815e-11
