In [6]:
import os
import sys
import argparse
from typing import Dict, List, Tuple, Optional, Union
from dataclasses import dataclass
import warnings

import numpy as np
import pandas as pd
import xarray as xr
import matplotlib.pyplot as plt
import cartopy.crs as ccrs
import cartopy.feature as cfeature
from cartopy.mpl.ticker import LongitudeFormatter, LatitudeFormatter
from scipy.ndimage import maximum_filter, minimum_filter
from datetime import datetime
from metpy.units import units
import metpy.calc as mpcalc

In [None]:
# ============================================================================
# CONFIGURATION AND DATA CLASSES
# ============================================================================

@dataclass
class VariableNames:
    """Variable names for different model outputs."""
    msl_pressure: str = "psl"  # Mean sea-level pressure
    pressure: str = "plev"  # Pressure levels
    temperature: str = "ta"  # Temperature
    specific_humidity: str = "hus"  # Specific humidity
    u_wind: str = "ua"  # Zonal wind
    v_wind: str = "va"  # Meridional wind
    w_wind: str = "wa"  # Meridional wind
    u_wind_10m: str = "uas"  # 10m zonal wind
    v_wind_10m: str = "vas"  # 10m meridional wind
    latitude: str = "lat"
    longitude: str = "lon"
    time: str = "time"


@dataclass
class CycloneTrackingConfig:
    """Configuration for cyclone tracking parameters."""
    search_radius_km: float = 500.0  # Initial search radius
    max_speed_kmh: float = 100.0  # Maximum cyclone movement speed
    min_pressure_threshold: float = 1015.0  # Pressure threshold for detection
    footprint_size: int = 10  # Footprint for local minima detection
    intensity_threshold: float = 2.0  # Minimum pressure gradient (hPa)


@dataclass
class PlotConfig:
    """Configuration for plotting."""
    figsize: Tuple[int, int] = (12, 8)
    dpi: int = 300
    colors: List[str] = None

    def __post_init__(self):
        if self.colors is None:
            self.colors = [
                "#1f77b4", "#ff7f0e", "#2ca02c", "#d62728", "#9467bd",
                "#8c564b", "#e377c2", "#7f7f7f", "#bcbd22", "#17becf"
            ]

In [8]:
# ============================================================================
# UTILITY FUNCTIONS
# ============================================================================

def haversine_distance(lat1: float, lon1: float, lat2: float, lon2: float) -> float:
    """
    Calculate the great circle distance between two points on Earth.

    Parameters
    ----------
    lat1, lon1 : float
        Latitude and longitude of first point in degrees
    lat2, lon2 : float
        Latitude and longitude of second point in degrees

    Returns
    -------
    float
        Distance in kilometers
    """
    R = 6371.0  # Earth radius in km
    lat1, lon1, lat2, lon2 = map(np.radians, [lat1, lon1, lat2, lon2])
    dlat = lat2 - lat1
    dlon = lon2 - lon1
    a = np.sin(dlat/2)**2 + np.cos(lat1) * np.cos(lat2) * np.sin(dlon/2)**2
    c = 2 * np.arcsin(np.sqrt(a))
    return R * c


def setup_logger(verbose: bool = True):
    """Set up logging configuration."""
    import logging
    level = logging.INFO if verbose else logging.WARNING
    logging.basicConfig(
        level=level,
        format='%(asctime)s - %(levelname)s - %(message)s',
        datefmt='%Y-%m-%d %H:%M:%S'
    )
    return logging.getLogger(__name__)


# ============================================================================
# THERMODYNAMIC CALCULATIONS
# ============================================================================

