In [None]:
import numpy as np
import matplotlib.pyplot as plt

# Parameters
k = 1.0  # wave number
x = np.linspace(-10, 10, 500)  # spatial domain
t = np.linspace(0, 5, 200)     # time domain

# Create the plot
fig, ax1 = plt.subplots(1, 1, figsize=(8, 3))

# 2D plot at different time slices
time_slices = [0, 1, 2]
for t_slice in time_slices:
    t_idx = np.argmin(np.abs(t - t_slice))
    u_slice = np.exp(-k**2 * t[t_idx]) * np.cos(k * x - t[t_idx])
    ax1.plot(x, u_slice, label=f't={t_slice}')

ax1.set_xlabel('x', style='italic', weight='bold')
ax1.set_ylabel('u(x,t)', style='italic', weight='bold')
# Move legend to top right corner and make it smaller
ax1.legend(loc='upper right', fontsize='small')
ax1.grid(True, alpha=0.3)

plt.tight_layout()

# Save the plot as PNG file
plt.savefig('dissipative_wave.png', dpi=300, bbox_inches='tight')

# Show the plot
plt.show()