# Exploratory Spatio-Temporal Data Analysis
# Part 1 - Data Inspection and Descriptive Statistics - Summary Statistics

In [3]:
import numpy as np
import pandas as pd
import polars as pl

import matplotlib.pyplot as plt

import folium
from mpl_toolkits.basemap import Basemap

%matplotlib inline

In [4]:
import xarray as xr 

ds = xr.open_dataset('/Users/adamprzychodni/Documents/Repos/ml-drought-forecasting/ml-modeling-pipeline/data/01_raw/ERA5_monthly_averaged_data_on_single_levels.nc') # to heavy  to visualzie 
# ds = xr.open_dataset('/Users/adamprzychodni/Documents/Repos/ml-drought-forecasting/ml-modeling-pipeline/data/02_intermediate/preprocessed_data.nc')

In [None]:
ds

## Data Inspection

### Area of intrest

In [6]:
import xarray as xr
import pandas as pd
import plotly.express as px
import plotly.graph_objects as go
from typing import Optional, List
from datetime import datetime

def visualize_variable_on_map(
    dataset: xr.Dataset,
    variable: str,
    time_dim: str = 'date',
    lat_dim: str = 'latitude',
    lon_dim: str = 'longitude',
    plot_type: str = 'scatter_geo',
    downsample_factor: Optional[int] = 1,
    projection: str = 'natural earth',
    color_scale: str = 'Viridis',
    title: Optional[str] = None,
    animation_frame: Optional[str] = None,
    hover_precision: int = 2,
    custom_colorbar_title: Optional[str] = None,
    start_date: Optional[str] = None,
    end_date: Optional[str] = None,
    **kwargs
) -> go.Figure:
    """
    Universal function to visualize any variable from an xarray Dataset on a map with optional animation using Plotly.
    
    Parameters:
    - dataset (xr.Dataset): The xarray Dataset containing the data.
    - variable (str): The name of the variable to visualize.
    - time_dim (str): Name of the time dimension.
    - lat_dim (str): Name of the latitude dimension.
    - lon_dim (str): Name of the longitude dimension.
    - plot_type (str): Type of plot. Options include 'scatter_geo' and 'imshow'.
    - downsample_factor (int, optional): Factor by which to downsample spatial data.
    - projection (str): Map projection style for Plotly.
    - color_scale (str or list): Color scale for the plot.
    - title (str, optional): Title of the plot. Uses `GRIB_name` if available.
    - animation_frame (str, optional): Dimension to animate over.
    - hover_precision (int): Decimal precision for hover data.
    - custom_colorbar_title (str, optional): Custom title for the color bar. Uses `GRIB_units` if available.
    - start_date (str, optional): Start date for the data in ISO format (YYYY-MM-DD).
    - end_date (str, optional): End date for the data in ISO format (YYYY-MM-DD).
    - **kwargs: Additional keyword arguments passed to the Plotly plotting functions.
    
    Returns:
    - fig (plotly.graph_objs._figure.Figure): The Plotly figure object.
    """
    
    # Extract GRIB metadata for title and colorbar
    grib_name = dataset[variable].attrs.get("GRIB_name", variable)
    grib_units = dataset[variable].attrs.get("GRIB_units", "")
    
    # Use GRIB_name as title if not provided
    if not title:
        title = f"{grib_name}"
        if start_date or end_date:
            title += f" ({start_date or ''} to {end_date or ''})"
    
    # Set colorbar title
    colorbar_title = custom_colorbar_title if custom_colorbar_title else grib_units

    # Extract the DataArray
    da = dataset[variable]

    # Check if variable dimensions are correct
    for dim in [time_dim, lat_dim, lon_dim]:
        if dim not in da.dims:
            raise ValueError(f"Dimension '{dim}' not found in variable '{variable}'.")

    # Filter by date range if specified
    if start_date or end_date:
        if start_date:
            da = da.sel({time_dim: da[time_dim] >= pd.to_datetime(start_date)})
        if end_date:
            da = da.sel({time_dim: da[time_dim] <= pd.to_datetime(end_date)})

    # Downsample if required
    if downsample_factor > 1:
        da = da.coarsen({lat_dim: downsample_factor, lon_dim: downsample_factor}, boundary='trim').mean()

    # Plotting
    if plot_type == 'scatter_geo':
        df = da.to_dataframe().reset_index().dropna(subset=[variable])

        fig = px.scatter_geo(
            df,
            lat=lat_dim,
            lon=lon_dim,
            color=variable,
            animation_frame=animation_frame,
            projection=projection,
            color_continuous_scale=color_scale,
            title=title,
            labels={variable: grib_name},
            hover_data={variable: f":.{hover_precision}f"},
            **kwargs
        )

        fig.update_layout(
            coloraxis_colorbar=dict(
                title=colorbar_title,
                ticks="outside"
            )
        )
    
    elif plot_type == 'imshow':
        z = da.mean(dim=time_dim).values  # Aggregate over time dimension for a static plot
        if da[lat_dim][0] > da[lat_dim][-1]:
            z = z[::-1, :]
        
        fig = px.imshow(
            z,
            labels=dict(x=lon_dim, y=lat_dim, color=grib_name),
            x=da[lon_dim].values,
            y=da[lat_dim].values,
            color_continuous_scale=color_scale,
            title=title,
            **kwargs
        )
        
        fig.update_layout(
            coloraxis_colorbar=dict(
                title=colorbar_title,
                ticks="outside"
            )
        )
    else:
        raise ValueError(f"Plot type '{plot_type}' is not supported.")
    
    return fig