class ThermodynamicCalculator:
    """Calculate thermodynamic variables for cyclone analysis."""

    @staticmethod
    def calculate_equivalent_potential_temperature(
        temperature: xr.DataArray,
        pressure: xr.DataArray,
        specific_humidity: xr.DataArray,
        var_names: VariableNames
    ) -> xr.DataArray:
        """
        Calculate equivalent potential temperature.

        Parameters
        ----------
        temperature : xr.DataArray
            Temperature in Kelvin
        pressure : xr.DataArray
            Pressure in Pa or hPa
        specific_humidity : xr.DataArray
            Specific humidity in kg/kg

        Returns
        -------
        xr.DataArray
            Equivalent potential temperature in Kelvin
        """
        # Convert to MetPy quantities
        T = temperature.values * units.kelvin

        # Handle pressure units
        p_values = pressure.values
        if np.nanmax(p_values) > 10000:  # Likely in Pa
            p = p_values * units.Pa
        else:  # Likely in hPa
            p = p_values * units.hPa

        q = specific_humidity.values * units('kg/kg')

        # Calculate mixing ratio from specific humidity
        mixing_ratio = mpcalc.mixing_ratio_from_specific_humidity(q)

        # Calculate dewpoint
        dewpoint = mpcalc.dewpoint_from_specific_humidity(p, T, q)

        # Calculate equivalent potential temperature
        theta_e = mpcalc.equivalent_potential_temperature(p, T, dewpoint)

        # Create output DataArray
        theta_e_da = xr.DataArray(
            theta_e.magnitude,
            coords=temperature.coords,
            dims=temperature.dims,
            attrs={'units': 'K', 'long_name': 'Equivalent Potential Temperature'}
        )

        return theta_e_da


# ============================================================================
# CYCLONE DETECTION AND TRACKING
# ============================================================================

