# Pentad scatter comparison

This notebook compares two land-sweeper z-score statistic files by plotting scatter comparisons for five key variables (`o_mean`, `o_std`, `m_mean`, `m_std`, `n_data`). For ten randomly selected pentads we draw location-wise scatter plots contrasting the two datasets.

In [None]:
from pathlib import Path
import numpy as np
import xarray as xr
import matplotlib.pyplot as plt
from scipy import stats
from geospatial_plotting import plot_region, REGION_BOUNDS

plt.style.use('seaborn-v0_8-darkgrid')

In [None]:
base_dir = Path('../test_data/land_sweeper/LS_OLv8_M36/output/SMAP_EASEv2_M36_GLOBAL/stats')
file_dedup = base_dir / 'M36_dedup_zscore_stats_2007_doy152_2024_doy151_W_75d_Nmin_20_sp_ALL_all_pentads.nc4'
file_full = base_dir / 'M36_zscore_stats_2007_doy152_2024_doy90_W_75d_Nmin_20_sp_ALL_all_pentads.nc4'
# file_full = base_dir / 'M36_python_dedup_zscore_stats_2007_doy152_2024_doy151_W_75d_Nmin_20_sp_ALL_all_pentads.nc4'
file_dedup, file_full

In [None]:
ds_dedup = xr.open_dataset(file_dedup)
ds_full = xr.open_dataset(file_full)
ds_dedup, ds_full

In [None]:
dims_dedup = {name: ds_dedup.dims[name] for name in ds_dedup.dims}
dims_full = {name: ds_full.dims[name] for name in ds_full.dims}
dims_dedup, dims_full

In [None]:
vars_to_compare = ['o_mean', 'o_std', 'm_mean', 'm_std', 'n_data']
common_pentads = np.intersect1d(ds_dedup['pentad'].values, ds_full['pentad'].values)
if common_pentads.size < 10:
    raise ValueError('Not enough pentads to sample 10 unique entries.')
rng = np.random.default_rng(42)
selected_pentads = np.sort(rng.choice(common_pentads, size=10, replace=False))
selected_pentads

In [None]:
max_points = 20000  # limit plotted points for readability
rng = np.random.default_rng(123)

for pentad in selected_pentads:
    fig, axes = plt.subplots(2, 3, figsize=(18, 10))
    axes = axes.ravel()
    for ax, var in zip(axes, vars_to_compare):
        data_a = ds_dedup[var].sel(pentad=pentad).transpose('lon', 'lat', missing_dims='ignore').values.reshape(-1)
        data_b = ds_full[var].sel(pentad=pentad).transpose('lon', 'lat', missing_dims='ignore').values.reshape(-1)
        mask = np.isfinite(data_a) & np.isfinite(data_b)
        data_a = data_a[mask]
        data_b = data_b[mask]
        if data_a.size == 0:
            ax.text(0.5, 0.5, 'No overlapping data', ha='center', va='center', transform=ax.transAxes)
            ax.set_axis_off()
            continue
        if data_a.size > max_points:
            idx = rng.choice(data_a.size, size=max_points, replace=False)
            data_a = data_a[idx]
            data_b = data_b[idx]
        ax.scatter(data_a, data_b, s=2, alpha=0.3, edgecolor='none')
        combined = np.concatenate([data_a, data_b])
        vmin, vmax = np.nanmin(combined), np.nanmax(combined)
        if vmin == vmax:
            vmin -= 1
            vmax += 1
        ax.plot([vmin, vmax], [vmin, vmax], color='black', linewidth=1, linestyle='--')
        ax.set_title(var)
        ax.set_xlabel('Deduplicated dataset')
        ax.set_ylabel('Original dataset')
    for leftover_ax in axes[len(vars_to_compare):]:
        leftover_ax.set_visible(False)
    fig.suptitle(f'Pentad {int(pentad)} scatter comparisons')
    plt.tight_layout()
    plt.show()

## Cross-file observation vs model tests

