# CYGNSS CD: Fixed vs Original Diagnostics

This notebook compares the rerun ("fixed") CYGNSS experiments against the original deliveries for both the open-loop (OL) and the CYGNSS DA configuration. It mirrors the structure of `cygnss_cd_ofa_figures_081825.ipynb`, but the file paths point to the 2018-08 through 2019-07 subset provided for this validation exercise.

In [None]:
import numpy as np
from pathlib import Path
from datetime import datetime
from netCDF4 import Dataset
import matplotlib.pyplot as plt
import pickle

import sys; sys.path.append('../util/shared/python/')
from read_GEOSldas import read_tilecoord
from mapper_functions import plot_aus_tight_pcm, plot_global_tight_pcm
from geospatial_plotting import plot_region, REGION_BOUNDS

In [None]:
# Define which species belong to each sensor group
species_groups = {
    "SMOS": [0, 1, 2, 3],
    "SMAP": [4, 5, 6, 7],
    "ASCAT": [8, 9, 10],
    "CYGNSS": [11]
}
NMIN = 20  # minimum obs count threshold for diagnostics

In [None]:
# File configuration for the original and fixed deliveries
base_dir = Path('/Users/amfox/Desktop/GEOSldas_diagnostics/test_data/CYGNSS_Experiments')
fig_dir_ol = base_dir / 'OLv8_M36_cd' / 'OLv8_M36_cd' / 'output' / 'SMAP_EASEv2_M36_GLOBAL' / 'figures'
#fig_dir_da = base_dir / 'DAv8_M36_cd_ssa' / 'DAv8_M36_cd_ssa' / 'output' / 'SMAP_EASEv2_M36_GLOBAL' / 'figures'
fig_dir_da = base_dir / 'DAv8_M36_cd_a' / 'DAv8_M36_cd_a' / 'output' / 'SMAP_EASEv2_M36_GLOBAL' / 'figures'

data_paths = {
    'OL_orig': {
        'temporal': fig_dir_ol / 'temporal_stats_OL_20180801_20190731.nc4',
        'spatial': fig_dir_ol / 'spatial_stats_OL_201808_201907.pkl'
    },
    'OL_fixed': {
        'temporal': fig_dir_ol / 'temporal_stats_OL_fixed_20180801_20190731.nc4',
        'spatial': fig_dir_ol / 'spatial_stats_OL_fixed_201808_201907.pkl'
    },
    'DA_orig': {
        'temporal': fig_dir_da / 'temporal_stats_DA_20180801_20190731.nc4',
        'spatial': fig_dir_da / 'spatial_stats_DA_201808_201907.pkl'
    },
    'DA_fixed': {
        'temporal': fig_dir_da / 'temporal_stats_DA_fixed_20180801_20190731.nc4',
        'spatial': fig_dir_da / 'spatial_stats_DA_fixed_201808_201907.pkl'
    },
    'DA_dedup': {
        'temporal': fig_dir_da / 'temporal_stats_DA_dedup_20180801_20190731.nc4',
        'spatial': fig_dir_da / 'spatial_stats_DA_dedup_201808_201907.pkl'
    },
    'DA_64': {
        'temporal': fig_dir_da / 'temporal_stats_DA_lowerr_20180801_20190731.nc4',
        'spatial': fig_dir_da / 'spatial_stats_DA_lowerr_201808_201907.pkl'
    },
    'DA_45': {
        'temporal': fig_dir_da / 'temporal_stats_DA_v_lowerr_20180801_20190731.nc4',
        'spatial': fig_dir_da / 'spatial_stats_DA_v_lowerr_201808_201907.pkl'
    },
}

for label, files in data_paths.items():
    for kind, path in files.items():
        if not path.exists():
            raise FileNotFoundError(f"Missing {kind} file for {label}: {path}")

In [None]:
def read_temporal_stats(nc_path):
    """Load the tile-wise statistics from the netCDF diagnostics file."""
    stats = {}
    with Dataset(nc_path, 'r') as nc:
        for key, value in nc.variables.items():
            stats[key] = value[:].filled(np.nan)
    return stats


def read_spatial_pickle(pkl_path):
    with open(pkl_path, 'rb') as f:
        return pickle.load(f)


def apply_obs_threshold(stats, nmin):
    masked = {key: np.array(value, copy=True) for key, value in stats.items()}
    n_data = masked['N_data']
    for key in ['OmF_mean', 'OmF_stdv', 'OmF_norm_mean', 'OmF_norm_stdv', 'OmA_mean', 'OmA_stdv']:
        arr = masked[key]
        arr[n_data < nmin] = np.nan
        masked[key] = arr
    n_data_masked = n_data.copy()
    n_data_masked[n_data_masked < nmin] = 0
    masked['N_data'] = n_data_masked
    return masked