class CycloneDetector:
    """Detect and track cyclones in gridded data."""

    def __init__(self, config: CycloneTrackingConfig, var_names: VariableNames):
        self.config = config
        self.var_names = var_names
        self.logger = setup_logger()

    def find_pressure_minimum(
        self,
        msl_pressure: xr.DataArray,
        search_lat: Optional[Tuple[float, float]] = None,
        search_lon: Optional[Tuple[float, float]] = None
    ) -> Tuple[float, float, float]:
        """
        Find local pressure minimum in the search domain.

        Parameters
        ----------
        msl_pressure : xr.DataArray
            Mean sea-level pressure field
        search_lat : tuple, optional
            (min_lat, max_lat) for search region
        search_lon : tuple, optional
            (min_lon, max_lon) for search region

        Returns
        -------
        tuple
            (latitude, longitude, pressure_value) of minimum
        """
        # Subset to search region if specified
        msl = msl_pressure.copy()
        if search_lat is not None:
            msl = msl.sel({self.var_names.latitude: slice(*search_lat)})
        if search_lon is not None:
            msl = msl.sel({self.var_names.longitude: slice(*search_lon)})

        # Apply local minimum filter
        data = msl.values
        local_min = minimum_filter(data, footprint=np.ones((
            self.config.footprint_size, self.config.footprint_size
        )))

        # Find where data equals local minimum (local minima locations)
        minima_mask = (data == local_min)

        # Find the absolute minimum among local minima
        if np.any(minima_mask):
            min_indices = np.where(minima_mask & (data == np.nanmin(data[minima_mask])))
            if len(min_indices[0]) > 0:
                lat_idx, lon_idx = min_indices[0][0], min_indices[1][0]
                lat = float(msl[self.var_names.latitude][lat_idx].values)
                lon = float(msl[self.var_names.longitude][lon_idx].values)
                pressure = float(data[lat_idx, lon_idx])
                return lat, lon, pressure

        # Fallback: return absolute minimum
        min_idx = np.unravel_index(np.nanargmin(data), data.shape)
        lat = float(msl[self.var_names.latitude][min_idx[0]].values)
        lon = float(msl[self.var_names.longitude][min_idx[1]].values)
        pressure = float(data[min_idx])

        return lat, lon, pressure

    def track_cyclone(
        self,
        dataset: xr.Dataset,
        initial_lat: Optional[float] = None,
        initial_lon: Optional[float] = None,
        initial_time: Optional[str] = None
    ) -> pd.DataFrame:
        """
        Track cyclone through time using pressure minima.

        Parameters
        ----------
        dataset : xr.Dataset
            Dataset containing msl pressure and time dimension
        initial_lat : float, optional
            Initial latitude for tracking
        initial_lon : float, optional
            Initial longitude for tracking
        initial_time : str, optional
            Initial time for tracking (if None, uses first time step)

        Returns
        -------
        pd.DataFrame
            DataFrame with columns: time, lat, lon, msl_pressure, max_wind
        """
        time_var = self.var_names.time
        lat_var = self.var_names.latitude
        lon_var = self.var_names.longitude

        times = dataset[time_var].values
        if initial_time is not None:
            start_idx = np.where(times >= np.datetime64(initial_time))[0][0]
        else:
            start_idx = 0

        track_data = []

        # Initialize tracking
        if initial_lat is not None and initial_lon is not None:
            prev_lat, prev_lon = initial_lat, initial_lon
        else:
            # Find initial position from first time step
            msl_first = dataset[self.var_names.msl_pressure].isel({time_var: start_idx})
            prev_lat, prev_lon, _ = self.find_pressure_minimum(msl_first)

        self.logger.info(f"Starting cyclone tracking from ({prev_lat:.2f}, {prev_lon:.2f})")

        for t_idx in range(start_idx, len(times)):
            time_step = times[t_idx]
            msl_data = dataset[self.var_names.msl_pressure].isel({time_var: t_idx})

            # Calculate search radius based on time step
            if t_idx > start_idx:
                time_diff_hours = (times[t_idx] - times[t_idx-1]) / np.timedelta64(1, 'h')
                max_distance = (self.config.max_speed_kmh * time_diff_hours) / 111.0  # degrees
                search_radius = max_distance + 2.0  # Add buffer
            else:
                search_radius = self.config.search_radius_km / 111.0

            # Define search box
            search_lat = (prev_lat - search_radius, prev_lat + search_radius)
            search_lon = (prev_lon - search_radius, prev_lon + search_radius)

            # Find pressure minimum
            try:
                lat, lon, pressure = self.find_pressure_minimum(
                    msl_data, search_lat, search_lon
                )

                # Calculate maximum wind speed near cyclone center
                max_wind = self._calculate_max_wind(dataset, t_idx, lat, lon)

                track_data.append({
                    'time': pd.Timestamp(time_step),
                    'lat': lat,
                    'lon': lon,
                    'msl_pressure': pressure,
                    'max_wind': max_wind
                })

                prev_lat, prev_lon = lat, lon

            except Exception as e:
                self.logger.warning(f"Failed to track at time {time_step}: {e}")
                continue

        track_df = pd.DataFrame(track_data)
        self.logger.info(f"Tracking complete. Found {len(track_df)} positions.")

        return track_df

    def _calculate_max_wind(
        self,
        dataset: xr.Dataset,
        time_idx: int,
        center_lat: float,
        center_lon: float,
        radius_deg: float = 3.0
    ) -> float:
        """Calculate maximum wind speed near cyclone center."""
        try:
            # Get 10m wind components
            u_var = self.var_names.u_wind_10m
            v_var = self.var_names.v_wind_10m

            if u_var not in dataset or v_var not in dataset:
                return np.nan

            u = dataset[u_var].isel({self.var_names.time: time_idx})
            v = dataset[v_var].isel({self.var_names.time: time_idx})

            # Subset around cyclone center
            lat_slice = slice(center_lat - radius_deg, center_lat + radius_deg)
            lon_slice = slice(center_lon - radius_deg, center_lon + radius_deg)

            u_subset = u.sel({
                self.var_names.latitude: lat_slice,
                self.var_names.longitude: lon_slice
            })
            v_subset = v.sel({
                self.var_names.latitude: lat_slice,
                self.var_names.longitude: lon_slice
            })

            # Calculate wind speed
            wind_speed = np.sqrt(u_subset**2 + v_subset**2)
            max_wind = float(wind_speed.max().values)

            return max_wind

        except Exception as e:
            self.logger.warning(f"Could not calculate max wind: {e}")
            return np.nan


# ============================================================================
# VERTICAL STRUCTURE ANALYSIS
# ============================================================================

