In [11]:
import xarray as xr 

ds = xr.open_dataset('/teamspace/studios/this_studio/ml-drought-forecasting/ml-modeling-pipeline/data/01_raw/f0e93607df96295f1c9e14ccfe96032f.nc')
# ds = xr.open_dataset('/Users/adamprzychodni/Documents/Repos/ml-drought-forecasting/ml-modeling-pipeline/data/01_raw/f0e93607df96295f1c9e14ccfe96032f.nc')

In [12]:
ds

In [13]:
import numpy as np
import xarray as xr

# Step 1: Define target grid with 1° resolution
target_lat = np.arange(-30, 30.1, 1)   # From -90 to 90 degrees inclusive
target_lon = np.arange(-180, 180, 1)      # From 0 to 359.5 degrees inclusive

# Step 2: Create target grid Dataset (optional, for reference)
target_grid = xr.Dataset(
    {
        "latitude": (["latitude"], target_lat),
        "longitude": (["longitude"], target_lon),
    }
)

# Step 3: Ensure latitude is ascending
if ds.latitude[0] > ds.latitude[-1]:
    ds = ds.sortby("latitude")

# Step 4: Perform interpolation
ds = ds.interp(latitude=target_lat, longitude=target_lon, method="linear")


In [14]:
ds

In [10]:
import xarray as xr

def mask_land_and_update_swvl1(ds: xr.Dataset) -> xr.Dataset:
    """
    Applies the land-sea mask (lsm) to the swvl1 variable in the dataset and updates swvl1 in the dataset.
    
    Args:
        ds (xr.Dataset): The input dataset containing swvl1 and lsm variables.
    
    Returns:
        xr.Dataset: The dataset with swvl1 updated to be masked according to the land-sea mask.
    """
    # Apply the land-sea mask directly (no need to modify it as it's already binary)
    ds['swvl1'] = ds.swvl1.where(ds.lsm == 1)  # Keep values where lsm is land (1), mask sea areas as NaN
    
    return ds

# Assuming your xarray.Dataset is loaded in `ds`
ds = mask_land_and_update_swvl1(ds)

# Now ds['swvl1'] contains the masked swvl1 values over land only (sea is masked out as NaN)


In [15]:
ds

In [16]:
import pandas as pd

# 1. Convert 'date' coordinate to datetime
ds['date'] = pd.to_datetime(ds['date'].astype(str), format='%Y%m%d')

In [17]:
import xarray as xr
import pandas as pd
import plotly.express as px
import plotly.graph_objects as go
from typing import Optional, List
from datetime import datetime

