In [1]:
import diffrax
import jax
import jax.numpy as jnp
import numpy as np
from jax import random

import matplotlib.pyplot as plt
from IPython.display import HTML
import matplotlib.animation as animation

from pde_opt.numerics.domains import Domain
from pde_opt.numerics.equations import AdvectionDiffusion2D



In [2]:

Nx, Ny = 128, 128
Lx = 0.01 * Nx
Ly = 0.01 * Ny
domain = Domain((Nx, Ny), ((-Lx / 2, Lx / 2), (-Ly / 2, Ly / 2)), "dimensionless")

t_start = 0.0
t_final = 5.0
dt = 0.0001

ts_save = jnp.linspace(t_start, t_final, 200)



In [3]:
path_fn_ = lambda t: (-0.4 + 0.8*t/t_final, -0.4 + 0.8*t/t_final)

def advection(t, p, xs, ys):
    xi, yi = path_fn_(t)
    r2 = ((xs-xi)**2 + (ys-yi)**2)/(2.0*p[1])
    grad_x = -(xs-xi)/p[1] * jnp.exp(-r2)
    grad_y = -(ys-yi)/p[1] * jnp.exp(-r2)
    return p[0] * grad_x, p[0] * grad_y

In [4]:
eq = AdvectionDiffusion2D(
    domain,
    lambda t, x, y: advection(t, [0.1,0.01], x, y),
    0.1,
    smooth=False,
)


In [6]:
key = random.PRNGKey(0)
u0 = 0.5 * jnp.ones((Nx, Ny)) + 0.01 * random.normal(key, (Nx, Ny))

solution = diffrax.diffeqsolve(
    diffrax.ODETerm(jax.jit(lambda t, y, args: eq.rhs(y, t))),
    diffrax.Tsit5(),
    t0=t_start,
    t1=t_final,
    dt0=dt,
    y0=u0,
    stepsize_controller=diffrax.PIDController(rtol=1e-4, atol=1e-6),
    # saveat=diffrax.SaveAt(ts=ts_save),
    saveat=diffrax.SaveAt(t1=True),
    max_steps=1000000,
)

print(solution.stats)

print(jnp.mean(solution.ys[0]))
print(jnp.mean(solution.ys[-1]))

{'max_steps': 1000000, 'num_accepted_steps': Array(28077, dtype=int32, weak_type=True), 'num_rejected_steps': Array(7715, dtype=int32, weak_type=True), 'num_steps': Array(35792, dtype=int32, weak_type=True)}
0.49999356
0.49999356


In [7]:
print(solution.ys.shape)

(1, 128, 128)


In [51]:
fig, ax = plt.subplots(figsize=(4,4))

ims = []
for i in range(0, len(solution.ys), 2):
    im = ax.imshow(solution.ys[i], animated=True, 
                   vmin=0.0, vmax=1.0,
                   extent=[domain.box[0][0], domain.box[0][1], 
                          domain.box[1][0], domain.box[1][1]])
    ims.append([im])

ani = animation.ArtistAnimation(fig, ims, interval=50, blit=True)

plt.title('Cahn-Hilliard Evolution')
plt.xlabel('x')
plt.ylabel('y')

plt.close()

HTML(ani.to_jshtml())
