# Plot Evolution of Storm Tracks and Atmospheric Fields

This notebook demonstrates how to visualize tropical cyclone (TC) tracks overlaid on atmospheric field data. Use this to validate tracking algorithm outputs or create publication-ready animations.

### Workflow Overview
1. Configure paths and plotting parameters
2. Load and subset field data from NetCDF
3. Load and align track data from CSV files
4. Generate animated visualization
5. (Optional) Save animation to file

### Data Prerequesits
- NetCDF file containing atmospheric field data
- Track CSV files generated by the TC tracking algorithm
Both files can be produced by configuring a tc_hunt run with `store_type: "netcdf"` and cyclone tracking switched on, as e.g. done in `cfg/reproduce_helene.yaml`.

### Configuration

**Path Settings**
- `field_data` - path to NetCDF file (e.g. `/path/to/outputs_reproduce_helene/helene_2024-09-24T00.00.00_mems0000-0013.nc`)
- `track_dir` - folder containing track CSVs (e.g. `/path/to/outputs_reproduce_helene/cyclone_tracks_te`)

**Plotting Settings**
- `variable` - field variable to plot (e.g., `'u10m'`, `'mslp'`). Use `'wind_speed'` to compute 10m wind magnitude from `u10m` and `v10m`
- `ensemble_member` - which ensemble member to visualize
- `region` - geographic region (`'global'`, `'north_atlantic'`, `'gulf_of_mexico'`)

**Animation Settings**
- `max_frames` - limit number of frames (set high to include full forecast)
- `scale` - spatial coarsening factor (1 = full resolution, 2 = half, etc.)
- `fps` - frames per second for animation playback


In [None]:
field_data = '/path/to/outputs_reproduce_helene/reproduce_helene_2024-09-24T00.00.00_mems0000-0013.nc'
track_dir = '/path/to/outputs_reproduce_helene/cyclone_tracks_te'

variable = 'wind_speed'
ensemble_member = 3
region = 'gulf_of_mexico'

max_frames = 99 # maximum number of frames to plot
scale = 1
fps = 4

if region == 'global':
    lat_min, lat_max = -90, 90
    lon_min, lon_max = 0, 359.75
elif region == 'west_pacific':
    lat_min, lat_max = 9, 60
    lon_min, lon_max = 100, 180
elif region == 'north_atlantic':
    lat_min, lat_max = 9, 65
    lon_min, lon_max = 250, 360
elif region == 'gulf_of_mexico':
    lat_min, lat_max = 9, 45
    lon_min, lon_max = 250, 310
elif region == 'north_indian':
    lat_min, lat_max = 9, 40
    lon_min, lon_max = 50, 100
elif region == 'southern_pacific':
    lat_min, lat_max = -40, 5
    lon_min, lon_max = 140, 240
else:
    raise ValueError(f'region {region} not yet implemented. Feel free to add it by providing lat/lon coords of the bounding box')

### Step 1: Load Field Data

- Read field data from NetCDF file
- Sub-select region and variable; coarsen data if `scale > 1` for faster iteration
- Compute min/max values for consistent colormap across all timesteps

In [None]:
import os
import cartopy.crs as ccrs
import cartopy.feature as cfeature
import matplotlib.pyplot as plt
import numpy as np
import matplotlib.animation as animation
from cartopy.mpl.ticker import LongitudeFormatter, LatitudeFormatter
import xarray as xr


base_name = field_data.split('/')[-1].split('.')[0]
out_dir = f'outputs_{base_name}'

ds = xr.open_dataset(field_data)

countries = cfeature.NaturalEarthFeature(
    category="cultural",
    name="admin_0_countries",
    scale="110m",
    facecolor="none",
    edgecolor="black",
)

# subselect lat/lon box, coarsen data if scale factor > 1
sub_ds = ds.sel(lat=list(np.arange(lat_min, lat_max, scale*.25)),
                lon=list(np.arange(lon_min, lon_max, scale*.25)))

# extract variable and obtain global min/max values
if variable == 'wind_speed':
    sub_ds = np.sqrt(np.square(sub_ds.u10m) + np.square(sub_ds.v10m))
else:
    sub_ds = sub_ds[variable] # np.sqrt(np.square(reg_ds.u10m) + np.square(reg_ds.v10m))

display(ds)

### Step 2: Load Track Data

This step loads and processes the cyclone track data to align with field timestamps.

**Processing steps:**
1. **Select file** - filter by `ensemble_member` from filenames ignoring the random seed
2. **Separate tracks** - split into individual DataFrames for each track
3. **Align timestamps** - reindex each track to match field data timesteps