def compute_group_metrics(stats, species_groups):
    group_metrics = {}
    n_data = stats['N_data']
    for group, idx in species_groups.items():
        weights = n_data[:, idx]
        total_weights = np.nansum(weights, axis=1)
        group_metrics[group] = {'Nobs_data': total_weights}
        for key in ['OmF_mean', 'OmF_stdv', 'OmF_norm_mean', 'OmF_norm_stdv', 'OmA_mean', 'OmA_stdv']:
            values = stats[key][:, idx]
            weighted = np.nansum(values * weights, axis=1)
            # group_metrics[group][key] = np.divide(
            #     weighted,
            #     total_weights,
            #     out=np.full_like(total_weights, np.nan, dtype=float),
            #     where=total_weights > 0
            # )
            group_metrics[group][key] = np.nanmean(values, axis=1)
    return group_metrics


def convert_stats_dict_to_arrays(stats_dict):
    array_dict = {}
    for key, values in stats_dict.items():
        array = np.array(values)
        if array.dtype == object:
            cleaned = []
            for row in array:
                cleaned.append([np.nan if val == '--' else float(val) for val in row])
            array = np.array(cleaned, dtype=float)
        array_dict[key] = array
    return array_dict


def calculate_weighted_group_stats(stats_dict, species_groups):
    n_times = len(stats_dict['OmF_mean'])
    stats = ['O_mean', 'F_mean', 'OmF_mean', 'OmF_stdv', 'OmA_mean', 'OmA_stdv']
    group_stats = {group: {stat: np.zeros(n_times) for stat in stats} for group in species_groups.keys()}
    for group in species_groups.keys():
        group_stats[group]['N_data'] = np.zeros(n_times)
    for t in range(n_times):
        for group, indices in species_groups.items():
            weights = stats_dict['N_data'][t, indices]
            total_weight = np.nansum(weights)
            if total_weight > 0:
                for stat in stats:
                    values = stats_dict[stat][t, indices]
                    group_stats[group][stat][t] = np.average(values, weights=weights)
                group_stats[group]['N_data'][t] = total_weight
            else:
                for stat in stats:
                    group_stats[group][stat][t] = np.nan
                group_stats[group]['N_data'][t] = 0
    return group_stats

In [None]:
# Load diagnostics for each dataset key
experiment_data = {}
for label, files in data_paths.items():
    exp_entry = {}
    exp_entry['temporal_stats'] = read_temporal_stats(files['temporal'])
    exp_entry['temporal_stats_masked'] = apply_obs_threshold(exp_entry['temporal_stats'], NMIN)
    exp_entry['group_metrics'] = compute_group_metrics(exp_entry['temporal_stats_masked'], species_groups)
    spatial_dict = read_spatial_pickle(files['spatial'])
    exp_entry['spatial_dict'] = spatial_dict
    exp_entry['ts_arrays'] = convert_stats_dict_to_arrays(spatial_dict)
    date_vec = spatial_dict.get('date_vec', None)
    if date_vec is None:
        raise ValueError(f'Missing date_vec in {files["spatial"]}')
    exp_entry['dates'] = [datetime.strptime(date, '%Y%m') for date in date_vec]
    exp_entry['ts_group_stats'] = calculate_weighted_group_stats(exp_entry['ts_arrays'], species_groups)
    experiment_data[label] = exp_entry

print('Loaded experiments:', ', '.join(experiment_data.keys()))

In [None]:
# Tile coordinate metadata used for every spatial plot
ftc = base_dir / 'DAv8_M36_cd' / 'DAv8_M36_cd' / 'output' / 'SMAP_EASEv2_M36_GLOBAL' / 'rc_out' / 'DAv8_M36_cd.ldas_tilecoord.bin'
tc = read_tilecoord(str(ftc))
n_tile = tc['N_tile']
map_template = np.empty((n_tile, 3))
map_template[:] = np.nan
map_template[:, 1] = tc['com_lon']
map_template[:, 2] = tc['com_lat']
print('Tile coordinate file:', ftc)

In [None]:
def _resolve_metric_name(metric):
    """Map user-friendly metric names to keys in spatial group metrics."""
    if metric == 'N_data':
        return 'Nobs_data'
    return metric