Target variable: swvl3

In [7]:
# # Visualize 'sm' with animation using scatter_geo and a specified date range
# fig = visualize_variable_on_map(
#     dataset=ds,
#     variable='swvl3',
#     plot_type='scatter_geo',
#     animation_frame='date',
#     downsample_factor=12,
#     color_scale = [
#         [0.0, "darkred"],        # 0% - extremely dry
#         [0.1, "red"],            # 10% - very dry
#         [0.2, "orangered"],      # 20% - dry
#         [0.3, "lightgreen"],     # 30% - beginning of optimal moisture
#         [0.4, "limegreen"],      # 40% - optimal moisture for most crops
#         [0.5, "green"],          # 50% - middle of optimal moisture
#         [0.55, "darkseagreen"],  # 55% - upper limit of optimal moisture
#         [0.6, "darkgreen"],      # 60% - end of optimal moisture
#         [0.7, "lightblue"],      # 70% - moist but not excessive
#         [0.8, "skyblue"],        # 80% - very moist
#         [0.9, "deepskyblue"],    # 90% - extremely moist
#         [1.0, "blue"]            # 100% - very wet
#     ]
# )
# fig.show()

In [8]:
# # Visualize 'sm' with animation using scatter_geo and a specified date range
# fig = visualize_variable_on_map(
#     dataset=ds,
#     variable='swvl1',
#     plot_type='scatter_geo',
#     animation_frame='date',
#     downsample_factor=12,
#     color_scale = [
#         [0.0, "darkred"],        # 0% - extremely dry
#         [0.1, "red"],            # 10% - very dry
#         [0.2, "orangered"],      # 20% - dry
#         [0.3, "lightgreen"],     # 30% - beginning of optimal moisture
#         [0.4, "limegreen"],      # 40% - optimal moisture for most crops
#         [0.5, "green"],          # 50% - middle of optimal moisture
#         [0.55, "darkseagreen"],  # 55% - upper limit of optimal moisture
#         [0.6, "darkgreen"],      # 60% - end of optimal moisture
#         [0.7, "lightblue"],      # 70% - moist but not excessive
#         [0.8, "skyblue"],        # 80% - very moist
#         [0.9, "deepskyblue"],    # 90% - extremely moist
#         [1.0, "blue"]            # 100% - very wet
#     ]
# )
# fig.show()

