# Import libraries

In [11]:
from dataclasses import dataclass
from collections import namedtuple
from typing import Tuple, Optional, List

import numpy as np
import xarray as xr
import matplotlib.pyplot as plt

from tqdm import tqdm
from scipy import signal
from matplotlib.quiver import Quiver
from matplotlib.figure import Figure

# 0. Utils for Vortex Simulation

In [13]:
@dataclass
class Point:
    """
    Coordinates in cartesian grid
    """
    x: float
    y: float
        
        
@dataclass
class EllipsisRadius:
    """
    Ellipsis radius values in cartesian grid
    """
    x: float
    y: float
        
        
@dataclass
class VortexVariationFunctions:
    x_center_t: callable
    y_center_t: callable
    x_radius_t: callable
    y_radius_t: callable

        
def normalize_vectors(x: np.ndarray, y: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
    """
    Get normalized vector (norm = 1)
    """
    norms = np.sqrt(x**2 + y**2)
    new_x = np.divide(x, norms, out=np.zeros_like(x), where=norms!=0)
    new_y = np.divide(y, norms, out=np.zeros_like(y), where=norms!=0)
    return new_x, new_y
    

class Vortex:
    """
    Define a vortex object
    """
    def __init__(self, center: Point, radius: EllipsisRadius, orientation: int) -> None:
        """
        Constructor for Vortex
        Args:
            - center: center of the Ellipsis
            - radius: ellipsis radius (on x and y axes)
            - orientation: 1 if trigo, -1 if clockwise
        """
        self.center = center
        self.radius = radius
        self.orientation = orientation
        assert orientation in (1, -1), "orientation should be among 1 or -1"
        
    def __repr__(self) -> None:
        return f"""Vortex(center=Point(x={self.center.x}, y={self.center.y}), radius=EllipsisRadius(x={self.radius.x}, y={self.radius.y}),orentation={self.orientation})"""
        
    def compute_distance_to_center(self, x: np.ndarray, y: np.ndarray) -> np.ndarray:
        """
        Compute a grid containing the distance to the Vortex center
        Args:
            - x: latitude
            - y: longitude
        """
        distance = np.sqrt(((y - self.center.y) / self.radius.y)**2 + ((x - self.center.x) / self.radius.x)**2)
        return distance
    
    def get_velocity_from_position(self, x: np.ndarray, y: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
        """
        Return velocity vectors from position
        Args:
            - x: latitude
            - y: longitude
        """
        return self._get_velocity_from_position(x, y)
    

    def plot_static(self, x: np.ndarray, y: np.ndarray, **kwargs) -> Tuple[Figure, Quiver]:
        """
        Plot a vortex
        Args:
            - x: latitude
            - y: longitude
        """
        # Get velocity vectors
        u, v = self.get_velocity_from_position(x, y)
        # Initialize figure and plot
        fig = plt.figure(figsize=(10, 7))
        quiver = plt.quiver(x, y, u, v, units="xy", scale=1)
        return fig, quiver
    
    def _get_velocity_from_position(self, x: np.ndarray, y: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
        """
        Return velocity vectors from position
        Args:
            - x: latitude
            - y: longitude
        """
        # Compute raw velocity
        u = - self.orientation * (y - self.center.y) / self.radius.y
        v = self.orientation * (x - self.center.x) / self.radius.x
        # Crop velocity
        distances = self.compute_distance_to_center(x, y)
        u, v = np.where(distances <= 1, u, 0), np.where(distances <= 1, v, 0)
        return u, v
    

class NormalizedVortex(Vortex):
    """
    Define a vortex object, with normalized velocity
    """
    def __init__(self, center: Point, radius: EllipsisRadius, orientation: int) -> None:
        super().__init__(center=center, radius=radius, orientation=orientation)
    
    def get_velocity_from_position(self, x: np.ndarray, y: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
        """
        Return velocity vectors from position
        Args:
            - x: latitude
            - y: longitude
        """
        # Get velocity from position from upper class
        u, v = self._get_velocity_from_position(x, y)
        # Normalize vectors
        u, v = normalize_vectors(u, v)
        return u, v

In [5]:
def update_vortex_at_time_t(
    time_step: int, 
    quiver: Quiver, 
    x: np.ndarray, 
    y: np.ndarray, 
    initial_vortex: Vortex,
    variation_functions: VortexVariationFunctions,
) -> Quiver:
    """
    Animate the Vortex: update it from initial position to new time step
    """
    # Define new parameters for new vortex
    new_center = Point(
        x=variation_functions.x_center_t(initial_vortex.center.x, time_step),
        y=variation_functions.y_center_t(initial_vortex.center.y, time_step),
    )
    new_radius = EllipsisRadius(
        x=variation_functions.x_radius_t(initial_vortex.radius.x, time_step),
        y=variation_functions.y_radius_t(initial_vortex.radius.y, time_step),
    )
    new_vortex = NormalizedVortex(center=new_center, radius=new_radius, orientation=initial_vortex.orientation)

    # Get new velocity
    u, v = new_vortex.get_velocity_from_position(x=x, y=y)
    
    # Update quiver
    quiver.set_UVC(u, v)
    return quiver


def run_vortex_animation(
    initial_vortex: Vortex,
    x: np.ndarray,
    y: np.ndarray,
    variation_functions: VortexVariationFunctions,
) -> None:
    """
    Run a Vortex animation
    Args:
        - initial_vortex: the initial vortex
        - x: latitude positions from meshgrid
        - y: longitude positions from meshgrid
        - variation_functions: temporal functions to update Vortex parameters
    """
    # Initialize plot
    fig, quiver = initial_vortex.plot_static(x, y)
    
    # Plot animation
    anim = animation.FuncAnimation(
        fig, 
        update_vortex_at_time_t, 
        fargs=(quiver, x, y, initial_vortex, variation_functions), 
        frames=50, 
        interval=20, 
        blit=False, 
        repeat=True
    )
    fig.tight_layout()
    return fig, anim

In [9]:
def convert_ms_to_knot(velocity_in_ms: np.ndarray) -> np.ndarray:
    return velocity_in_ms * 3600 / 1852

# 1. Utils for Vortex Recognition

In [8]:
def identify_depression(wind_speed, kernel_size, min_wind_depression):
    conv_ws = signal.convolve2d(wind_speed.values, np.ones((kernel_size, kernel_size)), "same")
    depression_idx = np.where(conv_ws * 3600 / (1852 * kernel_size**2) >= min_wind_depression, 1, 0)
    return depression_idx

def convolve_vectors(vectors: List[np.ndarray], kernel):
    for vector in vectors:
        convolved_vector = signal.convolve2d(vector, kernel, "same")
        yield convolved_vector
        
def get_centers_from_potential_area(lon_candidates: np.ndarray, lat_candidates: np.ndarray, deg_threshold: float) -> Tuple[np.ndarray, np.ndarray]:
    if len(lon_candidates) == 0 & len(lat_candidates) == 0:
        return [], []
    assert len(lon_candidates) > 0 and len(lon_candidates) == len(lat_candidates)
    lon_centers, lat_centers = [], []
    lon_to_avg, lat_to_avg = [], []
    prev_lon, prev_lat = lon_candidates[0], lat_candidates[0]
    for i in range(len(lon_candidates)):
        current_lon, currrent_lat = lon_candidates[i], lat_candidates[i]
        if np.abs(current_lon - prev_lon) <= deg_threshold and np.abs(currrent_lat - prev_lat) <= deg_threshold:
            lon_to_avg.append(current_lon)
            lat_to_avg.append(currrent_lat)
        else:
            lon_centers.append(np.mean(lon_to_avg))
            lat_centers.append(np.mean(lat_to_avg))
            lon_to_avg, lat_to_avg = [current_lon], [currrent_lat]
        prev_lon, prev_lat = current_lon, currrent_lat
    lon_centers.append(np.mean(lon_to_avg))
    lat_centers.append(np.mean(lat_to_avg))
    return lon_centers, lat_centers


def get_candidates_rot(longitude, latitude, rotational, divergence):
    condition = (rotational >= 1) & (divergence <= 0)
    lat_potential_idx, lon_potential_idx = np.where(condition)
    lon_candidates, lat_candidates = longitude[lon_potential_idx], latitude[lat_potential_idx]
    associated_rotational = rotational[condition]
    return tuple(zip(lon_candidates, lat_candidates, associated_rotational))

def get_centers_from_potential_area_rot(candidates: Tuple[np.ndarray], deg_threshold: float) -> Tuple[np.ndarray, np.ndarray]:
    if len(candidates) == 0:
        return [], []
    lon_centers, lat_centers = [], []
    lon_to_avg, lat_to_avg, tmp_rot = [], [], []
    prev_lon, prev_lat = candidates[0][0], candidates[0][1]
    for current_lon, current_lat, current_rot in candidates:
        if np.abs(current_lon - prev_lon) <= deg_threshold and np.abs(current_lat - prev_lat) <= deg_threshold:
            lon_to_avg.append(current_lon)
            lat_to_avg.append(current_lat)
            tmp_rot.append(current_rot)
        else:
            lon_centers.append(lon_to_avg[np.argmax(np.abs(tmp_rot))])
            lat_centers.append(lat_to_avg[np.argmax(np.abs(tmp_rot))])
            lon_to_avg, lat_to_avg, tmp_rot = [current_lon], [current_lat], [current_rot]
        prev_lon, prev_lat = current_lon, current_lat
    lon_centers.append(lon_to_avg[np.argmax(np.abs(tmp_rot))])
    lat_centers.append(lat_to_avg[np.argmax(np.abs(tmp_rot))])
    return lon_centers, lat_centers


def calculate_rotational2d(u: np.ndarray, v: np.ndarray) -> np.ndarray:
    rotational = np.gradient(v)[1] - np.gradient(u)[0]
    return rotational

def calculate_divergence2d(u: np.ndarray, v: np.ndarray) -> np.ndarray:
    divergence = np.gradient(u)[1] + np.gradient(v)[0]
    return divergence
    

def plot_rotational_and_divergence(lon, lat, rotational, divergence):
    plt.figure(figsize=(20, 6))
    plt.subplot(1, 2, 1)
    plt.title("Rotational")
    plt.pcolormesh(lon, lat, rotational, shading='nearest', cmap=plt.cm.get_cmap('coolwarm'))
    plt.colorbar()
    plt.quiver(lon, lat, u, v)
    plt.subplot(1, 2, 2)
    plt.title("Divergence")
    plt.pcolormesh(lon, lat, divergence, shading='nearest', cmap=plt.cm.get_cmap('coolwarm'))
    plt.colorbar()
    plt.quiver(lon, lat, u, v)
    
    
def plot_area_and_convolution(lon, lat, research_area, convolved_vectors):
    plt.figure(figsize=(20, 6))
    plt.subplot(1, 2, 1)
    plt.title("Potential area to convolve")
    plt.pcolormesh(lon, lat, research_area, shading='nearest', cmap=plt.cm.get_cmap('coolwarm'))
    plt.colorbar()
    plt.quiver(lon, lat, u, v)
    plt.subplot(1, 2, 2)
    plt.title("Convolved vectors")
    plt.quiver(lon, lat, convolved_vectors[0], convolved_vectors[1])
    

def identify_vortex_centers_using_convolution(longitude, latitude, u, v, plot: Optional[bool] = False) -> List[np.ndarray]:
    """
    From meshgrid and wind, identify vortex centers
    """
    u, v = u.copy(), v.copy()
    # Compute wind speed
    wind_speed = np.sqrt(u**2 + v**2)
    # Identify potential depression areas
    depression_indexes = identify_depression(wind_speed, kernel_size=8, min_wind_depression=15)
    # Normalize wind vectors
    u, v = normalize_vectors(u, v)
    # Convolve vectors
    u, v = convolve_vectors(vectors=[u, v], kernel=np.ones((3, 3)))
    # Identify potential areas for vortex center
    # potential_areas = np.where((np.sqrt(u**2 + v**2) <= 5) & (depression_indexes == 1), 1, 0)
    potential_areas = np.where((np.sqrt(u**2 + v**2) <= 5) & (depression_indexes == 1), 1, 0)
    # Convert to longitude & latitude
    (lat_potential_idx, lon_potential_idx) = np.where(potential_areas == 1)
    lon_candidates, lat_candidates = longitude[lon_potential_idx], latitude[lat_potential_idx]
    # Get centers from candidates
    lon_centers, lat_centers = get_centers_from_potential_area(lon_candidates, lat_candidates, deg_threshold=5)
    
    if plot:
        plot_area_and_convolution(lon, lat, depression_indexes, (u, v))
    return lon_centers, lat_centers


def identify_vortex_centers_with_rot(longitude, latitude, u, v, plot: Optional[bool] = False) -> List[np.ndarray]:
    """
    From meshgrid and wind, identify vortex centers
    """
    u, v = u.copy(), v.copy()
    # Calculate wind vector rotational & divergence
    rot_wind = calculate_rotational2d(u, v)
    div_wind = calculate_divergence2d(u, v)
    # Identify potential candidates
    (lat_potential_idx, lon_potential_idx) = np.where((rot_wind >= 6) & (div_wind <= 0))
    lon_candidates, lat_candidates = longitude[lon_potential_idx], latitude[lat_potential_idx]
    # Get centers from candidates
    lon_centers, lat_centers = get_centers_from_potential_area(lon_candidates, lat_candidates, deg_threshold=5)
    
    if plot:
        plot_rotational_and_divergence(lon, lat, rot_wind, div_wind)
    return lon_centers, lat_centers


def identify_vortex_centers_with_combined(longitude, latitude, u, v, kernel_size: int, plot: Optional[bool] = False) -> List[np.ndarray]:
    """
    From meshgrid and wind, identify vortex centers
    """
    u, v = u.copy(), v.copy()
    # Convolve vectors
    u, v = convolve_vectors(vectors=[u, v], kernel=np.ones((kernel_size, kernel_size)))
    # Normalize vectors
    u, v = normalize_vectors(u, v)
    # Calculate wind vector rotational & divergence
    rot_wind = calculate_rotational2d(u, v)
    div_wind = calculate_divergence2d(u, v)
    # Identify potential candidates
    candidates = get_candidates_rot(longitude, latitude, rot_wind, div_wind)
    # Get centers from candidates
    lon_centers, lat_centers = get_centers_from_potential_area_rot(candidates, deg_threshold=5)
    
    if plot:
        plot_rotational_and_divergence(lon, lat, rot_wind, div_wind)
    return lon_centers, lat_centers


def identify_vortex_radius(
    longitude_centers, latitude_centers, lon, lat, u, v, n_iterations
) -> Tuple[np.ndarray, np.ndarray]:
    radius_long, radius_lat = [], []
    for lon_center, lat_center in zip(longitude_centers, latitude_centers):
        # Get approximated longitude & latitude centers
        approx_lon_center_idx = np.where(lon == lon.sel({"longitude": lon_center}, method="nearest", tolerance=0.5))[0][0]
        approx_lat_center_idx = np.where(lat == lat.sel({"latitude": lat_center}, method="nearest", tolerance=0.5))[0][0]
        # Loop through potential radius candidates
        sum_long, sum_lat = [], []
        for k in range(n_iterations):
            # Compute differences between wind components
            diff_lat = np.abs(
                u[approx_lat_center_idx-k:approx_lat_center_idx+1, approx_lon_center_idx].values.sum() 
                + u[approx_lat_center_idx:approx_lat_center_idx+k+1, approx_lon_center_idx].values.sum()
            )
            diff_long = np.abs(
                v[approx_lat_center_idx, approx_lon_center_idx-k:approx_lon_center_idx+1].values.sum() 
                + v[approx_lat_center_idx, approx_lon_center_idx:approx_lon_center_idx+k+1].values.sum()
            )
            sum_lat.append(diff_lat)
            sum_long.append(diff_long)
        # Get radius as gradient minimum
        radius_long.append(np.argmin(np.gradient(sum_long)) + 1)
        radius_lat.append(np.argmin(np.gradient(sum_lat)) + 1)
    return radius_long, radius_lat


def identify_vortex_parameters(
    longitude: np.ndarray, 
    latitude: np.ndarray, 
    u: np.ndarray, 
    v: np.ndarray, 
    center_method: Optional[str] = "combined",
    n_radius_iterations: Optional[int] = 20,
    kernel_size_center: Optional[int] = 10,
) -> List[np.ndarray]:
    # Identify vortex centers
    if center_method == "convolution":
        lon_centers, lat_centers = identify_vortex_centers_using_convolution(longitude, latitude, u, v, plot=False)
    elif center_method == "rotational":
        lon_centers, lat_centers = identify_vortex_centers_with_rot(longitude, latitude, u, v, plot=False)
    elif center_method == "combined":
        lon_centers, lat_centers = identify_vortex_centers_with_combined(longitude, latitude, u, v, kernel_size=kernel_size_center, plot=False)
    else:
        raise ValueError("center_method should be among 'convolution', 'rotational', 'combined'")
        
    # Identify vortex radius
    assert isinstance(n_radius_iterations, int) and n_radius_iterations > 0
    radius_long, radius_lat = identify_vortex_radius(
        lon_centers, lat_centers, longitude, latitude, u, v, n_iterations=n_radius_iterations
    )
    
    # Identify vortex orientation
    # TODO: using rotational: >0 means depression, < means anticyclone
    orientations = [1] * len(lon_centers)

    # Cast as Vortex
    vortexes = []
    for x_center, y_center, x_radius, y_radius, orientation in zip(lon_centers, lat_centers, radius_long, radius_lat, orientations):
        center = Point(x=x_center, y=y_center)
        radius = EllipsisRadius(x=x_radius, y=y_radius)
        new_vortex = NormalizedVortex(center=center, radius=radius, orientation=orientation)
        vortexes.append(new_vortex)
    return vortexes


In [13]:
def plot_pressure_and_land(data, pressure_mask, no_land_mask):
    plt.figure(figsize=(20, 6))
    plt.subplot(1, 2, 1)
    plt.title("Low pressure")
    plt.pcolormesh(data["longitude"], data["latitude"], pressure_mask, shading='nearest', cmap=plt.cm.get_cmap('coolwarm'))
    plt.colorbar()
    plt.quiver(data["longitude"], data["latitude"], data["u_wind"], data["v_wind"])
    plt.subplot(1, 2, 2)
    plt.title("Not Land or Coast")
    plt.pcolormesh(data["longitude"], data["latitude"], no_land_mask, shading='nearest', cmap=plt.cm.get_cmap('coolwarm'))
    plt.colorbar()
    plt.quiver(data["longitude"], data["latitude"], data["u_wind"], data["v_wind"])

def get_center_candidates(data: xr.Dataset, plot: Optional[bool] = False) -> Tuple[Tuple[xr.DataArray]]:
    """
    Get center candidates
    """
    # Calculate rotational & divergence
    u_norm, v_norm = normalize_vectors(data["u_wind"], data["v_wind"])
    rot_wind = calculate_rotational2d(u_norm, v_norm)
    div_wind = calculate_divergence2d(u_norm, v_norm)
    
    # Get conditions to be a candidate
    pressure_condition = data["pressure"] <= 1015 # 1013.25  # Low pressure
    no_land_coast_condition = data["land_mask"] + data["coast_mask"] == 0  # No land & coast (for interferences)
    condidates_condition = (
        pressure_condition
        & no_land_coast_condition
        & (rot_wind >= 1)  # Rotational potential
        & (div_wind <= 0)  # Negative divergence (really ?)
    )
    
    # Identify candidates
    latitude_idx, longitude_idx = np.where(condidates_condition)
    lon_candidates, lat_candidates = data["longitude"][longitude_idx], data["latitude"][latitude_idx]
    associated_rotational = rot_wind[condidates_condition]
    
    # Plot
    if plot:
        plot_rotational_and_divergence(data["longitude"], data["latitude"], rot_wind, div_wind)
        plot_pressure_and_land(data, pressure_condition, no_land_coast_condition)
    return tuple(zip(lon_candidates, lat_candidates, associated_rotational))


def get_centers_from_candidates(candidates: Tuple[Tuple[np.ndarray]], deg_threshold: float) -> Tuple[Tuple[np.ndarray]]:
    """
    Find centers among candidates
    """
    if len(candidates) == 0:
        return tuple()
    lon_centers, lat_centers = [], []
    lon_to_avg, lat_to_avg, tmp_rot = [], [], []
    prev_lon, prev_lat = candidates[0][0], candidates[0][1]
    for current_lon, current_lat, current_rot in candidates:
        if np.abs(current_lon - prev_lon) <= deg_threshold and np.abs(current_lat - prev_lat) <= deg_threshold:
            lon_to_avg.append(current_lon)
            lat_to_avg.append(current_lat)
            tmp_rot.append(current_rot)
        else:
            lon_centers.append(lon_to_avg[np.argmax(np.abs(tmp_rot))])
            lat_centers.append(lat_to_avg[np.argmax(np.abs(tmp_rot))])
            lon_to_avg, lat_to_avg, tmp_rot = [current_lon], [current_lat], [current_rot]
        prev_lon, prev_lat = current_lon, current_lat
    lon_centers.append(lon_to_avg[np.argmax(np.abs(tmp_rot))])
    lat_centers.append(lat_to_avg[np.argmax(np.abs(tmp_rot))])
    return tuple(zip(lon_centers, lat_centers))


def identify_vortexes_center(data: xr.Dataset, threshold: Optional[float] = 5, plot: Optional[bool] = False) -> List[np.ndarray]:
    """
    From meshgrid and wind, identify vortex centers
    """
    # Identify potential candidates
    candidates = get_center_candidates(data, plot=plot)
    # Get centers from candidates
    centers = get_centers_from_candidates(candidates, deg_threshold=threshold)
    return centers


# 2. Collect Vortex Timeseries

In [16]:
TimeVortexMapping = namedtuple("TimeVortexMapping", ["timestamp", "vortexes"])