class VerticalStructureAnalyzer:
    """Analyze vertical thermal structure of cyclones."""

    def __init__(self, var_names: VariableNames):
        self.var_names = var_names
        self.logger = setup_logger()
        self.thermo_calc = ThermodynamicCalculator()

    def extract_cross_section(
        self,
        dataset: xr.Dataset,
        center_lat: float,
        center_lon: float,
        time_idx: int,
        max_radius_deg: float = 10.0
    ) -> xr.Dataset:
        """
        Extract zonal cross-section through cyclone center.

        Parameters
        ----------
        dataset : xr.Dataset
            3D dataset with pressure levels
        center_lat : float
            Cyclone center latitude
        center_lon : float
            Cyclone center longitude
        time_idx : int
            Time index
        max_radius_deg : float
            Maximum radius for cross-section in degrees

        Returns
        -------
        xr.Dataset
            Cross-section with radial distance coordinate
        """
        # Select time and latitude slice
        lat_tol = 0.5  # Tolerance for latitude selection
        ds_time = dataset.isel({self.var_names.time: time_idx})

        # Select latitude closest to center
        ds_cross = ds_time.sel({
            self.var_names.latitude: center_lat,
            method:'nearest'
        })

        # Select longitude range around center
        lon_slice = slice(center_lon - max_radius_deg, center_lon + max_radius_deg)
        ds_cross = ds_cross.sel({self.var_names.longitude: lon_slice})

        # Calculate radial distance from center
        lons = ds_cross[self.var_names.longitude].values
        distances = []
        for lon in lons:
            dist = haversine_distance(center_lat, center_lon, center_lat, lon)
            # Make western distances negative
            if lon < center_lon:
                dist = -dist
            distances.append(dist)

        # Add distance coordinate
        ds_cross = ds_cross.assign_coords({'distance': (self.var_names.longitude, distances)})

        return ds_cross

    def compute_thermal_structure(
        self,
        dataset: xr.Dataset,
        center_lat: float,
        center_lon: float,
        time_idx: int,
        max_radius_deg: float = 10.0
    ) -> Tuple[xr.DataArray, xr.DataArray]:
        """
        Compute equivalent potential temperature cross-section.

        Returns
        -------
        tuple
            (theta_e, tangential_wind) DataArrays with distance and pressure coords
        """
        # Extract cross-section
        ds_cross = self.extract_cross_section(
            dataset, center_lat, center_lon, time_idx, max_radius_deg
        )

        # Calculate equivalent potential temperature
        theta_e = self.thermo_calc.calculate_equivalent_potential_temperature(
            ds_cross[self.var_names.temperature],
            ds_cross[self.var_names.pressure],
            ds_cross[self.var_names.specific_humidity],
            self.var_names
        )

        # Calculate tangential wind if available
        try:
            u_wind = ds_cross[self.var_names.u_wind]
            v_wind = ds_cross[self.var_names.v_wind]
            # Approximate tangential wind (simplified)
            tangential_wind = np.sqrt(u_wind**2 + v_wind**2)
        except:
            tangential_wind = None

        return theta_e, tangential_wind


# ============================================================================
# VISUALIZATION
# ============================================================================

