In [1]:
import numpy as np
import xarray as xr
from scipy.stats import genextreme
import datetime
import matplotlib.pyplot as plt # Assumed installed for example plot
from GEV import * 
from utils import *

Data Generation functions

In [31]:
def create_independent_gev_grid(loc, scale, shape,
                                num_x=5, num_y=5, num_t=5):
    """
    Generates an xarray.Dataset with independent GEV samples at each point.

    Each value in the grid (x, y, t) is drawn independently from the same
    GEV(loc, scale, shape) distribution.

    Args:
        loc (float): The location parameter (mu) of the GEV distribution.
        scale (float): The scale parameter (sigma > 0) of the GEV distribution.
        shape (float): The shape parameter (xi) of the GEV distribution.
                       Note: Internally uses c = -shape for scipy.stats.genextreme.
                       shape < 0 (c > 0) => Weibull type (bounded above)
                       shape = 0 (c = 0) => Gumbel type
                       shape > 0 (c < 0) => Fréchet type (heavy tail)
        num_x (int): Number of points along the x dimension. Default is 5.
        num_y (int): Number of points along the y dimension. Default is 5.
        num_t (int): Number of points along the t dimension. Default is 5.

    Returns:
        xarray.Dataset: An xarray Dataset with dimensions ('x', 'y', 't')
                        and a single data variable 'extreme_value' containing
                        independent GEV samples.

    Raises:
        ValueError: If the scale parameter is not positive.
    """
    # --- Input Validation ---
    if scale <= 0:
        raise ValueError("Scale parameter (sigma) must be positive.")

    # --- Define Coordinates ---
    x_coords = np.arange(num_x)
    y_coords = np.arange(num_y)
    t_coords = np.arange(num_t)

    # --- Generate GEV Data ---
    total_values = num_x * num_y * num_t
    # Generate independent samples using the specified parameters
    # Note: Using c = -shape consistent with user input convention
    gev_data_flat = genextreme.rvs(c=-shape, loc=loc, scale=scale, size=total_values)
    gev_data_3d = gev_data_flat.reshape(num_x, num_y, num_t)

    # --- Create xarray DataArray ---
    data_variable = xr.DataArray(
        data=gev_data_3d,
        coords={'lon': x_coords, 'lat': y_coords, 'time': t_coords},
        dims=['lon', 'lat', 'time'],
        name='extreme_value',
        attrs={
            'description': 'Independent synthetic data generated from a GEV distribution.',
            'units': 'unitless',
            'gev_location': loc,
            'gev_scale': scale,
            'gev_shape_xi': shape, # Store the input shape parameter (xi)
            'scipy_shape_c': -shape # Store the c parameter used by scipy
        }
    )
    # --- Create xarray Dataset ---
    dataset = xr.Dataset({'extreme_value': data_variable})
    dataset.attrs['creation_timestamp'] = datetime.datetime.now().isoformat()
    dataset.attrs['source'] = 'Generated by create_independent_gev_grid'
    dataset.attrs['grid_type'] = 'dimensionless integer coordinates'
    dataset.attrs['dependencies'] = 'Identically distributed, independent samples'

    return dataset