def plot_diff_map(dataset_a, dataset_b, group, metric='OmF_stdv', percent=False,
                  region='cygnss', units='m3/m3', cmin=None, cmax=None,
                  title_prefix='Fixed - Original'):
    metric_key = _resolve_metric_name(metric)
    arr_a = experiment_data[dataset_a]['group_metrics'][group][metric_key]
    arr_b = experiment_data[dataset_b]['group_metrics'][group][metric_key]
    diff = arr_b - arr_a
    if percent:
        diff = np.divide(
            diff,
            arr_a,
            out=np.full_like(arr_a, np.nan, dtype=float),
            where=arr_a != 0
        ) * 100.0
        units = '%'
    payload = map_template.copy()
    payload[:, 0] = diff
    vmax = np.nanmax(payload[:, 0])
    vmin = np.nanmin(payload[:, 0])
    fig, ax = plot_region(
        payload,
        region_bounds=REGION_BOUNDS[region],
        meanflag=True,
        plot_title=(f"{title_prefix} {group} {metric}"
                    f"({dataset_b} - {dataset_a})  Max: {vmax:.3g}  Min: {vmin:.3g}"),
        units=units,
        cmin=cmin,
        cmax=cmax
    )
    fig.tight_layout()
    return fig, ax


def plot_ts_comparison(*datasets_and_group, metric='OmF_stdv', ylabel=None):
    """Plot one time series per dataset on a single axis."""
    if len(datasets_and_group) < 3:
        raise ValueError('Provide at least two datasets followed by the group name.')

    *dataset_names, group = datasets_and_group
    dates = experiment_data[dataset_names[0]]['dates']
    for name in dataset_names[1:]:
        other_dates = experiment_data[name]['dates']
        if len(other_dates) != len(dates) or any(a != b for a, b in zip(other_dates, dates)):
            raise ValueError(f'Date vector mismatch between {dataset_names[0]} and {name}.')

    plt.figure(figsize=(12, 5))
    for name in dataset_names:
        data = experiment_data[name]['ts_group_stats'][group][metric]
        mean_val = np.nanmean(data)
        label_with_mean = f"{name} (mean={mean_val:.3g})"
        plt.plot(dates, data, marker='o', label=label_with_mean)

    span_days = (dates[-1] - dates[0]).days if len(dates) > 1 else 0
    if span_days >= 365:
        year_ticks = []
        seen = set()
        for dt in dates:
            if dt.year not in seen:
                year_ticks.append(dt)
                seen.add(dt.year)
        xticks = year_ticks
    else:
        xticks = dates
    plt.xticks(xticks, rotation=45)
    plt.grid(True, linestyle=':')
    plt.legend()
    plt.xlabel('Date')
    plt.ylabel(ylabel or metric)
    plt.title(f'{metric} comparison for {group}')
    plt.tight_layout()
    plt.show()


def plot_percent_difference_ts(*pairs_and_group, metric='OmF_stdv', ylabel='Percent Difference (%)', title=None):
    """Plot percent difference time series for one or more (DA, OL) dataset pairs."""
    if len(pairs_and_group) < 2:
        raise ValueError('Provide at least one (DA, OL) pair followed by the group name.')

    *pair_defs, group = pairs_and_group
    parsed_pairs = []
    dataset_names = []
    for pair in pair_defs:
        if not isinstance(pair, (list, tuple)) or len(pair) not in (2, 3):
            raise ValueError('Each pair must be a (DA, OL) or (DA, OL, label) tuple.')
        da_name, ol_name = pair[:2]
        label = pair[2] if len(pair) == 3 else f"{da_name} vs {ol_name}"
        parsed_pairs.append((da_name, ol_name, label))
        dataset_names.extend([da_name, ol_name])

    dates = experiment_data[parsed_pairs[0][0]]['dates']
    for name in dataset_names[1:]:
        other_dates = experiment_data[name]['dates']
        if len(other_dates) != len(dates) or any(a != b for a, b in zip(other_dates, dates)):
            raise ValueError(f'Date vector mismatch between {parsed_pairs[0][0]} and {name}.')

    plt.figure(figsize=(12, 5))
    for da_name, ol_name, label in parsed_pairs:
        da_vals = experiment_data[da_name]['ts_group_stats'][group][metric]
        ol_vals = experiment_data[ol_name]['ts_group_stats'][group][metric]
        percent_diff = np.divide(
            da_vals - ol_vals,
            ol_vals,
            out=np.full_like(da_vals, np.nan, dtype=float),
            where=ol_vals != 0
        ) * 100.0
        plt.plot(dates, percent_diff, marker='o', label=f"{label} (mean={np.nanmean(percent_diff):.3g}%)")

    span_days = (dates[-1] - dates[0]).days if len(dates) > 1 else 0
    if span_days >= 365:
        year_ticks = []
        seen = set()
        for dt in dates:
            if dt.year not in seen:
                year_ticks.append(dt)
                seen.add(dt.year)
        xticks = year_ticks
    else:
        xticks = dates
    plt.axhline(0.0, color='k', linestyle='--', linewidth=0.8)
    plt.xticks(xticks, rotation=45)
    plt.grid(True, linestyle=':')
    plt.legend()
    plt.xlabel('Date')
    plt.ylabel(ylabel)
    plt.title(title or f'Percent difference of {metric} for {group}')
    plt.tight_layout()
    plt.show()