class CycloneVisualizer:
    """Create visualizations for cyclone analysis."""

    def __init__(self, plot_config: PlotConfig):
        self.config = plot_config
        self.logger = setup_logger()

    def plot_tracks(
        self,
        tracks: Dict[str, pd.DataFrame],
        reference_tracks: Optional[Dict[str, pd.DataFrame]] = None,
        output_path: Optional[str] = None,
        domain: Optional[Dict[str, Tuple[float, float]]] = None
    ):
        """
        Plot cyclone tracks for multiple experiments.

        Parameters
        ----------
        tracks : dict
            Dictionary of {experiment_name: track_dataframe}
        reference_tracks : dict, optional
            Dictionary of {reference_name: track_dataframe} for ERA5/best-track
        output_path : str, optional
            Path to save figure
        domain : dict, optional
            Domain bounds {'lat': (min, max), 'lon': (min, max)}
        """
        fig = plt.figure(figsize=self.config.figsize)
        ax = plt.axes(projection=ccrs.PlateCarree())

        # Set domain
        if domain:
            ax.set_extent([
                domain['lon'][0], domain['lon'][1],
                domain['lat'][0], domain['lat'][1]
            ], crs=ccrs.PlateCarree())

        # Add map features
        ax.add_feature(cfeature.COASTLINE, linewidth=0.5)
        ax.add_feature(cfeature.BORDERS, linewidth=0.3, linestyle=':')
        ax.add_feature(cfeature.LAND, facecolor='lightgray', alpha=0.3)
        ax.add_feature(cfeature.OCEAN, facecolor='lightblue', alpha=0.3)

        # Plot reference tracks first (dashed lines)
        if reference_tracks:
            for idx, (name, track) in enumerate(reference_tracks.items()):
                ax.plot(
                    track['lon'], track['lat'],
                    linestyle='--', linewidth=2.5, alpha=0.8,
                    color='black', label=name, transform=ccrs.PlateCarree()
                )
                # Mark start and end
                ax.scatter(track['lon'].iloc[0], track['lat'].iloc[0],
                          marker='o', s=100, color='green', edgecolors='black',
                          transform=ccrs.PlateCarree(), zorder=5)
                ax.scatter(track['lon'].iloc[-1], track['lat'].iloc[-1],
                          marker='s', s=100, color='red', edgecolors='black',
                          transform=ccrs.PlateCarree(), zorder=5)

        # Plot experiment tracks
        for idx, (name, track) in enumerate(tracks.items()):
            color = self.config.colors[idx % len(self.config.colors)]
            ax.plot(
                track['lon'], track['lat'],
                linestyle='-', linewidth=2, alpha=0.7,
                color=color, label=name, transform=ccrs.PlateCarree()
            )
            # Mark start
            ax.scatter(track['lon'].iloc[0], track['lat'].iloc[0],
                      marker='o', s=80, color=color, edgecolors='black',
                      alpha=0.7, transform=ccrs.PlateCarree(), zorder=4)

        # Add gridlines
        gl = ax.gridlines(draw_labels=True, linestyle='--', alpha=0.5)
        gl.top_labels = False
        gl.right_labels = False

        # Add legend
        ax.legend(loc='upper left', fontsize=10, framealpha=0.9)

        plt.title('Cyclone Tracks Comparison', fontsize=14, fontweight='bold')
        plt.tight_layout()

        if output_path:
            plt.savefig(output_path, dpi=self.config.dpi, bbox_inches='tight')
            self.logger.info(f"Track plot saved to {output_path}")

        plt.show()

    def plot_pressure_evolution(
        self,
        tracks: Dict[str, pd.DataFrame],
        reference_tracks: Optional[Dict[str, pd.DataFrame]] = None,
        output_path: Optional[str] = None
    ):
        """Plot minimum pressure evolution over time."""
        fig, ax = plt.subplots(figsize=self.config.figsize)

        # Plot reference tracks
        if reference_tracks:
            for name, track in reference_tracks.items():
                ax.plot(track['time'], track['msl_pressure'],
                       linestyle='--', linewidth=2.5, color='black',
                       alpha=0.8, label=name)

        # Plot experiment tracks
        for idx, (name, track) in enumerate(tracks.items()):
            color = self.config.colors[idx % len(self.config.colors)]
            ax.plot(track['time'], track['msl_pressure'],
                   linestyle='-', linewidth=2, color=color,
                   alpha=0.7, label=name, marker='o', markersize=4)

        ax.set_xlabel('Time', fontsize=12, fontweight='bold')
        ax.set_ylabel('Mean Sea-Level Pressure (hPa)', fontsize=12, fontweight='bold')
        ax.set_title('Cyclone Minimum Pressure Evolution', fontsize=14, fontweight='bold')
        ax.legend(fontsize=10, framealpha=0.9)
        ax.grid(True, alpha=0.3)

        # Invert y-axis (lower pressure at top)
        ax.invert_yaxis()

        plt.tight_layout()

        if output_path:
            plt.savefig(output_path, dpi=self.config.dpi, bbox_inches='tight')
            self.logger.info(f"Pressure evolution plot saved to {output_path}")

        plt.show()

    def plot_wind_evolution(
        self,
        tracks: Dict[str, pd.DataFrame],
        reference_tracks: Optional[Dict[str, pd.DataFrame]] = None,
        output_path: Optional[str] = None
    ):
        """Plot maximum wind speed evolution over time."""
        fig, ax = plt.subplots(figsize=self.config.figsize)

        # Plot reference tracks
        if reference_tracks:
            for name, track in reference_tracks.items():
                if 'max_wind' in track.columns:
                    ax.plot(track['time'], track['max_wind'],
                           linestyle='--', linewidth=2.5, color='black',
                           alpha=0.8, label=name)

        # Plot experiment tracks
        for idx, (name, track) in enumerate(tracks.items()):
            if 'max_wind' in track.columns:
                color = self.config.colors[idx % len(self.config.colors)]
                ax.plot(track['time'], track['max_wind'],
                       linestyle='-', linewidth=2, color=color,
                       alpha=0.7, label=name, marker='o', markersize=4)

        ax.set_xlabel('Time', fontsize=12, fontweight='bold')
        ax.set_ylabel('Maximum Wind Speed (m/s)', fontsize=12, fontweight='bold')
        ax.set_title('Cyclone Maximum Wind Speed Evolution', fontsize=14, fontweight='bold')
        ax.legend(fontsize=10, framealpha=0.9)
        ax.grid(True, alpha=0.3)

        plt.tight_layout()

        if output_path:
            plt.savefig(output_path, dpi=self.config.dpi, bbox_inches='tight')
            self.logger.info(f"Wind evolution plot saved to {output_path}")

        plt.show()

    def plot_vertical_structure(
        self,
        theta_e: xr.DataArray,
        distance_coord: str = 'distance',
        pressure_coord: str = 'plev',
        tangential_wind: Optional[xr.DataArray] = None,
        output_path: Optional[str] = None,
        title: str = 'Vertical Thermal Structure'
    ):
        """
        Plot vertical cross-section of equivalent potential temperature.

        Parameters
        ----------
        theta_e : xr.DataArray
            Equivalent potential temperature with distance and pressure dimensions
        distance_coord : str
            Name of distance coordinate
        pressure_coord : str
            Name of pressure coordinate
        tangential_wind : xr.DataArray, optional
            Tangential wind component for contouring
        output_path : str, optional
            Path to save figure
        title : str
            Plot title
        """
        fig, ax = plt.subplots(figsize=(14, 8))

        # Prepare data
        distance = theta_e[distance_coord].values
        pressure = theta_e[pressure_coord].values

        # Convert pressure to hPa if needed
        if np.max(pressure) > 10000:
            pressure = pressure / 100.0

        # Create meshgrid
        X, Y = np.meshgrid(distance, pressure)

        # Plot theta_e as filled contours
        levels_theta = np.arange(300, 380, 2)
        cf = ax.contourf(X, Y, theta_e.T, levels=levels_theta,
                        cmap='RdYlBu_r', extend='both')

        # Add colorbar
        cbar = plt.colorbar(cf, ax=ax, orientation='vertical', pad=0.02)
        cbar.set_label('Equivalent Potential Temperature (K)',
                      fontsize=12, fontweight='bold')

        # Overlay tangential wind as contours if available
        if tangential_wind is not None:
            levels_wind = np.arange(10, 50, 5)
            cs = ax.contour(X, Y, tangential_wind.T, levels=levels_wind,
                          colors='black', linewidths=1.5, alpha=0.6)
            ax.clabel(cs, inline=True, fontsize=9, fmt='%d m/s')

        # Format axes
        ax.set_xlabel('Distance from Center (km)', fontsize=12, fontweight='bold')
        ax.set_ylabel('Pressure (hPa)', fontsize=12, fontweight='bold')
        ax.set_title(title, fontsize=14, fontweight='bold')

        # Invert y-axis (pressure increases downward)
        ax.invert_yaxis()

        # Set y-axis to show common pressure levels
        ax.set_yticks([1000, 925, 850, 700, 500, 400, 300, 250, 200])

        # Add vertical line at center
        ax.axvline(x=0, color='black', linestyle='--', linewidth=1.5, alpha=0.7)

        ax.grid(True, alpha=0.3)
        plt.tight_layout()

        if output_path:
            plt.savefig(output_path, dpi=self.config.dpi, bbox_inches='tight')
            self.logger.info(f"Vertical structure plot saved to {output_path}")

        plt.show()