In [9]:
# # Visualize 'sm' with animation using scatter_geo and a specified date range
# fig = visualize_variable_on_map(
#     dataset=ds,
#     variable='swvl2',
#     plot_type='scatter_geo',
#     animation_frame='date',
#     downsample_factor=12,
#     color_scale = [
#         [0.0, "darkred"],        # 0% - extremely dry
#         [0.1, "red"],            # 10% - very dry
#         [0.2, "orangered"],      # 20% - dry
#         [0.3, "lightgreen"],     # 30% - beginning of optimal moisture
#         [0.4, "limegreen"],      # 40% - optimal moisture for most crops
#         [0.5, "green"],          # 50% - middle of optimal moisture
#         [0.55, "darkseagreen"],  # 55% - upper limit of optimal moisture
#         [0.6, "darkgreen"],      # 60% - end of optimal moisture
#         [0.7, "lightblue"],      # 70% - moist but not excessive
#         [0.8, "skyblue"],        # 80% - very moist
#         [0.9, "deepskyblue"],    # 90% - extremely moist
#         [1.0, "blue"]            # 100% - very wet
#     ]
# )
# fig.show()

In [10]:
# # Visualize 'sm' with animation using scatter_geo and a specified date range
# fig = visualize_variable_on_map(
#     dataset=ds,
#     variable='swvl4',
#     plot_type='scatter_geo',
#     animation_frame='date',
#     downsample_factor=12,
#     color_scale = [
#         [0.0, "darkred"],        # 0% - extremely dry
#         [0.1, "red"],            # 10% - very dry
#         [0.2, "orangered"],      # 20% - dry
#         [0.3, "lightgreen"],     # 30% - beginning of optimal moisture
#         [0.4, "limegreen"],      # 40% - optimal moisture for most crops
#         [0.5, "green"],          # 50% - middle of optimal moisture
#         [0.55, "darkseagreen"],  # 55% - upper limit of optimal moisture
#         [0.6, "darkgreen"],      # 60% - end of optimal moisture
#         [0.7, "lightblue"],      # 70% - moist but not excessive
#         [0.8, "skyblue"],        # 80% - very moist
#         [0.9, "deepskyblue"],    # 90% - extremely moist
#         [1.0, "blue"]            # 100% - very wet
#     ]
# )
# fig.show()

In [11]:
# variable = 'u10'

# fig = visualize_variable_on_map(
#     dataset=ds,
#     variable=variable,
#     plot_type='scatter_geo',
#     animation_frame='date',
#     downsample_factor=12,
#     color_scale = 'Bluered'
# )
# fig.show()

In [12]:
# variable = 'v10'

# fig = visualize_variable_on_map(
#     dataset=ds,
#     variable=variable,
#     plot_type='scatter_geo',
#     animation_frame='date',
#     downsample_factor=12,
#     color_scale = 'Bluered'
# )
# fig.show()

In [13]:
# variable = 't2m'

# fig = visualize_variable_on_map(
#     dataset=ds,
#     variable=variable,
#     plot_type='scatter_geo',
#     animation_frame='date',
#     downsample_factor=12,
#     color_scale = 'Thermal'
# )
# fig.show()

In [14]:
# variable = 'sst'

# fig = visualize_variable_on_map(
#     dataset=ds,
#     variable=variable,
#     plot_type='scatter_geo',
#     animation_frame='date',
#     downsample_factor=12,
#     color_scale = 'Plasma'
# )
# fig.show()

In [15]:
# variable = 'sp'

# fig = visualize_variable_on_map(
#     dataset=ds,
#     variable=variable,
#     plot_type='scatter_geo',
#     animation_frame='date',
#     downsample_factor=12,
#     color_scale = 'Plasma'
# )
# fig.show()

In [16]:
# variable = 'tp'

# fig = visualize_variable_on_map(
#     dataset=ds,
#     variable=variable,
#     plot_type='scatter_geo',
#     animation_frame='date',
#     downsample_factor=12,
#     color_scale = 'Blues'
# )
# fig.show()

