In [64]:
# -*- coding: utf-8 -*-
"""
改进的Kelvin波东传动画 - 更美观的视觉效果
@author: Hohai University
"""
import matplotlib.pyplot as plt
import matplotlib.animation as animation
import numpy as np
from matplotlib.colors import LinearSegmentedColormap
import cartopy.crs as ccrs
import cartopy.feature as cfeature
from matplotlib.patches import Circle

plt.rcParams['font.sans-serif'] = ['Arial', 'Helvetica', 'DejaVu Sans']
plt.rcParams['axes.unicode_minus'] = False

def create_improved_colormap():
    """创建改进的红-白-蓝配色"""
    colors = ['#00008B', '#0000FF', '#4169E1', '#87CEEB', '#E0F6FF',
              '#FFFFFF', 
              '#FFE0E0', '#FFB6C1', '#FF6B6B', '#FF0000', '#8B0000']
    return LinearSegmentedColormap.from_list('improved_bwr', colors, N=256)

def kelvin_wave_field(x, y, t, k=1.0, omega=1.0, c=1.0, y_scale=7.0):
    """
    计算理想化的Kelvin波场
    
    参数:
        x, y: 空间坐标 (经度和纬度方向)
        t: 时间
        k: 波数
        omega: 频率
        c: 波速
        y_scale: 纬度衰减尺度
    """
    # Kelvin波的纬度结构: 高斯衰减
    y_structure = np.exp(-0.5 * (y / y_scale)**2)
    
    # 东传的波动相位
    phase = k * x - omega * t
    
    # 纬向速度 (u) - 主要分量
    u = y_structure * np.cos(phase)
    
    # 经向速度 (v) - Kelvin波特征: v=0
    v = np.zeros_like(u)
    
    # 位势高度场 (Phi)
    Phi = (c / k) * y_structure * np.cos(phase)
    
    return u, v, Phi