def create_gev_random_linear_trend(loc, scale, shape,
                                   num_x=5, num_y=5, num_t=5, seed=0):
    """
    Generates GEV data with a location parameter varying linearly with time.

    The location parameter at time t is loc(t) = b0 + b1 * t.
    The intercept b0 is randomly chosen based on the input 'loc'.
    The slope b1 is randomly chosen based on 'scale' and 'num_t'.
    Both are generated using the specified 'seed'.
    Scale and shape parameters are constant.

    Args:
        loc (float): The base location parameter (mu) around which the trend starts (t=0).
        scale (float): The constant scale parameter (sigma > 0). Used for GEV generation
                       and influences the magnitude of the random trend slope.
        shape (float): The constant shape parameter (xi).
                       Note: Internally uses c = -shape for scipy.stats.genextreme.
        num_x (int): Number of points along the x dimension. Default is 5.
        num_y (int): Number of points along the y dimension. Default is 5.
        num_t (int): Number of points along the t dimension. Default is 5.
        seed (int): Seed for the random number generator to determine b0 and b1. Default is 0.

    Returns:
        xarray.Dataset: An xarray Dataset with dimensions ('x', 'y', 't')
                        containing GEV samples with a time-varying location parameter.
                        The generated b0 and b1 are stored as attributes.

    Raises:
        ValueError: If scale <= 0.
    """
    # --- Input Validation ---
    if scale <= 0:
        raise ValueError("Scale parameter (sigma) must be positive.")

    # --- Generate Trend Parameters Internally ---
    rng = np.random.default_rng(seed) # Use modern generator with seed

    # Generate b0: Centered around 'loc' with slight variation based on 'scale'
    # Ensures loc(t=0) is close to the provided 'loc'
    b0 = rng.normal(loc=loc, scale=abs(scale * 0.1)) # Use abs(scale) just in case

    # Generate b1: Slope determined relative to scale and time duration
    # Aim for total change over the period to be within a few 'scale' units
    # k defines target total change factor (e.g., k=1 means loc changes by ~1*scale over num_t steps)
    k = rng.uniform(0, 10) # Random factor for total change magnitude/direction
    time_horizon = max(1, num_t - 1) # Avoid division by zero if num_t=1
    b1 = (k * scale) / time_horizon # Calculate the resulting slope per time step

    # --- Define Coordinates ---
    x_coords = np.arange(num_x)
    y_coords = np.arange(num_y)
    t_coords = np.arange(num_t)

    # --- Calculate Time-Varying Location Parameter ---
    loc_t = b0 + b1 * t_coords # Shape (num_t,)

    # --- Generate GEV Data (Iterating through time) ---
    gev_data_3d = np.empty((num_x, num_y, num_t))
    num_spatial_points = num_x * num_y

    for i in range(num_t):
        current_loc = loc_t[i]
        # Generate samples for all x,y points at this time step 'i'
        # Note: Using c = -shape consistent with input convention
        time_slice_flat = genextreme.rvs(c=-shape, loc=current_loc, scale=scale,
                                         size=num_spatial_points)
        # Reshape and store in the main array
        gev_data_3d[:, :, i] = time_slice_flat.reshape(num_x, num_y)

    # --- Create xarray DataArray ---
    data_variable = xr.DataArray(
        data=gev_data_3d,
        coords={'lon': x_coords, 'lat': y_coords, 'time': t_coords},
        dims=['lon', 'lat', 'time'],
        name='extreme_value', # New name for variable
        attrs={
            'description': 'Synthetic GEV data with internally generated linear trend in location (loc = b0 + b1*t).',
            'units': 'unitless',
            'trend_intercept_b0': b0, # Store generated b0
            'trend_slope_b1': b1,     # Store generated b1
            'trend_rng_seed': seed,   # Store seed used
            'gev_scale': scale,       # Constant scale
            'gev_shape_xi': shape,    # Constant shape (xi)
        }
    )
    # --- Create xarray Dataset ---
    dataset = xr.Dataset({'extreme_value': data_variable})
    dataset.attrs['creation_timestamp'] = datetime.datetime.now().isoformat()
    dataset.attrs['source'] = 'Generated by create_gev_random_linear_trend'
    dataset.attrs['grid_type'] = 'dimensionless integer coordinates'
    dataset.attrs['dependencies'] = 'Time dependence via loc(t), spatial independence at fixed t.'
    dataset.attrs['trend_slope_b1'] = b1

    return dataset

Data visualization functions

In [3]:
def plot_time(data, x_index, y_index):
    """
    Plots the time series for a specific (x, y) point in the xarray data.

    Args:
        data (xarray.Dataset or xarray.DataArray): The dataset or data array
            containing the 'extreme_value' variable with dims ('x', 'y', 't').
        x_index (int): The integer index for the x dimension.
        y_index (int): The integer index for the y dimension.
    """
    # If input is a Dataset, extract the DataArray
    if isinstance(data, xr.Dataset):
        if 'extreme_value' not in data:
            print("Error: Dataset does not contain 'extreme_value' variable.")
            return
        da = data['extreme_value']
    elif isinstance(data, xr.DataArray):
        da = data
    else:
        print("Error: Input must be an xarray Dataset or DataArray.")
        return

    try:
        # Select the data point using integer indices
        selected_point_data = da.isel(x=x_index, y=y_index)

        # Create the plot
        plt.figure(figsize=(10, 5)) # Adjust figure size if needed
        selected_point_data.plot.line(marker='.', linestyle='-') # Use xarray's plotting

        # Customize
        plt.title(f"Time Series at x={selected_point_data['x'].item()}, y={selected_point_data['y'].item()}")
        plt.xlabel("Dimensionless Time (t)")
        plt.ylabel(da.name or "Value") # Use variable name if available
        plt.grid(True, alpha=0.5)
        plt.tight_layout() # Adjust layout
        plt.show()

    except IndexError:
        print(f"Error: Indices x={x_index}, y={y_index} are out of bounds "
              f"for data with shape {da.shape[:2]}.")
    except Exception as e:
        print(f"An error occurred during time series plotting: {e}")