In [17]:
# variable = 'ssr'

# fig = visualize_variable_on_map(
#     dataset=ds,
#     variable=variable,
#     plot_type='scatter_geo',
#     animation_frame='date',
#     downsample_factor=12,
#     color_scale = 'YlOrBr'
# )
# fig.show()

In [None]:
# variable = 'tcc'

# fig = visualize_variable_on_map(
#     dataset=ds,
#     variable=variable,
#     plot_type='scatter_geo',
#     animation_frame='date',
#     downsample_factor=12,
#     color_scale = 'BuPu'
# )
# fig.show()

In [18]:
# variable = 'cl'

# fig = visualize_variable_on_map(
#     dataset=ds,
#     variable=variable,
#     plot_type='scatter_geo',
#     animation_frame='date',
#     downsample_factor=12,
#     color_scale = 'Blues'
# )
# fig.show()

cl - drop - malo wnosi do globalnego modelu 

In [19]:
# variable = 'e'

# fig = visualize_variable_on_map(
#     dataset=ds,
#     variable=variable,
#     plot_type='scatter_geo',
#     animation_frame='date',
#     downsample_factor=12,
#     color_scale = 'YlGnBu'
# )
# fig.show()

In [None]:
# variable = 'ro'

# fig = visualize_variable_on_map(
#     dataset=ds,
#     variable=variable,
#     plot_type='scatter_geo',
#     animation_frame='date',
#     downsample_factor=12,
#     color_scale = 'dense'
# )
# fig.show()

In [None]:
# variable = 'rsn'

# fig = visualize_variable_on_map(
#     dataset=ds,
#     variable=variable,
#     plot_type='scatter_geo',
#     animation_frame='date',
#     downsample_factor=12,
#     color_scale = 'ice'
# )
# fig.show()

In [24]:
# variable = 'stl1'

# fig = visualize_variable_on_map(
#     dataset=ds,
#     variable=variable,
#     plot_type='scatter_geo',
#     animation_frame='date',
#     downsample_factor=12,
#     color_scale = 'Hot_r'
# )
# fig.show()

In [25]:
# variable = 'cvh'

# fig = visualize_variable_on_map(
#     dataset=ds,
#     variable=variable,
#     plot_type='scatter_geo',
#     animation_frame='date',
#     downsample_factor=12,
#     color_scale = 'YlGn'
# )
# fig.show()

In [26]:
# variable = 'lai_hv'

# fig = visualize_variable_on_map(
#     dataset=ds,
#     variable=variable,
#     plot_type='scatter_geo',
#     animation_frame='date',
#     downsample_factor=12,
#     color_scale = 'tempo'
# )
# fig.show()

In [27]:
# variable = 'lai_lv'

# fig = visualize_variable_on_map(
#     dataset=ds,
#     variable=variable,
#     plot_type='scatter_geo',
#     animation_frame='date',
#     downsample_factor=12,
#     color_scale = 'tempo'
# )
# fig.show()

potencjalna wysoka korelacja tych dwóch zmiennych 

In [28]:
# variable = 'tvh'

# fig = visualize_variable_on_map(
#     dataset=ds,
#     variable=variable,
#     plot_type='scatter_geo',
#     animation_frame='date',
#     downsample_factor=12,
#     color_scale = 'YlGn_r'
# )
# fig.show()

potzrebne porownanie z cvh, potenjcalnie niosą tą samą informacje 

In [31]:
# variable = 'tvh'

# fig = visualize_variable_on_map(
#     dataset=ds,
#     variable=variable,
#     plot_type='scatter_geo',
#     animation_frame='date',
#     downsample_factor=12,
#     color_scale = 'Rainbow'
# )
# fig.show()

In [33]:
# variable = 'tvl'

# fig = visualize_variable_on_map(
#     dataset=ds,
#     variable=variable,
#     plot_type='scatter_geo',
#     animation_frame='date',
#     downsample_factor=12,
#     color_scale = 'Rainbow'
# )
# fig.show()