def summarize_metric_stats(group, metric='OmF_stdv'):
    summary = {}
    for label, data in experiment_data.items():
        metric_key = _resolve_metric_name(metric)
        arr = data['group_metrics'][group][metric_key]
        summary[label] = {
            'mean': float(np.nanmean(arr)),
            'median': float(np.nanmedian(arr)),
            'min': float(np.nanmin(arr)),
            'max': float(np.nanmax(arr))
        }
    return summary


In [None]:
# Example: CYGNSS OmF StdDev differences for OL and DA
_ = plot_diff_map('OL_orig', 'OL_fixed', 'ASCAT', metric='OmF_stdv',
                  title_prefix='OL Fixed - OL Original', units='m3/m3')
_ = plot_diff_map('DA_orig', 'DA_fixed', 'ASCAT', metric='OmF_stdv',
                  title_prefix='DA Fixed - DA Original', units='m3/m3')
_ = plot_diff_map('OL_orig', 'OL_fixed', 'ASCAT', metric='OmF_stdv', percent=True,
                  title_prefix='OL Percent Change', cmin=-40, cmax=40)
_ = plot_diff_map('DA_orig', 'DA_fixed', 'ASCAT', metric='OmF_stdv', percent=True,
                  title_prefix='DA Percent Change', cmin=-40, cmax=40)

In [None]:
_ = plot_diff_map('OL_orig', 'DA_orig', 'ASCAT', metric='OmF_stdv', percent=True,
                  title_prefix='DA orig Percent Change', cmin=-40, cmax=40)
_ = plot_diff_map('OL_fixed', 'DA_fixed', 'ASCAT', metric='OmF_stdv', percent=True,
                  title_prefix='DA fixed 9% Percent Change', cmin=-40, cmax=40)
_ = plot_diff_map('OL_fixed', 'DA_64', 'ASCAT', metric='OmF_stdv', percent=True,
                  title_prefix='DA fixed 6.4% Percent Change', cmin=-40, cmax=40)
_ = plot_diff_map('OL_fixed', 'DA_45', 'ASCAT', metric='OmF_stdv', percent=True,
                  title_prefix='DA fixed 4.5% Percent Change', cmin=-40, cmax=40)

_ = plot_diff_map('OL_fixed', 'DA_dedup', 'ASCAT', metric='OmF_stdv', percent=True,
                  title_prefix='DA fixed Percent Change', cmin=-40, cmax=40)

In [None]:
_ = plot_diff_map('OL_orig', 'DA_orig', 'SMAP', metric='OmF_stdv', percent=True,
                  title_prefix='DA orig Percent Change', cmin=-40, cmax=40)

_ = plot_diff_map('OL_fixed', 'DA_fixed', 'SMAP', metric='OmF_stdv', percent=True,
                  title_prefix='DA fixed Percent Change', cmin=-40, cmax=40)

In [None]:
_ = plot_diff_map('OL_orig', 'DA_orig', 'CYGNSS', metric='OmF_stdv', percent=True,
                  title_prefix='DA orig Percent Change', cmin=-40, cmax=40)

_ = plot_diff_map('OL_fixed', 'DA_fixed', 'CYGNSS', metric='OmF_stdv', percent=True,
                  title_prefix='DA fixed Percent Change', cmin=-40, cmax=40)

In [None]:
# Example: CYGNSS OmF StdDev differences for OL and DA
_ = plot_diff_map('OL_orig', 'OL_fixed', 'ASCAT', metric='N_data',
                  title_prefix='OL Fixed - OL Original', units='N_obs')
_ = plot_diff_map('DA_orig', 'DA_fixed', 'ASCAT', metric='N_data',
                  title_prefix='DA Fixed - DA Original', units='N_obs')
_ = plot_diff_map('OL_orig', 'OL_fixed', 'ASCAT', metric='N_data', percent=True,
                  title_prefix='OL Percent Change', cmin=-40, cmax=40)