def create_kelvin_animation_improved(
        lon_range=(60, 180),
        lat_range=(-30, 30),
        nx=300, ny=150,
        total_frames=100,
        wave_number=2,
        save_path=None):
    """
    创建改进的Kelvin波动画
    """
    
    # 空间网格
    lon = np.linspace(lon_range[0], lon_range[1], nx)
    lat = np.linspace(lat_range[0], lat_range[1], ny)
    LON, LAT = np.meshgrid(lon, lat)
    
    # 物理空间坐标 (归一化)
    x_phys = np.linspace(0, 2*np.pi*wave_number, nx)
    y_phys = np.linspace(lat_range[0], lat_range[1], ny)
    X, Y = np.meshgrid(x_phys, y_phys)
    
    # 波动参数
    k = wave_number  # 波数
    omega = 1.0      # 频率
    c = omega / k    # 波速
    
    # 配色方案
    cmap = create_improved_colormap()
    
    # 预计算所有帧
    print(f"预计算 {total_frames} 帧数据...")
    frames_data = []
    
    for frame in range(total_frames):
        t = frame * 2 * np.pi / total_frames
        u, v, Phi = kelvin_wave_field(X, Y, t, k=k, omega=omega, c=c, y_scale=8.0)
        frames_data.append({'u': u, 'v': v, 'Phi': Phi, 't': t})
    
    print("开始渲染动画...")
    
    # 创建图形
    # plt.switch_backend('Agg')
    projection = ccrs.PlateCarree()
    fig = plt.figure(figsize=(15, 9), dpi=200)
    ax = fig.add_subplot(111, projection=projection)
    
    # 调整子图位置以确保标签可见
    fig.subplots_adjust(left=0.08, right=0.95, top=0.92, bottom=0.08)
    
    def animate(frame):
        ax.clear()
        
        data = frames_data[frame]
        u = data['u']
        v = data['v']
        Phi = data['Phi']
        
        # 地图背景
        ax.coastlines(resolution='110m', linewidth=0.8, color='black', alpha=0.6)
        ax.add_feature(cfeature.LAND, facecolor='white', alpha=0.3, 
                      edgecolor='black', linewidth=0.3)
        
        # 绘制位势高度填色图 (主要视觉元素)
        levels_fill = np.linspace(-1., 1., 21)
        cf = ax.contourf(LON, LAT, Phi, levels=levels_fill, 
                        cmap=cmap, extend='both', 
                        transform=projection, alpha=0.85)
        
        # 叠加清晰的等值线 (类似参考图)
        levels_contour = np.linspace(-1.0, 1.0, 11)
        cs_solid = ax.contour(LON, LAT, Phi, levels=levels_contour[levels_contour > 0], 
                             colors='black', linewidths=1.5, 
                             transform=projection, alpha=0.8)
        cs_dashed = ax.contour(LON, LAT, Phi, levels=levels_contour[levels_contour < 0], 
                              colors='black', linewidths=1.5, linestyles='dashed',
                              transform=projection, alpha=0.8)
        
        # 标注等值线数值
        ax.clabel(cs_solid, inline=True, fontsize=9, fmt='%.2f')
        ax.clabel(cs_dashed, inline=True, fontsize=9, fmt='%.2f')
        
        # 风矢量 (显示速度场)
        step = 10
        Q = ax.quiver(LON[::step, ::step], LAT[::step, ::step],
                    u[::step, ::step], v[::step, ::step],
                    scale=50, 
                    #  scale_units='inches', width=0.003,
                    color='darkblue', alpha=0.7, transform=projection,
                    headwidth=4, headlength=5,)
        # step_quiver = 10
        # Q = ax.quiver(LON[::step_quiver, ::step_quiver], 
        #              LAT[::step_quiver, ::step_quiver],
        #              u[::step_quiver, ::step_quiver], 
        #              v[::step_quiver, ::step_quiver],
        #              scale=8, 
        #             #  scale_units='inches', width=0.003,
        #              color='black', alpha=0.7, transform=projection,
        #              headwidth=4, headlength=5, headaxislength=4)
        
        # 赤道标注
        ax.plot([lon_range[0], lon_range[1]], [0, 0], color='red', linewidth=2, 
                linestyle='-', alpha=0.7, transform=projection)
       
        
        # 设置地图范围和网格
        ax.set_extent([lon_range[0], lon_range[1], lat_range[0], lat_range[1]], 
                     crs=projection)
        
        gl = ax.gridlines(draw_labels=True, linewidth=0.75, 
                         color='gray', alpha=0.3, linestyle='--')
        gl.top_labels = False
        gl.right_labels = False
        gl.xlabel_style = {'size': 18}
        gl.ylabel_style = {'size': 18}
        
        # 标题
        phase_deg = (frame / total_frames) * 360
        ax.set_title(
            f'Kelvin Wave (K*={wave_number}) - Eastward Propagation\n'
            f'Time: {data["t"]:.2f}s  |  Phase: {phase_deg:.1f}°  |  Frame: {frame+1}/{total_frames}',
            fontsize=13, fontweight='bold', pad=15
        )
        
        # 添加说明文本
        info_text = (
            f'Wave Parameters:\n'
            f'• Wavenumber k = {wave_number}\n'
            f'• Direction: West → East\n'
            f'• Equatorial trapped'
        )
        ax.text(0.02, 0.98, info_text, transform=ax.transAxes,
               fontsize=10, verticalalignment='top',
               bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.8))
        
        return [cf, cs_solid, cs_dashed, Q]
    
    # 创建动画
    anim = animation.FuncAnimation(fig, animate, frames=total_frames,
                                  interval=100, blit=False, repeat=True)
    
    # 保存动画
    if save_path:
        print(f"保存动画到 {save_path}...")
        try:
            anim.save(save_path, writer='pillow', fps=10, dpi=100)
            print(f"✓ 成功保存!")
        except Exception as e:
            print(f"✗ 保存失败: {e}")
            print("尝试逐帧保存...")
            save_frame_by_frame(frames_data, projection, lon_range, lat_range, 
                              LON, LAT, save_path, wave_number, total_frames, cmap)
    
    return fig, anim

