# Solutions for Session 5 Part 2: Animations of particle trajectories with matplotlib and cartopy.

Data used: CaribbeanCurrent_1994.zarr

About this Parcels dataset: 
Particles are released every 5 days on the transect between Venezuela (mainland) and the island of Grenada.
Particles are advected with the 2D geostrophic flow computed from the Copernicus model output GLORYS12V1 
(1/12° horizontal resolution, 50 vertical levels) for the year 1994.
In order to decrease the data size for this workshop, outputs are stored every 12 hours.

In this script we demonstrate how to:
1) Plot the simplified animation of particle trajectories (using the correct time dimension)
2) Create backward animation
3) Start animation at a chosen time (e.g., 1st of February 1994)
4) Use different color of particles for every new time of release (every 5 days)
5) Add trails to the particles

In [None]:
import matplotlib.pyplot as plt
import matplotlib.animation
import numpy as np
import xarray as xr
import cartopy.crs as ccrs
import cartopy.feature as cfeature
from datetime import timedelta
import matplotlib.cm as cm

In order to see the animation inline in notebook, we need to activate it:

In [2]:
# for interactive display of animation
plt.rcParams["animation.html"] = "jshtml"

Load the data and get familiar with it:

In [None]:
# Load dataset
ds = xr.open_zarr('data/CaribbeanCurrent_1994.zarr')
print(f"Loaded: {len(ds.trajectory)} particles")

# For performance, load only lon, lat, time into memory
print("Loading subset into memory...")
ds = ds[["lon", "lat", "time"]].load()

ds

We need to set up the time dimension for plotting the particles correctly based on their release times. This particular simulation has an output dt of 12 hours. We set up the **timerange** which determines at which times the particles will start the animation.

In [None]:
# Setup time dimension:
# For this example our output is stored at every 12 hours
outputdt = timedelta(hours=12)

# Create timerange from min to max time in your dataset
timerange = np.arange(
    np.nanmin(ds["time"].values),
    np.nanmax(ds["time"].values) + np.timedelta64(outputdt),
    outputdt,
)
print(f"Timerange has {len(timerange)} timesteps from {timerange[0]} to {timerange[-1]}")

## **Simple animation**

Following is the code for creating a simple animation, taking into the account the **timerange** and plotting the trajectories as scatter plots.

In [None]:
# Number of timesteps to animate
nframes = 50    # use less frames for testing purposes

# figure setup
fig, ax = plt.subplots(figsize=(6, 5), subplot_kw={'projection': ccrs.PlateCarree()})
ax.set_xlim(-71, -59)
ax.set_ylim(9.5, 19.5)
ax.coastlines(color='saddlebrown')
ax.add_feature(cfeature.LAND, alpha=0.5, facecolor='saddlebrown')

# Find particles at the first time step
time_id = np.where(ds["time"] == timerange[0])
initial_lons = ds["lon"].values[time_id]
initial_lats = ds["lat"].values[time_id]

# Remove any NaN values for initial plot§
valid_initial = ~np.isnan(initial_lons) & ~np.isnan(initial_lats)
# plot first timestep
scatter = ax.scatter(
    initial_lons[valid_initial], 
    initial_lats[valid_initial], 
    s=2, 
    c='b'
)

# Set initial title
t_str = str(timerange[0])[:19]  # Format datetime nicely
title = ax.set_title(f"Particles at t = {t_str}")

# loop over for animation
def animate(i):

    print(f"Animating frame {i+1}/{len(timerange)} at time {timerange[i]}")
    t_str = str(timerange[i])[:19]
    title.set_text(f"Particles at t = {t_str}")
    
    # Find particles at current time
    time_id = np.where(ds["time"] == timerange[i])
    current_lons = ds["lon"].values[time_id]
    current_lats = ds["lat"].values[time_id]
    
    # Remove NaN values
    valid = ~np.isnan(current_lons) & ~np.isnan(current_lats)
    
    # Update scatter plot positions using scatter.set_offsets
    if np.any(valid):
        scatter.set_offsets(np.c_[current_lons[valid], current_lats[valid]])
    else:
        scatter.set_offsets(np.empty((0, 2)))  # Empty array if no valid particles

