In [1]:
from SRGAN_funcs import load_and_combine_channels
from matplotlib.animation import FuncAnimation
import matplotlib.pyplot as plt
import numpy as np
import os

# Velocity field

In [None]:
lng = lng[:64]
lat = lat[:64]
lng_lr = lng[::4]
lat_lr = lat[::4]

In [None]:
T = vx.shape[0]
X = Y = 64

In [None]:
def plot_field(t, vx, vy, lng, lat, mask=None, scale=15, width=0.0025, d=2):
    # Get velocity components at time t
    v_x = vx[t]
    v_y = vy[t]
    
    # Find speed
    v = np.sqrt(v_x**2 + v_y**2)
    
    # Apply the mask if provided
    if mask is not None:
        v = np.ma.array(v, mask=mask)
    
    # Sample for plotting vectors
    x_arrows = v_x[::d, ::d]
    y_arrows = v_y[::d, ::d]
    xy = np.arange(0, len(v), d)
    
    # Plot speed
    plt.imshow(v, cmap='viridis', origin='lower')
    
    # Plot velocity field
    plt.quiver(xy, xy, x_arrows, y_arrows, scale=scale, color='white', width=width)
    
    # Beautify
    plt.colorbar(label=r'Speed = $\sqrt{v_x^2+v_y^2}$')
    
    if len(lng) > 10:
        nt = 8
        ticks = range(0, len(v), nt)
        x_ticks = lng[::nt].round(2)
        y_ticks = lat[::-nt].round(2)
    else:
        ticks = range(0, 10, 2)
        x_ticks = lng[::2].round(2)
        y_ticks = lat[::2].round(2)
    plt.xticks(ticks=ticks, labels=x_ticks, rotation=45)
    plt.yticks(ticks=ticks, labels=y_ticks)
    plt.xlabel('Longitude (°E)')
    plt.ylabel('Latitude (°N)')
    plt.title('Velocity Field')
    plt.tight_layout()

In [None]:
# Define the animation function
def animate_v(t, vx, vy, lng, lat, mask=None):
    plt.clf()  # Clear the current figure
    plot_field(t, vx, vy, lng, lat, mask)

In [None]:
os.makedirs('animations', exist_ok=True)

# Create a figure
fig = plt.figure()

# Create the animation
anim = FuncAnimation(fig, animate_v, frames=range(T), interval=150, fargs=(vx, vy, lng, lat))

# Save the animation as a GIF file
anim.save('reports/figures/animations/v_hr.gif', writer='pillow')

In [None]:
# Create a figure
fig = plt.figure()

# Create the animation
anim = FuncAnimation(fig, animate_v, frames=range(T), interval=150, fargs=(vx_lr, vy_lr, lng_lr, lat_lr))

# Save the animation as a GIF file
anim.save('reports/figures/animations/v_lr.gif', writer='pillow')

## Masks

In [None]:
# Create an empty mask of shape 23x23
mask = np.zeros((64, 64), dtype=int)

# Create a coastline on the left side
mask[:20, :5] = 1  # Vertical coastline
mask[20:, :3] = 1  # Vertical coastline
mask[5:13, 5:7] = 1  # Small extension into the water
mask[7:11, 7:8] = 1  # Small extension into the water
# Create a coastline on the top side
mask[:3, :18] = 1  # Horizontal coastline
mask[3:5, 10:15] = 1  # Extension into the water

# Add an island
mask[16:22, 14:19] = 1
mask[15:16, 15:18] = 1
mask[17:20, 19:20] = 1

# Visualize the mask
plt.figure(figsize=[3, 3])
plt.imshow(mask, cmap="Greys", origin="lower")
plt.title("Land Mask")
plt.colorbar(label="1 = Data, 0 = No data")
plt.show()

In [None]:
# Create a figure
fig = plt.figure()

# Create the animation
anim = FuncAnimation(fig, animate_v, frames=range(T), interval=150, fargs=(vx, vy, lng, lat, mask))

# Save the animation as a GIF file
anim.save('reports/figures/animations/v_masked.gif', writer='pillow')