def visualize_variable_on_map(
    dataset: xr.Dataset,
    variable: str,
    time_dim: str = 'date',
    lat_dim: str = 'latitude',
    lon_dim: str = 'longitude',
    plot_type: str = 'scatter_geo',  # Options: 'scatter_geo', 'imshow'
    downsample_factor: Optional[int] = 1,
    projection: str = 'natural earth',
    color_scale: str = 'Viridis',
    title: Optional[str] = None,
    animation_frame: Optional[str] = None,
    hover_precision: int = 2,
    custom_colorbar_title: Optional[str] = None,
    start_date: Optional[str] = None,  # New parameter for start date
    end_date: Optional[str] = None,    # New parameter for end date
    **kwargs
) -> go.Figure:
    """
    Universal function to visualize any variable from an xarray Dataset on a map with optional animation using Plotly.
    
    Parameters:
    - dataset (xr.Dataset): The xarray Dataset containing the data.
    - variable (str): The name of the variable to visualize.
    - time_dim (str): Name of the time dimension. Default is 'time'.
    - lat_dim (str): Name of the latitude dimension. Default is 'lat'.
    - lon_dim (str): Name of the longitude dimension. Default is 'lon'.
    - plot_type (str): Type of plot. Options include 'scatter_geo' and 'imshow'. Default is 'scatter_geo'.
    - downsample_factor (int, optional): Factor by which to downsample spatial data for performance. Default is 1 (no downsampling).
    - projection (str): Map projection style for Plotly. Default is 'natural earth'.
    - color_scale (str or list): Color scale for the plot. Default is 'Viridis'.
    - title (str, optional): Title of the plot. If None, a default title is generated.
    - animation_frame (str, optional): Dimension to animate over. If None, no animation is created. Default is None.
    - hover_precision (int): Decimal precision for hover data. Default is 2.
    - custom_colorbar_title (str, optional): Custom title for the color bar. If None, the variable name is used.
    - start_date (str, optional): Start date for the data in ISO format (YYYY-MM-DD). If None, no start date filtering is applied.
    - end_date (str, optional): End date for the data in ISO format (YYYY-MM-DD). If None, no end date filtering is applied.
    - **kwargs: Additional keyword arguments passed to the Plotly plotting functions.
    
    Returns:
    - fig (plotly.graph_objs._figure.Figure): The Plotly figure object.
    """
    
    # Validate plot_type
    valid_plot_types = ['scatter_geo', 'imshow']
    if plot_type not in valid_plot_types:
        raise ValueError(f"Invalid plot_type '{plot_type}'. Valid options are: {valid_plot_types}")
    
    # Check if the variable exists
    if variable not in dataset.data_vars:
        raise ValueError(f"Variable '{variable}' not found in the dataset.")
    
    # Extract the DataArray
    da = dataset[variable]
    
    # Ensure required dimensions are present
    for dim in [time_dim, lat_dim, lon_dim]:
        if dim not in da.dims:
            raise ValueError(f"Dimension '{dim}' not found in variable '{variable}'.")
    
    # Convert time dimension to datetime if not already
    if not pd.api.types.is_datetime64_any_dtype(da[time_dim].dtype):
        da[time_dim] = pd.to_datetime(da[time_dim].values)
    
    # Apply date range filtering if specified
    if start_date or end_date:
        if start_date:
            try:
                start_datetime = pd.to_datetime(start_date)
                da = da.sel({time_dim: da[time_dim] >= start_datetime})
            except Exception as e:
                raise ValueError(f"Invalid start_date '{start_date}': {e}")
        if end_date:
            try:
                end_datetime = pd.to_datetime(end_date)
                da = da.sel({time_dim: da[time_dim] <= end_datetime})
            except Exception as e:
                raise ValueError(f"Invalid end_date '{end_date}': {e}")
    
    # Check if the dataset is empty after filtering
    if da.date.size == 0:
        raise ValueError("No data available for the specified date range.")
    
    # Downsample if required
    if downsample_factor > 1:
        da = da.coarsen({lat_dim: downsample_factor, lon_dim: downsample_factor}, boundary='trim').mean()
    
    # Handle animation
    animate = False
    if animation_frame and animation_frame in da.dims:
        animate = True
    elif animation_frame:
        raise ValueError(f"Animation frame dimension '{animation_frame}' not found in variable '{variable}'.")
    
    # Generate title if not provided
    if not title:
        if start_date and end_date:
            title = f"{variable} from {start_date} to {end_date}"
        elif start_date:
            title = f"{variable} from {start_date} onwarsm"
        elif end_date:
            title = f"{variable} up to {end_date}"
        else:
            title = f"{variable} Visualization"
    
    if plot_type == 'scatter_geo':
        # Convert to DataFrame
        df = da.to_dataframe().reset_index()
    
        # Drop NaNs
        df = df.dropna(subset=[variable])
    
        # Convert animation frame to string if animating
        if animate:
            df[animation_frame] = df[animation_frame].astype(str)
        else:
            # If not animating, set a default frame
            df['Frame'] = 'Frame'
    
        # Create the scatter_geo plot
        fig = px.scatter_geo(
            df,
            lat=lat_dim,
            lon=lon_dim,
            color=variable,
            animation_frame=animation_frame if animate else 'Frame',
            projection=projection,
            color_continuous_scale=color_scale,
            title=title,
            labels={variable: variable.upper()},
            hover_data={variable: f':.{hover_precision}f'},
            **kwargs
        )
    
        # Update layout for better appearance
        fig.update_layout(
            coloraxis_colorbar=dict(
                title=custom_colorbar_title if custom_colorbar_title else variable.upper(),
                ticks="outside"
            )
        )
    
        # If not animating, remove animation controls
        if not animate:
            fig.update_layout(
                updatemenus=[]
            )
    
    elif plot_type == 'imshow':
        # For imshow, we need to handle each frame separately
        if animate:
            # Prepare frames
            frames = []
            times = da[animation_frame].values
            for t in times:
                da_t = da.sel({animation_frame: t})
                # Convert to numpy array
                z = da_t.values
                # Handle latitude descending for proper display
                if da[lat_dim][0] > da[lat_dim][-1]:
                    z = z[::-1, :]
                frames.append(go.Frame(
                    data=[go.Heatmap(
                        z=z,
                        x=da[lon_dim].values,
                        y=da[lat_dim].values,
                        colorscale=color_scale,
                        **kwargs
                    )],
                    name=str(t)
                ))
    
            # Initial frame
            initial_t = times[0]
            initial_da = da.sel({animation_frame: initial_t})
            z_init = initial_da.values
            if da[lat_dim][0] > da[lat_dim][-1]:
                z_init = z_init[::-1, :]
    
            # Create initial heatmap
            fig = go.Figure(
                data=go.Heatmap(
                    z=z_init,
                    x=da[lon_dim].values,
                    y=da[lat_dim].values,
                    colorscale=color_scale,
                    **kwargs
                ),
                layout=go.Layout(
                    title=title,
                    updatemenus=[dict(
                        type='buttons',
                        buttons=[
                            dict(label='Play',
                                 method='animate',
                                 args=[None, {"frame": {"duration": 500, "redraw": True},
                                              "fromcurrent": True}]),
                            dict(label='Pause',
                                 method='animate',
                                 args=[[None], {"frame": {"duration": 0, "redraw": False},
                                                "mode": "immediate",
                                                "transition": {"duration": 0}}])
                        ],
                        showactive=False,
                        x=0.1,
                        y=0
                    )]
                ),
                frames=frames
            )
    
            # Update color bar
            fig.update_traces(colorbar=dict(
                title=custom_colorbar_title if custom_colorbar_title else variable.upper(),
                ticks="outside"
            ))
    
        else:
            # Single heatmap without animation
            da_t = da
            z = da_t.values
            if da[lat_dim][0] > da[lat_dim][-1]:
                z = z[::-1, :]
    
            fig = px.imshow(
                z,
                labels=dict(x=lon_dim, y=lat_dim, color=variable),
                x=da[lon_dim].values,
                y=da[lat_dim].values,
                color_continuous_scale=color_scale,
                title=title,
                **kwargs
            )
    
            # Update color bar
            fig.update_layout(
                coloraxis_colorbar=dict(
                    title=custom_colorbar_title if custom_colorbar_title else variable.upper(),
                    ticks="outside"
                )
            )
    
    else:
        raise NotImplementedError(f"Plot type '{plot_type}' is not implemented.")
    
    return fig