def save_frame_by_frame(frames_data, projection, lon_range, lat_range, 
                       LON, LAT, save_path, wave_number, total_frames, cmap):
    """逐帧保存方法 (备用)"""
    from PIL import Image
    import tempfile
    import os
    import shutil
    
    temp_dir = tempfile.mkdtemp()
    print(f"临时目录: {temp_dir}")
    
    images = []
    
    for idx, data in enumerate(frames_data):
        print(f"渲染第 {idx+1}/{total_frames} 帧", end='\r')
        
        fig_temp = plt.figure(figsize=(18, 8), dpi=80)
        ax_temp = fig_temp.add_subplot(111, projection=projection)
        
        u, v, Phi = data['u'], data['v'], data['Phi']
        
        # 背景
        ax_temp.coastlines(resolution='110m', linewidth=0.8, color='black', alpha=0.6)
        ax_temp.add_feature(cfeature.LAND, facecolor='lightgray', alpha=0.3)
        
        # 填色
        levels_fill = np.linspace(-1, 1, 21)
        ax_temp.contourf(LON, LAT, Phi, levels=levels_fill, cmap=cmap, 
                        extend='both', transform=projection, alpha=0.85)
        
        # 等值线
        levels_contour = np.linspace(-1.0, 1.0, 11)
        cs_pos = ax_temp.contour(LON, LAT, Phi, 
                                levels=levels_contour[levels_contour > 0],
                                colors='black', linewidths=1.5, 
                                transform=projection, alpha=0.8)
        cs_neg = ax_temp.contour(LON, LAT, Phi,
                                levels=levels_contour[levels_contour < 0],
                                colors='black', linewidths=1.5, 
                                linestyles='dashed',
                                transform=projection, alpha=0.8)
        
        ax_temp.clabel(cs_pos, inline=True, fontsize=8, fmt='%.2f')
        ax_temp.clabel(cs_neg, inline=True, fontsize=8, fmt='%.2f')
        
        # 风矢量
        step_quiver = 30
        ax_temp.quiver(LON[::step_quiver, ::step_quiver],
                      LAT[::step_quiver, ::step_quiver],
                      u[::step_quiver, ::step_quiver],
                      v[::step_quiver, ::step_quiver],
                      scale=8,
                    #   scale_units='inches', width=0.003,
                      color='black', alpha=0.7, transform=projection,
                      headwidth=4, headlength=5, headaxislength=4)
        
        # 赤道
        ax_temp.plot([lon_range[0], lon_range[1]], [0, 0], color='red', 
                    linewidth=2, alpha=0.7, transform=projection)
        
        ax_temp.set_extent([lon_range[0], lon_range[1], 
                           lat_range[0], lat_range[1]], crs=projection)
        
        gl = ax_temp.gridlines(draw_labels=True, linewidth=0.5, 
                              color='gray', alpha=0.3, linestyle='--')
        gl.top_labels = False
        gl.right_labels = False
        
        phase_deg = (idx / total_frames) * 360
        ax_temp.set_title(
            f'Kelvin Wave (K*={wave_number}) | Phase: {phase_deg:.1f}° | Frame {idx+1}/{total_frames}',
            fontsize=12, fontweight='bold'
        )
        
        # 保存帧
        frame_path = os.path.join(temp_dir, f'frame_{idx:04d}.png')
        fig_temp.savefig(frame_path, dpi=80, bbox_inches='tight')
        images.append(Image.open(frame_path))
        plt.close(fig_temp)
    
    print("\n合并为GIF...")
    images[0].save(save_path, save_all=True, append_images=images[1:],
                  duration=100, loop=0, optimize=False)
    
    shutil.rmtree(temp_dir)
    print(f"✓ 保存成功: {save_path}")