In [36]:
# variable = 'z'

# fig = visualize_variable_on_map(
#     dataset=ds,
#     variable=variable,
#     plot_type='scatter_geo',
#     animation_frame='date',
#     downsample_factor=12,
#     color_scale = 'RdBu_r'
# )
# fig.show()

In [38]:
# variable = 'lsm'

# fig = visualize_variable_on_map(
#     dataset=ds,
#     variable=variable,
#     plot_type='scatter_geo',
#     animation_frame='date',
#     downsample_factor=12,
#     color_scale = 'Purpor'
# )
# fig.show()

# Variable Variability Assessment

After reviewing the visualizations, the following observations were made about the variability of each variable:

| Variable                               | Variability Level |
|----------------------------------------|-------------------|
| 10m_u_component_of_wind                | High             |
| 10m_v_component_of_wind                | High             |
| 2m_temperature                         | High             |
| Sea surface temperature                | High             |
| Surface pressure                       | Low              |
| Total precipitation                    | High             |
| Surface net solar radiation            | High             |
| Total cloud cover                      | High             |
| Lake cover                             | Low              |
| Evaporation                            | High             |
| Runoff                                 | Low              |
| Snow density                           | Medium           |
| Soil temperature level 1               | High             |
| Soil type                              | Low              |
| Volumetric soil water layer 1          | Medium           |
| Volumetric soil water layer 2          | Medium           |
| Volumetric soil water layer 3 (Target) | Medium           |
| Volumetric soil water layer 4          | Medium           |
| High vegetation cover                  | Low              |
| Leaf area index - High vegetation      | Low              |
| Leaf area index - Low vegetation       | Low              |
| Low vegetation cover                   | Low              |
| Type of high vegetation                | Low              |
| Type of low vegetation                 | Low              |
| Geopotential                           | Low              |
| Land-sea mask                          | Low              |

Let's see if those assumptions are true 

In [43]:
import xarray as xr
import matplotlib.pyplot as plt

def plot_variable_distributions(ds: xr.Dataset) -> None:
    """
    Plot the distribution of each variable in the xarray Dataset.
    
    Args:
        ds (xr.Dataset): The xarray dataset containing variables to visualize.
    """
    # Iterate over each variable in the dataset
    for var in ds.data_vars:
        data = ds[var].values.flatten()  # Flatten the data for distribution plot
        
        # Remove NaN values to avoid plotting issues
        data = data[~np.isnan(data)]
        
        plt.figure(figsize=(10, 5))
        plt.hist(data, bins=50, alpha=0.7, color='violet')
        plt.title(f'Distribution of {var}')
        plt.xlabel(var)
        plt.ylabel('Frequency')
        plt.grid(True)
        plt.show()

# Usage
# plot_variable_distributions(ds)


### Missing values

In [19]:
def calculate_missing_percentages(df: pl.DataFrame, missing_value: float = None) -> pl.DataFrame:
    """
    Calculates the percentage of missing values in each column of a Polars DataFrame, 
    considering a specified missing value.

    Parameters:
    - df (pl.DataFrame): The input DataFrame with potentially missing values.
    - missing_value (float, optional): A custom value to treat as missing. Default is None, which means only None values are considered missing.

    Returns:
    - pl.DataFrame: A DataFrame with one row showing the percentage of missing values for each column.
    """
    total_rows = df.height  # Access the number of rows in the DataFrame
    if missing_value is not None:
        # Create a condition for custom missing value and None
        missing_condition = lambda column: (pl.col(column).is_null() | (pl.col(column) == missing_value)).sum().alias(column)
    else:
        # Default condition if no custom missing value provided
        missing_condition = lambda column: pl.col(column).is_null().sum().alias(column)
    
    missing_counts = df.select([missing_condition(column) for column in df.columns])
    missing_percentage = missing_counts / total_rows * 100
    
    return missing_percentage