# ============================================================================
# DATA LOADING AND MANAGEMENT
# ============================================================================

class DataManager:
    """Manage loading and preprocessing of model and reference data."""

    def __init__(self, var_names: VariableNames):
        self.var_names = var_names
        self.logger = setup_logger()

    def load_model_output(
        self,
        file_paths: Union[str, List[str]],
        chunks: Optional[Dict] = None
    ) -> xr.Dataset:
        """
        Load model output files.

        Parameters
        ----------
        file_paths : str or list
            Path(s) to model output NetCDF files
        chunks : dict, optional
            Chunking specification for dask

        Returns
        -------
        xr.Dataset
            Loaded dataset
        """
        try:
            if isinstance(file_paths, str):
                ds = xr.open_dataset(file_paths, chunks=chunks)
            else:
                ds = xr.open_mfdataset(file_paths, combine='by_coords', chunks=chunks)

            self.logger.info(f"Loaded dataset with dimensions: {dict(ds.dims)}")
            return ds

        except Exception as e:
            self.logger.error(f"Failed to load data: {e}")
            raise

    def load_era5_data(
        self,
        file_paths: Union[str, List[str]],
        var_mapping: Optional[Dict[str, str]] = None
    ) -> xr.Dataset:
        """
        Load ERA5 reanalysis data and standardize variable names.

        Parameters
        ----------
        file_paths : str or list
            Path(s) to ERA5 NetCDF files
        var_mapping : dict, optional
            Mapping from ERA5 variable names to standard names

        Returns
        -------
        xr.Dataset
            Loaded and standardized dataset
        """
        ds = self.load_model_output(file_paths)

        # Default ERA5 variable mapping
        if var_mapping is None:
            var_mapping = {
                'msl': self.var_names.msl_pressure,
                't': self.var_names.temperature,
                'q': self.var_names.specific_humidity,
                'u': self.var_names.u_wind,
                'v': self.var_names.v_wind,
                'u10': self.var_names.u_wind_10m,
                'v10': self.var_names.v_wind_10m,
            }

        # Rename variables
        rename_dict = {k: v for k, v in var_mapping.items() if k in ds}
        if rename_dict:
            ds = ds.rename(rename_dict)
            self.logger.info(f"Renamed ERA5 variables: {rename_dict}")

        return ds

    def load_best_track(
        self,
        file_path: str,
        format: str = 'csv'
    ) -> pd.DataFrame:
        """
        Load NOAA best-track data.

        Parameters
        ----------
        file_path : str
            Path to best-track file
        format : str
            File format ('csv', 'txt', 'hurdat2')

        Returns
        -------
        pd.DataFrame
            Best-track data with columns: time, lat, lon, msl_pressure, max_wind
        """
        if format == 'csv':
            df = pd.read_csv(file_path, parse_dates=['time'])
        elif format == 'hurdat2':
            df = self._parse_hurdat2(file_path)
        else:
            raise ValueError(f"Unsupported format: {format}")

        self.logger.info(f"Loaded best-track data with {len(df)} records")
        return df

    def _parse_hurdat2(self, file_path: str) -> pd.DataFrame:
        """Parse HURDAT2 format best-track file."""
        # Implementation for HURDAT2 format
        # This is a placeholder - actual implementation depends on file structure
        raise NotImplementedError("HURDAT2 parsing not yet implemented")

