In [1]:
import numpy as np
import matplotlib.pyplot as plt
import jax as jax 
import jax.numpy as jnp 



In [2]:
## FV Solver Functions

def periodic_bc(u:jnp.ndarray): 
    """
    Extend a 1D array u by one ghost cell on each side for periodic boundaries.
    Args:
        u: 1D array of length N.

    Returns:
        u_ext: 1D array of length N+2, where
               u_ext[0]   = u[-1],
               u_ext[1:-1] = u[:],
               u_ext[-1]  = u[0].
    """
    # u[-1:] is the last element as a length-1 array
    # u[:1]  is the first element as a length-1 array
    return jnp.concatenate([u[-1:], u, u[:1]], axis=0)

def linear_advection_split_form(u:jnp.ndarray, a:float):
    """
    Symmetric flux‐splitting for linear advection.

    Args:
      u_ext: array of length N+2 (with periodic ghost cells).
      a:     scalar wave‐speed.
      
    Returns:
      flux:  array of length N+1, flux[i] = F_{i+1/2}.
    """
    # positive and negative parts of a
    a_plus  = 0.5 * (a + jnp.abs(a))   # >= 0
    a_minus = 0.5 * (a - jnp.abs(a))   # <= 0

    # left state = u_ext[:-1], right state = u_ext[1:]
    return a_plus * u[:-1] + a_minus * u[1:] 

def rhs(u:jnp.ndarray, a:float, dx):
    """ 
    Computes the Right-Hand-Side of the conservation law. 
     Args:
      u : length‐N cell‐averages
      a : wave‐speed
      dx: cell width

    Returns:
      dudt: length‐N array of ∂u/∂t
    """
    
    # Setting Periodic Boundary Conditions 
    u_bc = periodic_bc(u) # This makes the boundaries of the u vector loop (Adding ghost cells)
    
    # Computing the Fluxes
    f = linear_advection_split_form(u_bc, a)
    return -(f[:1] - f[-1:]) # This will return a vector of length n
  
def rk4(u:jnp.ndarray, a:float, dx:float, dt:float): 
  """
  Numerical Time Integration using the Runge-kutta 4-Stage Scheme. 
  This function advances by a single time step. 

  Args:
      u (jnp.ndarray): Solution Vector
      a (float): Advection Wave Speed
      dx (float): Grid Spacing
      dt (float): Time Step Size
  """
  k1 = rhs(u,             a, dx)
  k2 = rhs(u + 0.5*dt*k1, a, dx)
  k3 = rhs(u + 0.5*dt*k2, a, dx)
  k4 = rhs(u +    dt*k3,  a, dx)
  return u + (dt/6.0)*(k1 + 2*k2 + 2*k3 + k4)

    

In [3]:
## Driver Block

N = 500
x_min = -1
x_max = 1
dx = (x_max-x_min)/N
dt = 0.25e-2
a = 1 # Advection coefficient 
cfl = a*dt/dx 


##########################################
########### Initial Conditions ###########
x_centers = x_min + (jnp.arange(N) + 0.5)*dx # Defining the grid 
μ     = -0.75          # center of the pulse
σ     = 0.05 
A     = 1.0          # peak amplitude
u0 = A * jnp.exp(-((x_centers - μ)**2) / (2 * σ**2))
############ Initial COnditions ######################
######################################################


In [4]:
# ─── Cell: Driver Block ─────────────────────────────────────────────────────
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt

# 1) Problem parameters
N       = 500
x_min   = -1.0
x_max   =  1.0
a       =  1.0     # advection speed
CFL     =  0.8     # choose a CFL number
T_final =  1.0     # final time

# 2) Grid spacing & time‐step
dx      = (x_max - x_min) / N
dt      = CFL * dx / abs(a)        # or keep your 0.25e-2 if you prefer: dt = 0.25e-2
n_steps = int(jnp.ceil(T_final / dt))
dt      = T_final / n_steps        # adjust so n_steps*dt = T_final exactly

print(f"dx = {dx:.3e}, dt = {dt:.3e}, n_steps = {n_steps}, CFL ≃ {a*dt/dx:.2f}")

# 3) Cell centers
x_centers = x_min + (jnp.arange(N) + 0.5) * dx

# 4) Initial condition: Gaussian pulse
μ = -0.75    # center
σ =  0.05    # width
A =  1.0     # amplitude
u0 = A * jnp.exp(-((x_centers - μ)**2) / (2 * σ**2))

# 5) Run the solver
#    Make sure advect_1d_rk4_split is defined or imported in this notebook!
u_final = advect_1d_rk4_split(u0, a, dx, dt, n_steps)

# 6) Bring data back to CPU for plotting
x = jax.device_get(x_centers)
u0_np      = jax.device_get(u0)
u_final_np = jax.device_get(u_final)

# 7) Plot
plt.figure(figsize=(6,3))
plt.plot(x, u0_np,    label='t = 0')
plt.plot(x, u_final_np,label=f't = {T_final:.2f}')
plt.xlabel('x')
plt.ylabel('u')
plt.title('1D Linear Advection (RK4 + Flux Splitting)')
plt.legend()
plt.tight_layout()
plt.show()


dx = 4.000e-03, dt = 3.195e-03, n_steps = 313, CFL ≃ 0.80


NameError: name 'advect_1d_rk4_split' is not defined