# Create animation
anim = matplotlib.animation.FuncAnimation(fig, animate, frames=nframes, interval=100)
anim

## **Backward animation**

Create animation that starts at the last existing timestep of the Parcels dataset. 

In [None]:
# only line to be added in above code to create the simulation backward: 
timerange = timerange[::-1]

# Number of timesteps to animate
nframes = 50    # use less frames for testing purposes

# figure setup
fig, ax = plt.subplots(figsize=(6, 5), subplot_kw={'projection': ccrs.PlateCarree()})
ax.set_xlim(-71, -59)
ax.set_ylim(9.5, 19.5)
ax.coastlines(color='saddlebrown')
ax.add_feature(cfeature.LAND, alpha=0.5, facecolor='saddlebrown')

# Find particles at the first time step
time_id = np.where(ds["time"] == timerange[0])
initial_lons = ds["lon"].values[time_id]
initial_lats = ds["lat"].values[time_id]

# Remove any NaN values for initial plot§
valid_initial = ~np.isnan(initial_lons) & ~np.isnan(initial_lats)
# plot first timestep
scatter = ax.scatter(
    initial_lons[valid_initial], 
    initial_lats[valid_initial], 
    s=2, 
    c='b'
)

# Set initial title
t_str = str(timerange[0])[:19]  # Format datetime nicely
title = ax.set_title(f"Particles at t = {t_str}")

# loop over for animation
def animate(i):

    print(f"Animating frame {i+1}/{len(timerange)} at time {timerange[i]}")
    t_str = str(timerange[i])[:19]
    title.set_text(f"Particles at t = {t_str}")
    
    # Find particles at current time
    time_id = np.where(ds["time"] == timerange[i])
    current_lons = ds["lon"].values[time_id]
    current_lats = ds["lat"].values[time_id]
    
    # Remove NaN values
    valid = ~np.isnan(current_lons) & ~np.isnan(current_lats)
    
    # Update scatter plot positions using scatter.set_offsets
    if np.any(valid):
        scatter.set_offsets(np.c_[current_lons[valid], current_lats[valid]])
    else:
        scatter.set_offsets(np.empty((0, 2)))  # Empty array if no valid particles

# Create animation
anim = matplotlib.animation.FuncAnimation(fig, animate, frames=nframes, interval=100)
anim

## **Starting animation at a chosen time**

Simple solution: add **i=i+31*2** (because we have timestep at every 12 hours)

In [None]:
# make sure you set back the time to forward! 
timerange = timerange[::-1]

# Number of timesteps to animate
nframes = 50    # use less frames for testing purposes

# figure setup
fig, ax = plt.subplots(figsize=(6, 5), subplot_kw={'projection': ccrs.PlateCarree()})
ax.set_xlim(-71, -59)
ax.set_ylim(9.5, 19.5)
ax.coastlines(color='saddlebrown')
ax.add_feature(cfeature.LAND, alpha=0.5, facecolor='saddlebrown')

# Find particles at the first time step
time_id = np.where(ds["time"] == timerange[0])
initial_lons = ds["lon"].values[time_id]
initial_lats = ds["lat"].values[time_id]

# Remove any NaN values for initial plot§
valid_initial = ~np.isnan(initial_lons) & ~np.isnan(initial_lats)
# plot first timestep
scatter = ax.scatter(
    initial_lons[valid_initial], 
    initial_lats[valid_initial], 
    s=2, 
    c='b'
)

# Set initial title
t_str = str(timerange[0])[:19]  # Format datetime nicely
title = ax.set_title(f"Particles at t = {t_str}")