In [2]:
def load_regcm5_multi_file(data_dir, verbose=False):
    """
    Load RegCM5 data from separate variable files and merge into single dataset.
    
    Parameters
    ----------
    data_dir : str
        Directory containing separate NetCDF files for each variable
        
    Returns
    -------
    xr.Dataset
        Merged dataset with all variables
    """
    nc_files = sorted(glob.glob(os.path.join(data_dir, '*.nc')))
    datasets = []

    for file in nc_files:
        ds_temp = xr.open_dataset(file)
        datasets.append(ds_temp)
        
    ds_merged = xr.merge(datasets, compat='override')
    
    if verbose:
        print(f"\nMerged dataset:")
        print(f"  Variables: {list(ds_merged.data_vars)}")
        print(f"  Dimensions: {dict(ds_merged.dims)}")
        print(f"  Coordinates: {list(ds_merged.coords)}")
    
    return ds_merged

In [3]:
data_dir = 'data/domain_small_regridded/ctrl'
ds = load_regcm5_multi_file(data_dir, verbose=True)


Merged dataset:
  Variables: ['hus', 'time_bnds', 'pr', 'psl', 'ta', 'ua', 'va', 'wa']
  Dimensions: {'time': 193, 'lon': 417, 'lat': 417, 'plev': 21, 'bnds': 2}
  Coordinates: ['time', 'lon', 'lat', 'plev']


