In [2]:
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
from numba import njit

In [3]:
# Implementation given by eq. (23) of http://hplgit.github.io/INF5620/doc/notes/wave-sphinx/main_wave.html
# Eq: u_tt + b*u_t = c^2*u_xx + f(x,t)


In [8]:
# Parameters
dx = 1. # m
dt = 10

L = 1000 # m string length
max_niter = int(5e3)
grid_length = int(L/dx)
simulated_time = int(dt*max_niter)

b = 0.001 # friction term
c_squared = 1.0 # coefficient of wave
f_ext = 0.000001 * np.ones(shape=grid_length) # external force

In [9]:
fin_diff_matrix = (np.diag(np.ones(shape=grid_length-1), k=-1) +
             np.diag(np.ones(shape=grid_length-1), k=+1) -
             np.diag(2*np.ones(shape=grid_length)))

@njit
def run_one_timestep(u, u_old):
    return (1/(1 + 0.5*b*dt)*((0.5*b*dt - 1) * u_old +
                                2*u +
                                c_squared*(fin_diff_matrix @ u) +
                                 dt**2 * f_ext
                            )
    )
 
@njit 
def run_all_timesteps(u, u_old):
    solution = np.zeros(shape=(max_niter, grid_length))
    solution[0] = u
    for i in range(1, max_niter):
        # Diri BC
        u[0] = 0.; u[grid_length-1] = 0
        
        # drainage canal
        if i > 300:
            u[int(grid_length/4)] = 1.
        
        u_new = run_one_timestep(u, u_old)
        solution[i] = u_new
        
        u_old = u.copy()
        u = u_new.copy()
        
    return solution  


In [10]:
# %% Solve
# Initial condition
u = np.zeros(shape=grid_length)
# Triangular
# u[:int(grid_length/2)] = np.linspace(start=0, stop=1., num=int(grid_length/2))
# u[int(grid_length/2):] = np.linspace(start=1., stop=0., num=int(grid_length/2))
                                      
# Diri BC
u[0] = 0.; u[grid_length-1] = 0
u_old = u.copy()

solution = run_all_timesteps(u, u_old)

In [12]:
%matplotlib notebook
# %% Plot
plt.figure()
for i in range(0, max_niter):
    if i==0:
        plt.plot(solution[0], color='brown', alpha=1.0)
    if i%int(max_niter/1000) == 0:  
        plt.plot(solution[i], color='brown', alpha=0.2)
# %% Animate
from matplotlib.animation import FuncAnimation 
# initializing a figure in 
# which the graph will be plotted
fig = plt.figure() 
   
# marking the x-axis and y-axis
axis = plt.axes(xlim =(0, 1000), 
                ylim =(0, 12)) 
  
# initializing a line variable
line, = axis.plot([], [], lw = 2, alpha=0.7, color='brown') 

def init(): 
    line.set_data([], [])
    return line,

def animate(i):
    x = np.linspace(0, L, num=grid_length)
   
    line.set_data(x, solution[i])
      
    return line,
   
anim = FuncAnimation(fig, animate, init_func = init,
                     frames = max_niter, interval = 100, blit = True)
  

    
anim.save(r'C:\Users\03125327\github\subsi_wave\peatland.mp4', 
          writer = 'ffmpeg', fps = 30)
# %%


<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

In [18]:
# Plot p of a given point in x

location_in_x = 600

plt.figure()
plt.plot(solution[:,location_in_x])
plt.title(f'peat height at position {location_in_x} over time')
plt.ylabel('peat height')
plt.xlabel('time')
plt.show()


<IPython.core.display.Javascript object>