In [1]:
# # Example Usage:
# # Assuming you have your xarray Dataset loaded as `sm`

# # Visualize 'sm' with animation using scatter_geo and a specified date range
# fig = visualize_variable_on_map(
#     dataset=ds,
#     variable='swvl1',
#     plot_type='scatter_geo',
#     animation_frame='date',
#     downsample_factor=4,
#     title='swvl1 Over Time',
#     # start_date='2020-01-01',  # Specify the start date
#     # end_date='2020-01-31',    # Specify the end date
#     color_scale = [
#         [0.0, "darkred"],        # 0% - extremely dry
#         [0.1, "red"],            # 10% - very dry
#         [0.2, "orangered"],      # 20% - dry
#         [0.3, "lightgreen"],     # 30% - beginning of optimal moisture
#         [0.4, "limegreen"],      # 40% - optimal moisture for most crops
#         [0.5, "green"],          # 50% - middle of optimal moisture
#         [0.55, "darkseagreen"],  # 55% - upper limit of optimal moisture
#         [0.6, "darkgreen"],      # 60% - end of optimal moisture
#         [0.7, "lightblue"],      # 70% - moist but not excessive
#         [0.8, "skyblue"],        # 80% - very moist
#         [0.9, "deepskyblue"],    # 90% - extremely moist
#         [1.0, "blue"]            # 100% - very wet
#     ]
# )
# fig.show()

In [22]:
# ds.to_netcdf('/Users/adamprzychodni/Documents/Repos/ml-drought-forecasting/ml-modeling-pipeline/data/02_intermediate/merged_data.nc')
ds.to_netcdf('ml-drought-forecasting/ml-modeling-pipeline/data/02_intermediate/merged_data.nc')