# loop over for animation
def animate(i):
    i = i+31*2
    print(f"Animating frame {i+1}/{len(timerange)} at time {timerange[i]}")
    t_str = str(timerange[i])[:19]
    title.set_text(f"Particles at t = {t_str}")
    
    # Find particles at current time
    time_id = np.where(ds["time"] == timerange[i])
    current_lons = ds["lon"].values[time_id]
    current_lats = ds["lat"].values[time_id]
    
    # Remove NaN values
    valid = ~np.isnan(current_lons) & ~np.isnan(current_lats)
    
    # Update scatter plot positions using scatter.set_offsets
    if np.any(valid):
        scatter.set_offsets(np.c_[current_lons[valid], current_lats[valid]])
    else:
        scatter.set_offsets(np.empty((0, 2)))  # Empty array if no valid particles

# Create animation
anim = matplotlib.animation.FuncAnimation(fig, animate, frames=nframes, interval=100)
anim

## **Adding colors to particles**

Use different color of particles for every new time of release. 

In [None]:
# Number of timesteps to animate
nframes = 50    # use less frames for testing purposes

#--> Set up the colors and associated trajectories:
# 1) get release times for each particle (first valide obs for each trajectory)
release_times = ds["time"].min(dim="obs", skipna=True).values

# 2) get unique release times and assign colors
unique_release_times = np.unique(release_times[~np.isnat(release_times)])
n_release_times = len(unique_release_times)
print(f"Number of unique release times: {n_release_times}")

# 3) choose a continuous colormap
colormap = matplotlib.colormaps['tab20b']

# 4) set up a unique color for each release time
release_time_to_color = {}
for i, release_time in enumerate(unique_release_times):
    release_time_to_color[release_time] = colormap(i / max(n_release_times-1, 1))


# figure setup
fig, ax = plt.subplots(figsize=(6, 5), subplot_kw={'projection': ccrs.PlateCarree()})
ax.set_xlim(-71, -59)
ax.set_ylim(9.5, 19.5)
ax.coastlines(color='saddlebrown')
ax.add_feature(cfeature.LAND, alpha=0.5, facecolor='saddlebrown')

# Find particles at the first time step
time_id = np.where(ds["time"] == timerange[0])
initial_lons = ds["lon"].values[time_id]
initial_lats = ds["lat"].values[time_id]

# Remove any NaN values for initial plot
valid_initial = ~np.isnan(initial_lons) & ~np.isnan(initial_lats)

#--> Get release times for these particles
initial_release_times = release_times[time_id[0]]

#--> Create colors array for initial particles
initial_colors = [release_time_to_color[rt] for rt in initial_release_times[valid_initial]]

# plot first timestep
scatter = ax.scatter(
    initial_lons[valid_initial], 
    initial_lats[valid_initial], 
    s=2, 
    c=initial_colors 
)

# Set initial title
t_str = str(timerange[0])[:19]  # Format datetime nicely
title = ax.set_title(f"Particles at t = {t_str}")

# loop over for animation
def animate(i):

    print(f"Animating frame {i+1}/{len(timerange)} at time {timerange[i]}")
    t_str = str(timerange[i])[:19]
    title.set_text(f"Particles at t = {t_str}")
    
    # Find particles at current time
    time_id = np.where(ds["time"] == timerange[i])
    current_lons = ds["lon"].values[time_id]
    current_lats = ds["lat"].values[time_id]
    
    # Remove NaN values
    valid = ~np.isnan(current_lons) & ~np.isnan(current_lats)

    #--> Get release times for these particles
    current_release_times = release_times[time_id[0]]
    
    # Update scatter plot positions using scatter.set_offsets
    if np.any(valid):

        valid_release_times = current_release_times[valid]

        current_colors = []
        for rt in valid_release_times:
            current_colors.append(release_time_to_color[rt])

        scatter.set_offsets(np.c_[current_lons[valid], current_lats[valid]])
        scatter.set_color(current_colors)
    else:
        scatter.set_offsets(np.empty((0, 2)))

# Create animation
anim = matplotlib.animation.FuncAnimation(fig, animate, frames=nframes, interval=100)
anim

## **Particle trails**

Add particle trails showing the last 10 days of trajectory.

In [None]:
# Number of timesteps to animate
nframes = 50        # use less frames for testing purposes
nreducedtrails = 10 # every 10th particle will have a trail (if 1, all particles have trails. Adjust for faster performance)


