In [None]:
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline

In [None]:
N = 51
L = 1.      # System size
h = L/(N-1)     # Grid spacing for Dirichlet
c = 1.      # Wave speed
omega = 10*np.pi
print('Time for wave to move one grid spacing is ', h/c) 

In [None]:
tau = 0.015 # set tau
coeff = -c*tau/(2.*h)    # Coefficient used by all schemes
coefflw = 2*coeff**2     # Coefficient used by L-W scheme
print('Wave crosses system in ', L/(c*tau), ' steps') 

In [None]:
nStep = 50 # set number of time steps

In [None]:
# Dirichlet boundary conditions
a = np.zeros(N)

In [None]:
#* Initialize plotting variables.
iplot = 1           # Plot counter
nplots = 20         # Desired number of plots
aplot = np.empty((N,nplots))
tplot = np.empty(nplots)
aplot[:,0] = np.copy(a)     # Record the initial state
tplot[0] = 0                # Record the initial time (t=0)
plotStep = nStep/nplots +1  # Number of steps between plots

In [None]:
method = 1
#* Loop over desired number of steps.
for iStep in range(nStep):  ## MAIN LOOP ##

    # set boundary conditions
    a[0] = np.sin(omega*tau*iStep)
    a[-1] = 0
    #* Compute new values of wave amplitude using FTCS, 
    #%  Lax or Lax-Wendroff method.
    if method == 1 :      ### FTCS method ###
        a[1:-1] = a[1:-1] + coeff*( a[2:] - a[:-2] )  
    elif  method == 2 :   ### Lax method ###
        a[1:-1] = .5*( a[2:] + a[:-2] ) + coeff*( a[2:] - a[:-2] )   
    elif method == 3:     ### Lax-Wendroff method ###
        a[1:-1] = ( a[1:-1] + coeff*( a[2:] - a[-2] ) + 
                coefflw*( a[2:] + a[:-2] -2*a[1:-1] ) )
    else:                 ### upwind method
        a[1:-1] = a[1:-1] + 2*coeff*( a[1:-1] - a[:-2] )

    #* Periodically record a(t) for plotting.
    if (iStep+1) % plotStep < 1 :        # Every plot_iter steps record 
        aplot[:,iplot] = np.copy(a)      # Record a(i) for ploting
        tplot[iplot] = tau*(iStep+1)
        iplot += 1
#        print(iStep, ' out of ', nStep, ' steps completed')

#* Plot the initial and final states.
plt.plot(x,aplot[:,0],'-',x,a,'--')
plt.legend(['Initial  ','Final'])
plt.xlabel('x')  
plt.ylabel('a(x,t)')
plt.show()