In [14]:
class CycloneAnalysis:
    """Main pipeline for cyclone tracking and analysis."""

    def __init__(
        self,
        var_names: Optional[VariableNames] = None,
        tracking_config: Optional[CycloneTrackingConfig] = None,
        plot_config: Optional[PlotConfig] = None
    ):
        self.var_names = var_names or VariableNames()
        self.tracking_config = tracking_config or CycloneTrackingConfig()
        self.plot_config = plot_config or PlotConfig()

        self.detector = CycloneDetector(self.tracking_config, self.var_names)
        self.vertical_analyzer = VerticalStructureAnalyzer(self.var_names)
        self.visualizer = CycloneVisualizer(self.plot_config)
        self.data_manager = DataManager(self.var_names)
        self.logger = setup_logger()

    def analyze_experiment(
        self,
        ds: xr.Dataset,
        experiment_name: str,
        initial_position: Optional[Tuple[float, float]] = None,
        initial_time: Optional[str] = None
    ) -> Tuple[pd.DataFrame, xr.Dataset]:
        """
        Analyze a single experiment.

        Parameters
        ----------
        data_path : str or list
            Path(s) to model output files
        experiment_name : str
            Name of experiment
        initial_position : tuple, optional
            (lat, lon) initial position
        initial_time : str, optional
            Initial time for tracking

        Returns
        -------
        tuple
            (track_dataframe, dataset)
        """
        self.logger.info(f"Analyzing experiment: {experiment_name}")

        # Load data
        #ds = self.data_manager.load_model_output(data_path)

        # Track cyclone
        initial_lat = initial_position[0] if initial_position else None
        initial_lon = initial_position[1] if initial_position else None

        track = self.detector.track_cyclone(
            ds, initial_lat, initial_lon, initial_time
        )

        return track, ds

In [15]:
otis = CycloneAnalysis()

In [16]:
data_path = "data/domain_small"
otis.analyze_experiment(ds=ds, experiment_name="ctrl")

2025-11-21 12:02:35 - INFO - Analyzing experiment: ctrl


ValueError: All-NaN slice encountered

In [None]:
def find_pressure_minimum(
        msl_pressure: xr.DataArray,
        search_lat: Optional[Tuple[float, float]] = None,
        search_lon: Optional[Tuple[float, float]] = None
    ) -> Tuple[float, float, float]:
        """
        Find local pressure minimum in the search domain.

        Parameters
        ----------
        msl_pressure : xr.DataArray
            Mean sea-level pressure field
        search_lat : tuple, optional
            (min_lat, max_lat) for search region
        search_lon : tuple, optional
            (min_lon, max_lon) for search region

        Returns
        -------
        tuple
            (latitude, longitude, pressure_value) of minimum
        """
        # Subset to search region if specified
        msl = msl_pressure.copy()
        if search_lat is not None:
            msl = msl.sel({self.var_names.latitude: slice(*search_lat)})
        if search_lon is not None:
            msl = msl.sel({self.var_names.longitude: slice(*search_lon)})

        # Apply local minimum filter
        data = msl.values
        local_min = minimum_filter(data, footprint=np.ones((
            self.config.footprint_size, self.config.footprint_size
        )))

        # Find where data equals local minimum (local minima locations)
        minima_mask = (data == local_min)

        # Find the absolute minimum among local minima
        if np.any(minima_mask):
            min_indices = np.where(minima_mask & (data == np.nanmin(data[minima_mask])))
            if len(min_indices[0]) > 0:
                lat_idx, lon_idx = min_indices[0][0], min_indices[1][0]
                lat = float(msl[self.var_names.latitude][lat_idx].values)
                lon = float(msl[self.var_names.longitude][lon_idx].values)
                pressure = float(data[lat_idx, lon_idx])
                return lat, lon, pressure

        # Fallback: return absolute minimum
        try:
            min_idx = np.unravel_index(np.nanargmin(data), data.shape)
            lat = float(msl[self.var_names.latitude][min_idx[0]].values)
            lon = float(msl[self.var_names.longitude][min_idx[1]].values)
            pressure = float(data[min_idx])

            return lat, lon, pressure
        except:
            return np.nan, np.nan, np.nan


In [None]:
find_pressure_minimum(ds.psl)