# Set up the colors and associated trajectories:
# get release times for each particle (first valide obs for each trajectory)
release_times = ds["time"].min(dim="obs", skipna=True).values

# get unique release times and assign colors
unique_release_times = np.unique(release_times[~np.isnat(release_times)])
n_release_times = len(unique_release_times)
print(f"Number of unique release times: {n_release_times}")

# choose a continuous colormap
colormap = matplotlib.colormaps['tab20b']

# set up a unique color for each release time
release_time_to_color = {}
for i, release_time in enumerate(unique_release_times):
    release_time_to_color[release_time] = colormap(i / max(n_release_times-1, 1))


#--> Store data for all timeframes (this is needed for faster performance)
print("Pre-computing all particle positions...")
all_particles_data = []
for i, target_time in enumerate(timerange):
    time_id = np.where(ds["time"] == target_time)
    lons = ds["lon"].values[time_id]
    lats = ds["lat"].values[time_id]
    particle_indices = time_id[0]
    valid = ~np.isnan(lons) & ~np.isnan(lats)
    
    all_particles_data.append({
        'lons': lons[valid],
        'lats': lats[valid], 
        'particle_indices': particle_indices[valid],
        'valid_count': np.sum(valid)
    })
    

# figure setup
fig, ax = plt.subplots(figsize=(6, 5), subplot_kw={'projection': ccrs.PlateCarree()})
ax.set_xlim(-71, -59)
ax.set_ylim(9.5, 19.5)
ax.coastlines(color='saddlebrown')
ax.add_feature(cfeature.LAND, alpha=0.5, facecolor='saddlebrown')

#--> Use pre-computed data for initial setup
initial_data = all_particles_data[0]
initial_colors = []
for particle_idx in initial_data['particle_indices']:
    rt = release_times[particle_idx]
    if rt in release_time_to_color:
        initial_colors.append(release_time_to_color[rt])
    else:
        initial_colors.append('blue')

#--> plot first timestep
scatter = ax.scatter(
    initial_data['lons'],
    initial_data['lats'],
    s=2,
    c=initial_colors
)

#--> initialize trails
trail_plot = []

# Set initial title
t_str = str(timerange[0])[:19]  # Format datetime nicely
title = ax.set_title(f"Particles at t = {t_str}")

# loop over for animation
def animate(i):

    print(f"Animating frame {i+1}/{len(timerange)} at time {timerange[i]}")
    t_str = str(timerange[i])[:19]
    title.set_text(f"Particles at t = {t_str}")
    
    # Find particles at current time
    current_data = all_particles_data[i]

    if current_data['valid_count'] > 0:
        current_colors = []
        for particle_idx in current_data['particle_indices']:
            rt = release_times[particle_idx]
            current_colors.append(release_time_to_color[rt])

        scatter.set_offsets(np.c_[current_data['lons'], current_data['lats']])
        scatter.set_color(current_colors)

        #--> add trails

        for trail in trail_plot:
            trail.remove()
        trail_plot.clear()  

        trail_length = min(10, i) # trails will have max length of 10 time steps

        if trail_length > 0:
            sampled_particles = current_data['particle_indices'][::nreducedtrails]   # use all or sample if you want faster computation

            for particle_idx in sampled_particles:
                trail_lons = []
                trail_lats = []
                for j in range(i-trail_length, i+1):
                    past_data = all_particles_data[j]
                    if particle_idx in past_data['particle_indices']:
                        idx = np.where(past_data['particle_indices'] == particle_idx)[0][0]
                        trail_lons.append(past_data['lons'][idx])
                        trail_lats.append(past_data['lats'][idx])
                if len(trail_lons) > 1:
                    rt = release_times[particle_idx]
                    color = release_time_to_color[rt]
                    trail, = ax.plot(trail_lons, trail_lats, color=color, linewidth=0.3, alpha=0.6)
                    trail_plot.append(trail)

    else:
        scatter.set_offsets(np.empty((0, 2)))

# Create animation
anim = matplotlib.animation.FuncAnimation(fig, animate, frames=nframes, interval=100)
anim