We test whether deduplicated and original files yield different Gaussian summaries for the same quantities (observations and model). For each pentad/location we run Welch t-tests for mean differences and two-sided F-tests for variance differences.

In [None]:

alpha = 0.05


def compute_lon_lat(ds):
    lon0 = float(ds['ll_lon'])
    lat0 = float(ds['ll_lat'])
    dlon = float(ds['d_lon'])
    dlat = float(ds['d_lat'])
    lon = lon0 + np.arange(ds.dims['lon']) * dlon
    lat = lat0 + np.arange(ds.dims['lat']) * dlat
    return lon, lat


def reorder_dims(da):
    target = [dim for dim in ('pentad', 'lon', 'lat') if dim in da.dims]
    return da.transpose(*target, missing_dims='ignore')


def compute_cross_significance(ds_a, ds_b, mean_key, std_key):
    data_a = reorder_dims(ds_a[mean_key])
    data_b = reorder_dims(ds_b[mean_key])
    coords = data_a.coords
    dims = data_a.dims

    mu_a = data_a.values
    mu_b = data_b.values
    std_a = reorder_dims(ds_a[std_key]).values
    std_b = reorder_dims(ds_b[std_key]).values
    n_a = reorder_dims(ds_a['n_data']).values
    n_b = reorder_dims(ds_b['n_data']).values

    var_a = np.square(std_a)
    var_b = np.square(std_b)

    mean_p = np.full(mu_a.shape, np.nan, dtype=np.float64)
    var_p = np.full(mu_a.shape, np.nan, dtype=np.float64)

    with np.errstate(divide='ignore', invalid='ignore'):
        se2 = var_a / n_a + var_b / n_b
        diff = mu_a - mu_b
        valid_mean = (n_a > 1) & (n_b > 1) & np.isfinite(se2) & (se2 > 0)
        if np.any(valid_mean):
            t_stat = np.full(mu_a.shape, np.nan, dtype=np.float64)
            t_stat[valid_mean] = diff[valid_mean] / np.sqrt(se2[valid_mean])
            term_a = np.zeros(mu_a.shape, dtype=np.float64)
            term_b = np.zeros(mu_a.shape, dtype=np.float64)
            valid_term_a = n_a > 1
            valid_term_b = n_b > 1
            term_a[valid_term_a] = (var_a[valid_term_a] / n_a[valid_term_a]) ** 2 / (n_a[valid_term_a] - 1)
            term_b[valid_term_b] = (var_b[valid_term_b] / n_b[valid_term_b]) ** 2 / (n_b[valid_term_b] - 1)
            denom = term_a + term_b
            valid_dof = valid_mean & (denom > 0)
            if np.any(valid_dof):
                dof = np.full(mu_a.shape, np.nan, dtype=np.float64)
                dof[valid_dof] = (se2[valid_dof] ** 2) / denom[valid_dof]
                mean_p[valid_dof] = 2.0 * stats.t.sf(np.abs(t_stat[valid_dof]), dof[valid_dof])

        valid_var = (n_a > 2) & (n_b > 2) & np.isfinite(var_a) & np.isfinite(var_b) & (var_a > 0) & (var_b > 0)
        if np.any(valid_var):
            f_stat = np.full(mu_a.shape, np.nan, dtype=np.float64)
            f_stat[valid_var] = var_a[valid_var] / var_b[valid_var]
            df1 = n_a - 1
            df2 = n_b - 1
            valid_df = valid_var & (df1 > 0) & (df2 > 0)
            if np.any(valid_df):
                cdf_vals = stats.f.cdf(f_stat[valid_df], df1[valid_df], df2[valid_df])
                sf_vals = stats.f.sf(f_stat[valid_df], df1[valid_df], df2[valid_df])
                var_p[valid_df] = 2.0 * np.minimum(cdf_vals, sf_vals)

    mean_da = xr.DataArray(mean_p, coords=coords, dims=dims)
    var_da = xr.DataArray(var_p, coords=coords, dims=dims)

    mean_flags = (mean_da < alpha).where(np.isfinite(mean_da))
    var_flags = (var_da < alpha).where(np.isfinite(var_da))

    pentad_mean_frac = mean_flags.mean(dim=('lon', 'lat'), skipna=True)
    pentad_var_frac = var_flags.mean(dim=('lon', 'lat'), skipna=True)
    location_mean_frac = mean_flags.mean(dim='pentad', skipna=True)
    location_var_frac = var_flags.mean(dim='pentad', skipna=True)

    return {
        'mean_flags': mean_flags,
        'var_flags': var_flags,
        'pentad_mean_frac': pentad_mean_frac,
        'pentad_var_frac': pentad_var_frac,
        'location_mean_frac': location_mean_frac,
        'location_var_frac': location_var_frac,
    }