**Why reindexing matters:**
Tracks may not have positions at every timestep (e.g., storm genesis/dissipation). Reindexing fills gaps with NaN, enabling frame-by-frame plotting without index errors.

In [None]:
from data_handling import extract_tracks_from_file

track_dir = os.path.abspath(track_dir)

# extract time steps of field data
time_stamps = ds.time.values + ds.lead_time.values

# select track file of ensemble member
track_file = [f for f in os.listdir(track_dir) if f.endswith('.csv') and
              int(f.split('_mem_')[-1].split('_seed_')[0]) == ensemble_member][0]
track_file = os.path.join(track_dir, track_file)

# extract tracks from prediction
tracks = extract_tracks_from_file(track_file)

# separate individual tracks in prediction
n_tracks = tracks["track_id"].iloc[-1] + 1
tracks = [tracks.loc[tracks["track_id"] == ii].copy() for ii in range(n_tracks)]

# align track data with simulation time steps
for ii in range(n_tracks):
    # extract the lines of tracks for which track['time'] is in time_stamps
    tracks[ii] = tracks[ii][tracks[ii]['time'].isin(time_stamps)]

    # fill the 'time' column with the time_stamps values
    tracks[ii] = tracks[ii].set_index('time').reindex(time_stamps).reset_index()

### Step 3: Create Animation

Creates an interactive animation where you can click through timesteps. Each frame shows the field data with tracks drawn progressively up to that point.

In [None]:
import warnings

colour_map = 'plasma'
projection=ccrs.PlateCarree()

min_val = float(np.min(sub_ds[ensemble_member, ...]))
max_val = float(np.max(sub_ds[ensemble_member, ...]))

# suppress line warnings stemming from potential NANs in track data
warnings.filterwarnings("ignore", message="invalid value encountered in linestrings")

# get index of ensemble member
ensemble_idx = np.argwhere(ds.ensemble.values == ensemble_member)[0,0]

# define plots
def make_figure():
    fig = plt.figure(figsize=(11,7))
    ax = fig.add_subplot(1, 1, 1, projection=projection)

    lon_formatter = LongitudeFormatter(zero_direction_label=False)
    lat_formatter = LatitudeFormatter()
    ax.xaxis.set_major_formatter(lon_formatter)
    ax.yaxis.set_major_formatter(lat_formatter)

    return fig, ax


# make animation
%matplotlib inline
plt.rcParams["animation.html"] = "jshtml"
fig, ax = make_figure()

def make_frame(frame):
    print(f'\rprocessing frame {frame+1} of {min(max_frames, sub_ds.shape[2])}', end='')

    # Clear previous plot objects
    for artist in ax.get_children():
        if hasattr(artist, 'get_array'):  # This targets pcolormesh objects
            artist.remove()


    plot_ds = sub_ds[ensemble_idx, 0, max(frame,0), :, :]
    pc = ax.pcolormesh(sub_ds.lon, sub_ds.lat, plot_ds, transform=projection,
                       cmap=colour_map,
                       vmin=min_val, vmax=max_val,
                       )

    ax.add_feature(cfeature.COASTLINE,lw=.5)
    ax.add_feature(cfeature.RIVERS,lw=.5)
    for track in tracks:
        ax.plot(track["lon"][:frame+1], track["lat"][:frame+1],
                transform=ccrs.PlateCarree(),
                color="lime", linewidth=1., alpha=1)

    # Enforce the plotting region extent
    ax.set_extent([lon_min, lon_max, lat_min, lat_max], crs=ccrs.PlateCarree())

    if frame==-1:
        cbar = fig.colorbar(pc, extend='both', shrink=0.8, ax=ax)

    header =f'{base_name} {variable} {'lead time'} {frame*6}h'
    ax.set_title(header, fontsize=14)

    return pc

def animate(frame):
    return make_frame(frame)

def first_frame():
    return make_frame(-1)

ani = animation.FuncAnimation(fig,
                              animate,
                              min(max_frames, sub_ds.shape[2]),
                              init_func=first_frame,
                              blit=False,
                              repeat=False,
                              interval=1000/fps)
plt.close('all')
ani

### Step 4 (Optional): Save Animation

Uncomment and run the cell below to save the animation as GIF.

In [None]:
plt.close('all')

# Recreate figure and animation
cbar = None
fig = plt.figure(figsize=(12, 8))
ax = fig.add_subplot(1, 1, 1, projection=projection)

ani = animation.FuncAnimation(fig, animate, min(max_frames, sub_ds.shape[2]),
                              init_func=first_frame, blit=False,
                              repeat=False, interval=1000/fps)

ani.save('tracks_n_fields_ani.gif', fps=fps, savefig_kwargs={'bbox_inches': 'tight', 'pad_inches': 0.1})
plt.close(fig)