_ = plot_diff_map('DA_orig', 'DA_fixed', 'ASCAT', metric='N_data', percent=True,
                  title_prefix='DA Percent Change', cmin=-40, cmax=40)
_ = plot_diff_map('DA_orig', 'DA_dedup', 'ASCAT', metric='N_data', percent=True,
                  title_prefix='DA Percent Change', cmin=-40, cmax=40)

In [None]:
# Example: CYGNSS OmF StdDev time series
plot_ts_comparison('OL_orig', 'OL_fixed', 'ASCAT', metric='N_data', ylabel='N_Obs')
plot_ts_comparison('DA_orig', 'DA_fixed', 'ASCAT', metric='N_data', ylabel='N_Obs')
plot_ts_comparison('OL_orig', 'OL_fixed', 'DA_orig', 'DA_fixed', 'DA_dedup', 'ASCAT', metric='N_data', ylabel='N_Obs')

In [None]:
# Helper: print quick statistic summaries for any group/metric
from pprint import pprint
pprint(summarize_metric_stats('ASCAT', metric='OmF_stdv'))

In [None]:
# Example: CYGNSS OmF StdDev time series
plot_ts_comparison('OL_orig', 'OL_fixed','DA_orig', 'DA_fixed', 'DA_64', 'DA_45', 'ASCAT', metric='OmF_stdv', ylabel='OmF StdDev (m3/m3)')
plot_ts_comparison('OL_orig', 'OL_fixed','DA_orig', 'DA_fixed', 'DA_dedup', 'DA_64', 'DA_45', 'ASCAT', metric='OmF_stdv', ylabel='OmF StdDev (m3/m3)')
plot_ts_comparison('OL_orig', 'OL_fixed','DA_orig', 'DA_fixed', 'CYGNSS', metric='OmF_stdv', ylabel='OmF StdDev (m3/m3)')
plot_ts_comparison('OL_orig', 'OL_fixed','DA_orig', 'DA_fixed', 'SMAP', metric='OmF_stdv', ylabel='OmF StdDev (Tb)')
plot_ts_comparison('OL_orig', 'OL_fixed','DA_orig', 'DA_fixed', 'SMOS', metric='OmF_stdv', ylabel='OmF StdDev (Tb)')

In [None]:
# Percent difference time series example
plot_percent_difference_ts(
    ('DA_orig', 'OL_orig', 'DA vs OL (orig)'),
    ('DA_fixed', 'OL_fixed', 'DA vs OL (fixed)'),
    ('DA_64', 'OL_fixed', 'DA vs OL (6.4%)'),
    ('DA_45', 'OL_fixed', 'DA vs OL (4.5%)'),
    'ASCAT',
    metric='OmF_stdv',
    ylabel='OmF StdDev Percent Difference (%)'
)


plot_percent_difference_ts(
    ('DA_orig', 'OL_orig', 'DA vs OL (orig)'),
    ('DA_dedup', 'OL_fixed', 'DA vs OL (dedup)'),
    'ASCAT',
    metric='OmF_stdv',
    ylabel='OmF StdDev Percent Difference (%)'
)

In [None]:
# Example: CYGNSS OmF StdDev time series
plot_ts_comparison('OL_orig', 'OL_fixed','DA_orig', 'DA_fixed', 'ASCAT', metric='O_mean', ylabel='O_mean (m3/m3)')
plot_ts_comparison('OL_orig', 'OL_fixed','DA_orig', 'DA_fixed', 'ASCAT', metric='F_mean', ylabel='F_mean (m3/m3)')
plot_ts_comparison('OL_orig', 'OL_fixed','DA_orig', 'DA_fixed', 'ASCAT', metric='OmF_mean', ylabel='OmF_mean (m3/m3)')

In [None]:
def compute_weighted_monthly_mean(dataset, group, metric):
    monthly_vals = experiment_data[dataset]['ts_group_stats'][group][metric]
    weights = experiment_data[dataset]['ts_group_stats'][group]['N_data']
    total_weight = np.nansum(weights)
    if not np.isfinite(total_weight) or total_weight == 0:
        return np.nan
    return np.nansum(monthly_vals * weights) / total_weight

summary_rows = []
for dataset in experiment_data:
    for group in species_groups:
        for metric in ['OmF_stdv', 'OmF_mean', 'OmA_stdv', 'OmA_mean']:
            wmean = compute_weighted_monthly_mean(dataset, group, metric)
            summary_rows.append({
                'Dataset': dataset,
                'Group': group,
                'Metric': metric,
                'WeightedMonthlyMean': wmean
            })
summary_rows