# Example usage
# missing_percentages = calculate_missing_percentages(polars_df)
# missing_percentages


In [None]:
missing_percentages = calculate_missing_percentages(polars_df)
missing_percentages


In [None]:
# Example usage
missing_percentages = calculate_missing_percentages(polars_df, -999.0)
missing_percentages

## Descriptive statistics

### Summmary statistics

In [None]:
summary_stats_polars = polars_df.describe()
summary_stats_polars

In [3]:
import polars as pl
import matplotlib.pyplot as plt
import seaborn as sns

def plot_boxplots(polars_df, box_width=0.5):
    """
    Plot horizontal box plots for each variable in the Polars DataFrame except 'time', 'lat', and 'lon',
    with an adjustable width for the boxes.

    Parameters:
    - polars_df (pl.DataFrame): The input Polars DataFrame containing the data to be plotted.
    - box_width (float): The width of the box in the box plot. Default is 0.5.

    Returns:
    None, displays the box plots.
    """
    # Set the aesthetic appearance of the plots
    sns.set(style="whitegrid")

    # Get the list of columns to plot, excluding 'time', 'lat', and 'lon'
    columns_to_plot = [col for col in polars_df.columns if col not in ['time', 'lat', 'lon']]

    # Define a color palette
    palette = sns.color_palette("husl", len(columns_to_plot))

    # Loop through each column and create a separate boxplot
    for idx, column in enumerate(columns_to_plot):
        # Selecting the column for visualization
        values = polars_df[column].to_numpy()  # Convert to NumPy array for Matplotlib

        # Creating the horizontal box plot for the current column
        plt.figure(figsize=(10, 6))  # Adjust the figure size as needed
        sns.boxplot(x=values, color=palette[idx], width=box_width)  # Use seaborn for a colored boxplot with adjusted width
        plt.xlabel(f'{column}')  # X-axis Label
        plt.title(f'Horizontal Box Plot of {column}')  # Title of the plot
        plt.grid(True, linestyle='--', alpha=0.7)  # Enable grid for better readability
        plt.show()




In [4]:
# # Example usage:
# plot_boxplots(polars_df, box_width=0.25)

In [1]:
import polars as pl
import matplotlib.pyplot as plt
import matplotlib.cm as cm
import numpy as np

def plot_histograms(dataframe: pl.DataFrame, exclude_columns: list = ['time', 'lat', 'lon'], n_bins: int = 50):
    """
    Plot histograms for each column in the DataFrame except the specified columns.

    Parameters:
    - dataframe (pl.DataFrame): The Polars DataFrame containing the data.
    - exclude_columns (list): A list of column names to exclude from histogram plotting.
    - n_bins (int): Number of bins for each histogram.

    This function plots each selected column's histogram in a different color.
    """
    # Selecting the columns that are not in the exclude list
    hist_columns = [col for col in dataframe.columns if col not in exclude_columns]

    # Creating a figure with subplots; adjust the figure size and layout as necessary
    fig, axs = plt.subplots(len(hist_columns), 1, figsize=(10, 5 * len(hist_columns)), tight_layout=True)

    # If there's only one column to plot, axs may not be an array, handle single histogram case
    if len(hist_columns) == 1:
        axs = [axs]

    # Create a colormap
    colors = cm.viridis(np.linspace(0, 1, len(hist_columns)))

    # Plotting histogram for each selected column
    for ax, col, color in zip(axs, hist_columns, colors):
        # Extract the data for the column
        data = dataframe[col].to_numpy()  # Convert to numpy array for compatibility with matplotlib
        
        # Plot the histogram
        ax.hist(data, bins=n_bins, color=color, edgecolor='black')
        
        # Setting the title for each subplot
        ax.set_title(f'Histogram of {col}')
        ax.set_xlabel(col)
        ax.set_ylabel('Frequency')

    # Display the plots
    plt.show()



In [2]:
# # Example of usage:
# plot_histograms(polars_df) 


# Correlations

## Correlation between all variables 