In [1]:
import numpy as np
import sympy as sp
import scipy.sparse as sparse

import ipywidgets as widgets

import matplotlib.pyplot as plt
from matplotlib import cm
%matplotlib notebook

#from IPython.display import display, HTML
#display(HTML("<style>.container { width:100% !important; }</style>"))

x, y, t = sp.symbols('x,y,t')

In [2]:
def mesh2D(N, L, sparse=False):
    x = np.linspace(0, L, N+1)
    y = np.linspace(0, L, N+1)
    return np.meshgrid(x, y, indexing='ij', sparse=sparse)

def D2(N):
    D = sparse.diags([1, -2, 1], [-1, 0, 1], (N+1, N+1), 'lil')
    D[0, :4] = 2, -5, 4, -1
    D[-1, -4:] = -1, 4, -5, 2
    return D

In [3]:
#copied from lecture 7
def solver(N, L, Nt, mx, my, cfl=0.5, c=1, store_data=10):
    w = c * np.sqrt(mx**2 + my**2)
    ue = sp.sin(mx * sp.pi * x) * sp.sin(my * sp.pi * y) * sp.cos(w * t)

    dx = L/N
    dt = cfl*dx/c
    xij, yij = mesh2D(N, L)
    D = D2(N)/dx**2

    # initial conditions:
    Unp1, Un, Unm1 = np.zeros((3, N+1, N+1))
    # initial impulse, w is 0
    t0=0
    Unm1 = sp.lambdify((x, y, t), ue)(xij, yij, t0)
   
    Un[:] = Unm1[:] + 0.5*(c*dt)**2*(D @ Un + Un @ D.T)
    plotdata = {0: Unm1.copy()}
    for n in range(1, Nt):
        Unp1[:] = 2*Un - Unm1 + (c*dt)**2*(D @ Un + Un @ D.T)
        # Set boundary conditions
        Unp1[0] = 0
        Unp1[-1] = 0
        Unp1[:, -1] = 0
        Unp1[:, 0] = 0
        # Swap solutions
        Unm1[:] = Un
        Un[:] = Unp1
        if n % store_data == 0:
            plotdata[n] = Unm1.copy() # Unm1 is now swapped to Un
    return xij, yij, plotdata

In [4]:
N = 40; L = 1; Nt = 50; mx = 2; my = 2; cfl = 0.1
xij, yij, data = solver(N, L, Nt, mx, my, cfl, store_data=1)
dict_values = data.values()
max_z_value = np.amax(np.concatenate(list(dict_values)))
min_z_value = np.amin(np.concatenate(list(dict_values)))
min_z_value, max_z_value

(-1.0, 1.0)

In [5]:
fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')

ax.set_xlim(0, xij.max())
ax.set_ylim(0, yij.max())
ax.set_zlim(min_z_value, max_z_value)
surf = ax.plot_surface(xij, yij, data[0], cmap=cm.coolwarm, linewidth=0, antialiased=False)

def update_plot(frame=0):
    global surf
    surf.remove()
    surf = ax.plot_surface(xij, yij, data[frame], cmap=cm.coolwarm, linewidth=0, antialiased=False)
    
_= widgets.interact(update_plot, frame=widgets.IntSlider(min=0, max=len(data)-1, step=1, value=0))

<IPython.core.display.Javascript object>

interactive(children=(IntSlider(value=0, description='frame', max=49), Output()), _dom_classes=('widget-intera…