def dataarray_to_map_array(data_array, lon_vals, lat_vals):
    da = data_array.transpose('lon', 'lat', missing_dims='ignore')
    data = da.values
    lon_grid, lat_grid = np.meshgrid(lon_vals, lat_vals, indexing='ij')
    flat_vals = data.reshape(-1)
    flat_lon = lon_grid.reshape(-1)
    flat_lat = lat_grid.reshape(-1)
    mask = np.isfinite(flat_vals)
    return np.column_stack((flat_vals[mask], flat_lon[mask], flat_lat[mask]))


In [None]:

stats_pentads = common_pentads
subset_dedup = ds_dedup.sel(pentad=stats_pentads)
subset_full = ds_full.sel(pentad=stats_pentads)

results = {
    'Observations (o)': compute_cross_significance(subset_dedup, subset_full, 'o_mean', 'o_std'),
    'Model (m)': compute_cross_significance(subset_dedup, subset_full, 'm_mean', 'm_std'),
}
results


In [None]:

colors = {'Observations (o)': 'tab:blue', 'Model (m)': 'tab:orange'}
fig, axes = plt.subplots(2, 1, figsize=(12, 8), sharex=True)
for label, res in results.items():
    axes[0].plot(res['pentad_mean_frac'], label=label, color=colors[label])
axes[0].set_ylabel('Fraction flagged')
axes[0].set_title('Mean difference test (dedup vs original)')
axes[0].legend()
for label, res in results.items():
    axes[1].plot(res['pentad_var_frac'], label=label, color=colors[label])
axes[1].set_ylabel('Fraction flagged')
axes[1].set_xlabel('Pentad index')
axes[1].set_title('Variance difference test (dedup vs original)')
axes[1].legend()
plt.tight_layout()
plt.show()


In [None]:

lon_vals, lat_vals = compute_lon_lat(ds_dedup)
lon_vals, lat_vals = compute_lon_lat(ds_dedup)
for label, res in results.items():
    mean_array = dataarray_to_map_array(res['location_mean_frac'], lon_vals, lat_vals)
    var_array = dataarray_to_map_array(res['location_var_frac'], lon_vals, lat_vals)

    # replace exact-zero values in the fraction column with NaN for plotting
    mean_zero_mask = np.isclose(mean_array[:, 0], 0.0)
    var_zero_mask = np.isclose(var_array[:, 0], 0.0)
    mean_array[mean_zero_mask, 0] = np.nan
    var_array[var_zero_mask, 0] = np.nan

    mean_title = f'{label}: mean-test fraction per location'
    fig, _ = plot_region(
        mean_array,
        region_bounds=REGION_BOUNDS['global'],
        meanflag=False,
        plot_title=mean_title,
        units='Fraction',
        cmin=0.0,
        cmax=1.0,
    )
    plt.show()
    var_title = f'{label}: variance-test fraction per location'
    fig, _ = plot_region(
        var_array,
        region_bounds=REGION_BOUNDS['global'],
        meanflag=False,
        plot_title=var_title,
        units='Fraction',
        cmin=0.0,
        cmax=1.0,
    )
    plt.show()