In [None]:
base_dir = '/home/mkoch/coreai/tc_tracking_e2s/earth2studio/recipes/tc_tracking'
field_data = 'plotting_data_helene/helene_2024-09-24T00.00.00_mems0000-0015.nc'
track_dir = f'plotting_data_helene/cyclone_tracks_te'
reference_track = 'test/aux_data/reference_track_helene_2024_north_atlantic.csv'

variable = 'wind_speed'
ensemble_member = 12

max_frames = 99 # maximum number of frames to plot
scale = 1

region = 'gulf_of_mexico' # 'global' 'north_atlantic' 'gulf_of_mexico'

if region == 'global':
    lat_min, lat_max = -90, 90
    lon_min, lon_max = 0, 360
elif region == 'north_atlantic':
    lat_min, lat_max = 9, 65
    lon_min, lon_max = 250, 360
elif region == 'gulf_of_mexico':
    lat_min, lat_max = 9, 45
    lon_min, lon_max = 250, 310

save_frames = False

In [None]:
import cartopy.crs as ccrs
import cartopy.feature as cfeature
import matplotlib.colors as mcolors
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from matplotlib.colors import TwoSlopeNorm
import matplotlib.animation as animation
from cartopy.mpl.ticker import LongitudeFormatter, LatitudeFormatter
import os

import xarray as xr


field_data = os.path.join(base_dir, field_data)
base_name = field_data.split('/')[-1].split('.')[0]
out_dir = f'outputs_{base_name}'
if save_frames:
    os.makedirs(out_dir, exist_ok=True)


dx = scale*.25
ds = xr.open_dataset(field_data)

countries = cfeature.NaturalEarthFeature(
    category="cultural",
    name="admin_0_countries",
    scale="110m",
    facecolor="none",
    edgecolor="black",
)

reg_ds = ds.sel(lat=list(np.arange(lat_min, lat_max, dx)),
                lon=list(np.arange(lon_min, lon_max, dx)))

time_str = 'lead time:'
projection=ccrs.PlateCarree()
if variable == 'wind_speed':
    var_ds = np.sqrt(np.square(reg_ds.u10m) + np.square(reg_ds.v10m))
else:
    var_ds = reg_ds[variable] # np.sqrt(np.square(reg_ds.u10m) + np.square(reg_ds.v10m))

min_val = float(np.min(var_ds[0,:,:,:]))
max_val = float(np.max(var_ds[0,:,:,:]))

ds

In [None]:
from data_handling import extract_tracks_from_file, match_tracks

# extract time steps of field data
time_stamps = ds.time.values + ds.lead_time.values

# select track file of ensemble member
track_file = [f for f in os.listdir(os.path.join(base_dir, track_dir)) if f.endswith('.csv') and
              int(f.split('_mem_')[-1].split('_seed_')[0]) == ensemble_member][0]
track_file = os.path.join(base_dir, track_dir, track_file)

# extract tracks from prediction and match with ground truth track
tru_track = extract_tracks_from_file(os.path.join(base_dir, reference_track))
track = extract_tracks_from_file(os.path.join(base_dir, track_file))
track = [{"ic": time_stamps[0], "member": ensemble_member, "tracks": track}]
track = match_tracks(track, tru_track)

# extract the lines of tracks for wich track['time'] is in time_stamps
track = track[0]['tracks'][track[0]['tracks']['time'].isin(time_stamps)]

# fill the 'time' column with the time_stamps values
track = track.set_index('time').reindex(time_stamps).reset_index()

In [None]:
import matplotlib.colors as colors

# min_val = -30
# max_val = 30

# define plots
def make_figure():
    fig = plt.figure(figsize=(11,5))
    ax = fig.add_subplot(1, 1, 1, projection=projection)

    lon_formatter = LongitudeFormatter(zero_direction_label=False)
    lat_formatter = LatitudeFormatter()
    ax.xaxis.set_major_formatter(lon_formatter)
    ax.yaxis.set_major_formatter(lat_formatter)

    return fig, ax


# make animation
%matplotlib inline
plt.rcParams["animation.html"] = "jshtml"
fig, ax = make_figure()

def make_frame(frame):
    # frame = frame + 700

    print(f'\rprocessing frame {frame+1} of {min(max_frames, var_ds.shape[2])}', end='')

    # Clear previous plot objects
    for artist in ax.get_children():
        if hasattr(artist, 'get_array'):  # This targets pcolormesh objects
            artist.remove()


    plot_ds = var_ds[ensemble_member, 0, max(frame,0), :, :]
    pc = ax.pcolormesh(reg_ds.lon, reg_ds.lat, plot_ds, transform=projection,
                       cmap='plasma',
                       vmin=min_val, vmax=max_val,
                       )

    ax.add_feature(cfeature.COASTLINE,lw=.5)
    ax.add_feature(cfeature.RIVERS,lw=.5)

    ax.plot(track["lon"][:frame+1], track["lat"][:frame+1], transform=ccrs.PlateCarree(),
            color="black", linewidth=1., alpha=1)

    if frame==-1:
        cbar = fig.colorbar(pc, extend='both', shrink=0.8, ax=ax)

    header =f'{base_name} {variable} {time_str} {frame*6}:00:00'
    ax.set_title(header, fontsize=14)

    fig.savefig(f"{out_dir}/{base_name}_{variable}_{frame*6:04d}.png", dpi=450)

    return pc

def animate(frame):
    return make_frame(frame)

def first_frame():
    return make_frame(-1)

ani = animation.FuncAnimation(fig,
                              animate,
                              min(max_frames, var_ds.shape[2]),
                              init_func=first_frame,
                              blit=False,
                              repeat=False,
                              interval=.1)
plt.close('all')
ani