# Base Definition

In [2]:
import warnings
from typing import Any, Callable, Iterable, List, Optional, Tuple, Union
import jax # Numerical computing library for GPU/TPU accelerators
import jax.random as jr
import jax.numpy as jnp
import optax # Optax: A gradient processing and optimization library for JAX
import haiku as hk # Haiku: A neural network library for JAX
import numpy as np
from scipy.interpolate import griddata
from tqdm import tqdm
from matplotlib import pyplot as plt
from mpl_toolkits.axes_grid1 import make_axes_locatable
import matplotlib.gridspec as gridspec
import seaborn as sns # Seaborn: A statistical data visualization library
import os.path
import pickle

# Diffrax library imports
from diffrax import (
    diffeqsolve,
    ControlTerm,
    Euler,
    MultiTerm,
    ODETerm,
    SaveAt,
    VirtualBrownianTree,
    WeaklyDiagonalControlTerm
)

  PyTreeDef = type(jax.tree_structure(0))
  term_structure = jax.tree_structure(0)
  term_structure = jax.tree_structure(0)
  term_structure = jax.tree_structure((0, 0))
  term_structure = jax.tree_structure(0)
  term_structure = jax.tree_structure(0)
  term_structure = jax.tree_structure((0, 0))
  term_structure = jax.tree_structure((0, 0))
  term_structure = jax.tree_structure(0)
  term_structure = jax.tree_structure((0, 0))


In [3]:
# Define all hyper params here

g = 0.03
lambda_ = 0.5
t0, t1 = 0.0, 10.0
num_path_dataset = 500
# Define global settings here

seed = 2022
sns.set_context("paper", rc={"font.size":8,"axes.titlesize":8,"axes.labelsize":5})
sns.set(font_scale=1.5, rc={'text.usetex' : False})
sns.set(rc={'axes.facecolor':'black', 'figure.facecolor':'white', 'axes.labelcolor' : "black", 'text.color' : "black"})

## Physics System SDE

We consider the SDE function defined by $dx = - \lambda sign(x) x^2 dt + g dW$ with $\lambda = 7$ and $g = 0.03$ which is easy to find a solution by SDE solver.


In [5]:
physics_operator = lambda x: - jnp.sign(x) * x * x * lambda_

def r_process(initial_value, noise_scaling, seed):
    """
    Simulates a stochastic process using the Euler method.

    Parameters:
    - initial_value (float): Initial value of the process.
    - noise_scaling (float): Scaling factor for the diffusion term (stochastic component).
    - seed (int): Seed for random number generation.

    Returns:
    - sol (numpy.ndarray): Solution of the stochastic process over the specified time interval.
    """
    initial_shape = (1,)
    y0 = jnp.ones(shape=initial_shape) * initial_value
    drift = lambda t, y, args: physics_operator(y)

    diffusion = lambda t, y, args: noise_scaling * jnp.ones(initial_shape)

    brownian_motion = VirtualBrownianTree(t0, t1, tol=1e-3, shape=initial_shape, key=jr.PRNGKey(seed))
    terms = MultiTerm(ODETerm(drift), WeaklyDiagonalControlTerm(diffusion, brownian_motion))
    solver = Euler()
    saveat = SaveAt(dense=True)

    sol = diffeqsolve(terms, solver, t0, t1, dt0=0.01, y0=y0, saveat=saveat)
    
    return sol

Generate Dataset of num_path_dataset path using solver

In [None]:
dataset = []
x = jnp.linspace(0,10,500)

for n in tqdm(range(num_path_dataset)):
    sol = r_process((-1) ** n, g, seed+n)
    dataset.append(jnp.diag(sol.evaluate(x)))
with open('test1_example_data.p', 'wb') as file:
    pickle.dump(dataset, file)