def map(data, t_index):
    """
    Plots a 2D spatial map (image) of the data at a specific time index 't'.

    Args:
        data (xarray.Dataset or xarray.DataArray): The dataset or data array
            containing the 'extreme_value' variable with dims ('x', 'y', 't').
        t_index (int): The integer index for the t dimension.
    """
    # If input is a Dataset, extract the DataArray
    if isinstance(data, xr.Dataset):
        if 'extreme_value' not in data:
            print("Error: Dataset does not contain 'extreme_value' variable.")
            return
        da = data['extreme_value']
    elif isinstance(data, xr.DataArray):
        da = data
    else:
        print("Error: Input must be an xarray Dataset or DataArray.")
        return

    try:
        # Select the time slice using integer index
        selected_time_slice = da.isel(t=t_index)

        # Create the plot using imshow
        plt.figure(figsize=(8, 6)) # Adjust figure size if needed
        selected_time_slice.plot.imshow(cmap='viridis') # Use xarray's imshow plotting

        # Customize
        # Xarray's plot.imshow usually sets title and labels reasonably
        plt.title(f"Spatial Map at t = {selected_time_slice['t'].item()}")
        plt.xlabel("Dimensionless Coordinate (x)")
        plt.ylabel("Dimensionless Coordinate (y)")
        # Colorbar is added automatically by default with plot.imshow
        plt.tight_layout()
        plt.show()

    except IndexError:
        print(f"Error: Time index t={t_index} is out of bounds "
              f"for data with time dimension size {da.shape[2]}.")
    except Exception as e:
        print(f"An error occurred during spatial map plotting: {e}")

In [64]:
location_param = 68
scale_param = 11.8
shape_param = 0.1

# 2. Specify grid dimensions
x_dim_size = 10
y_dim_size = 10
t_dim_size = 100

# 3. Generate the xarray dataset
try:
    gev_dataset_dimless = create_independent_gev_grid(
        loc=location_param,
        scale=scale_param,
        shape=shape_param,
        num_x=x_dim_size,
        num_y=y_dim_size,
        num_t=t_dim_size
    )

    # 4. Inspect the dataset (optional)
    print("Generated xarray Dataset (Dimensionless Coordinates):")
    print(gev_dataset_dimless)

    # 5. Use the new plotting functions
    # --- Plot time series at a specific point (e.g., x=10, y=15) ---
    plot_x_index = 5
    plot_y_index = 5
    print(f"\nPlotting time series for point (lon={plot_x_index}, lat={plot_y_index})...")
    plot_time(gev_dataset_dimless, x_index=plot_x_index, y_index=plot_y_index)

    # --- Plot spatial map at a specific time (e.g., t=0) ---
    plot_t_index = 0
    print(f"\nPlotting spatial map for time index (time={plot_t_index})...")
    map(gev_dataset_dimless, t_index=plot_t_index)

    # --- Example: Plot spatial map at a later time (e.g., t=50) ---
    plot_t_index_later = 50
    print(f"\nPlotting spatial map for time index (time={plot_t_index_later})...")
    map(gev_dataset_dimless, t_index=plot_t_index_later)


except ValueError as e:
    # Handles the ValueError check for scale <= 0 inside the function
    print(f"Error generating dataset: {e}")

Generated xarray Dataset (Dimensionless Coordinates):
<xarray.Dataset> Size: 80kB
Dimensions:        (lon: 10, lat: 10, time: 100)
Coordinates:
  * lon            (lon) int32 40B 0 1 2 3 4 5 6 7 8 9
  * lat            (lat) int32 40B 0 1 2 3 4 5 6 7 8 9
  * time           (time) int32 400B 0 1 2 3 4 5 6 7 ... 92 93 94 95 96 97 98 99
Data variables:
    extreme_value  (lon, lat, time) float64 80kB 73.14 67.8 ... 76.54 60.28
Attributes:
    creation_timestamp:  2025-04-07T16:04:03.614562
    source:              Generated by create_independent_gev_grid
    grid_type:           dimensionless integer coordinates
    dependencies:        Identically distributed, independent samples

Plotting time series for point (lon=5, lat=5)...
An error occurred during time series plotting: Dimensions {'y', 'x'} do not exist. Expected one or more of ('lon', 'lat', 'time')

Plotting spatial map for time index (time=0)...
An error occurred during spatial map plotting: Dimensions {'t'} do not exist. Expecte

Fitting

In [59]:
#print(gev_dataset_dimless.attrs["trend_slope_b1"])
endog, exog_loc, mdt = xarray_to_endog_exog(ds = gev_dataset_dimless,endog_var="extreme_value",include_time_coords=True, include_space_coords=False)
exog_loc.shape

(100, 1, 100)

In [58]:
afit = GEVSample(endog=endog, exog=exog).fit(fit_method='MLE')
print(afit)

ValueError: exog['shape'] must have first and third dimensions (100, 100), but got (75, 100).