def plot_single_snapshot_improved(
        lon_range=(60, 180),
        lat_range=(-30, 30),
        t=0.0,
        wave_number=2,
        save_path=None):
    """绘制单帧高质量快照"""
    
    nx, ny = 400, 200
    lon = np.linspace(lon_range[0], lon_range[1], nx)
    lat = np.linspace(lat_range[0], lat_range[1], ny)
    LON, LAT = np.meshgrid(lon, lat)
    
    x_phys = np.linspace(0, 2*np.pi*wave_number, nx)
    y_phys = np.linspace(lat_range[0], lat_range[1], ny)
    X, Y = np.meshgrid(x_phys, y_phys)
    
    k = wave_number
    omega = 1.0
    c = omega / k
    
    u, v, Phi = kelvin_wave_field(X, Y, t, k=k, omega=omega, c=c, y_scale=8.0)
    
    cmap = create_improved_colormap()
    projection = ccrs.PlateCarree()
    
    fig = plt.figure(figsize=(18, 8), dpi=150)
    ax = fig.add_subplot(111, projection=projection)
    
    # 地图背景
    ax.coastlines(resolution='50m', linewidth=1, color='black', alpha=0.7)
    ax.add_feature(cfeature.LAND, facecolor='lightgray', alpha=0.3,
                  edgecolor='black', linewidth=0.5)
    ax.add_feature(cfeature.OCEAN, facecolor='white', alpha=0.2)
    
    # 填色
    levels_fill = np.linspace(-1., 1., 51)
    cf = ax.contourf(LON, LAT, Phi, levels=levels_fill, cmap=cmap,
                    extend='neither', transform=projection, alpha=0.9)
    
    cbar = plt.colorbar(cf, ax=ax, orientation='horizontal',
                       pad=0.08, shrink=0.75, aspect=40)
    cbar.set_label('Geopotential Height Anomaly (Φ)', fontsize=11, fontweight='bold')
    
    # 等值线
    levels_contour = np.linspace(-1.0, 1.0, 21)
    cs_pos = ax.contour(LON, LAT, Phi, levels=levels_contour[levels_contour > 0],
                       colors='black', linewidths=2, transform=projection, alpha=0.9)
    cs_neg = ax.contour(LON, LAT, Phi, levels=levels_contour[levels_contour < 0],
                       colors='black', linewidths=2, linestyles='dashed',
                       transform=projection, alpha=0.9)
    
    ax.clabel(cs_pos, inline=True, fontsize=10, fmt='%.2f', inline_spacing=10)
    ax.clabel(cs_neg, inline=True, fontsize=10, fmt='%.2f', inline_spacing=10)
    
    # 风矢量
    step = 10
    Q = ax.quiver(LON[::step, ::step], LAT[::step, ::step],
                 u[::step, ::step], v[::step, ::step],
                 scale=50, 
                #  scale_units='inches', width=0.003,
                 color='darkblue', alpha=0.7, transform=projection,
                 headwidth=4, headlength=5,)
    
    ax.quiverkey(Q, 0.9, 0.95, 1.0, '1.0 m/s', labelpos='E',
                coordinates='axes', fontproperties={'size': 10, 'weight': 'bold'})
    
    # 赤道
    ax.plot([lon_range[0], lon_range[1]], [0, 0], color='red', linewidth=2.5, 
            linestyle='-', alpha=0.8, transform=projection)

    
    # 地图范围和网格
    ax.set_extent([lon_range[0], lon_range[1], lat_range[0], lat_range[1]],
                 crs=projection)
    
    gl = ax.gridlines(draw_labels=True, linewidth=0.8,
                     color='gray', alpha=0.4, linestyle='--')
    gl.top_labels = False
    gl.right_labels = False
    gl.xlabel_style = {'size': 20}
    gl.ylabel_style = {'size': 20}
    
    # 标题
    ax.set_title(
        f'Kelvin Wave Snapshot (K* = {wave_number}) - t = {t:.2f}s\n'
        f'Equatorial Trapped Eastward Propagating Wave',
        fontsize=14, fontweight='bold', pad=20
    )

 
    if save_path:
        fig.savefig(save_path, dpi=150, bbox_inches='tight')
        print(f"✓ 快照已保存: {save_path}")
    
    plt.show()
    return fig, ax

if __name__ == '__main__':
    print("=" * 70)
    print("改进的Kelvin波可视化程序")
    print("=" * 70)
    
    # 选项1: 生成高质量单帧快照 (推荐先看效果)
    print("\n[1] 生成高质量快照...")
    plot_single_snapshot_improved(
        lon_range=(60, 180),
        lat_range=(-25, 25),
        t=0.0,
        wave_number=1.5,
        save_path='kelvin_wave_improved_snapshot.png'
    )
    
    # 选项2: 生成动画 
    print("\n[2] 生成动画...")
    create_kelvin_animation_improved(
        lon_range=(60, 180),
        lat_range=(-25, 25),
        nx=300,
        ny=150,
        total_frames=50,
        wave_number=1.5,
        save_path='kelvin_wave_improved.gif'
    )
    
    print("\n完成!")

改进的Kelvin波可视化程序

[1] 生成高质量快照...
✓ 快照已保存: kelvin_wave_improved_snapshot.png

[2] 生成动画...
预计算 50 帧数据...
开始渲染动画...
保存动画到 kelvin_wave_improved.gif...
✓ 快照已保存: kelvin_wave_improved_snapshot.png

[2] 生成动画...
预计算 50 帧数据...
开始渲染动画...
保存动画到 kelvin_wave_improved.gif...


  plt.show()


✓ 成功保存!

完成!
