In [6]:
import numpy as np
import matplotlib.pyplot as plt
import jax as jax 
import jax.numpy as jnp 



In [3]:
## FV Solver Functions

def periodic_bc(u:jnp.ndarray): 
    """
    Extend a 1D array u by one ghost cell on each side for periodic boundaries.
    Args:
        u: 1D array of length N.

    Returns:
        u_ext: 1D array of length N+2, where
               u_ext[0]   = u[-1],
               u_ext[1:-1] = u[:],
               u_ext[-1]  = u[0].
    """
    # u[-1:] is the last element as a length-1 array
    # u[:1]  is the first element as a length-1 array
    return jnp.concatenate([u[-1:], u, u[:1]], axis=0)

def linear_advection_split_form(u:jnp.ndarray, a:float):
    """
    Symmetric flux‐splitting for linear advection.

    Args:
      u_ext: array of length N+2 (with periodic ghost cells).
      a:     scalar wave‐speed.
      
    Returns:
      flux:  array of length N+1, flux[i] = F_{i+1/2}.
    """
    # positive and negative parts of a
    a_plus  = 0.5 * (a + jnp.abs(a))   # >= 0
    a_minus = 0.5 * (a - jnp.abs(a))   # <= 0

    # left state = u_ext[:-1], right state = u_ext[1:]
    return a_plus * u[:-1] + a_minus * u[1:] 

def rhs(u:jnp.ndarray, a:float, dx):
    """ 
    Computes the Right-Hand-Side of the conservation law. 
     Args:
      u : length‐N cell‐averages
      a : wave‐speed
      dx: cell width

    Returns:
      dudt: length‐N array of ∂u/∂t
    """
    
    # Setting Periodic Boundary Conditions 
    u_bc = periodic_bc(u) # This makes the boundaries of the u vector loop (Adding ghost cells)
    
    # Computing the Fluxes
    f = linear_advection_split_form(u_bc, a)
    return -(f[:1] - f[-1:]) # This will return a vector of length n
  
def rk4(u:jnp.ndarray, a:float, dx:float, dt:float): 
  """
  Numerical Time Integration using the Runge-kutta 4-Stage Scheme. 
  This function advances by a single time step. 

  Args:
      u (jnp.ndarray): Solution Vector
      a (float): Advection Wave Speed
      dx (float): Grid Spacing
      dt (float): Time Step Size
  """
  k1 = rhs(u,             a, dx)
  k2 = rhs(u + 0.5*dt*k1, a, dx)
  k3 = rhs(u + 0.5*dt*k2, a, dx)
  k4 = rhs(u +    dt*k3,  a, dx)
  return u + (dt/6.0)*(k1 + 2*k2 + 2*k3 + k4)

    