In [1]:
from itertools import zip_longest
import numpy as np
from pathlib import Path
import pandas as pd
import pickle
import xarray as xr
import panel as pn
pn.extension('tabulator')

import hvplot.pandas
import hvplot.xarray
import geoviews as gv
import holoviews as hv
from bokeh.models.widgets.tables import NumberFormatter

In [3]:
def split_dim(ds, dim):
    if dim not in ds.dims:
        return ds
    das = {}
    for da in ds.data_vars.values():
        if dim not in da.dims:
            continue
        for crd in da[dim].values:
            das[f"{da.name}_{crd}"] = da.sel({dim: crd}, drop=True)
    return xr.Dataset(data_vars=das, attrs=ds.attrs)

sources = {
    'refs' : ['NRCAN', 'RDRS', 'ERA5-Land'],
    'sims' : ['ScenGen', 'ESPOG', 'ESPOR', 'ESPOG-RDRS', 'ESPOG-NRCAN']
}
root = Path('/exec/pbourg/ESPO-G/indicators/')
if not (root / 'sims_delta.nc').is_file():
    import xesmf as xe
    import dask
    dask.config.set(num_workers=12)

    def split_season(ds):
        das = {}
        for da in ds.data_vars.values():
            if 'season' not in da.dims:
                continue
            for crd in da['season'].values:
                das[f"{da.name}_{crd}"] = da.sel(season=crd, drop=True)
        return xr.Dataset(data_vars=das, attrs=ds.attrs)

    # Split seasons and create merged datasets
    regs = {
        dataset: xe.Regridder(
            xr.open_zarr(root / f'{dataset}_climato_AS-JAN.zarr'),
            xr.open_zarr(root / 'NRCAN_climato_AS-JAN.zarr'),
            method='bilinear',
            unmapped_to_nan=True
        )
        for dataset in (sources['refs'][1:] + sources['sims'] + ['ESPOG-IDM'])
    }
    for typ in ['climato', 'timeseries', 'timeseries-delta', 'delta']:
        for lvl, datasets in sources.items():
            if 'delta' in typ and lvl == 'refs':
                continue
            outpath = root / f'{lvl}_{typ}.nc'
            if outpath.is_file():
                continue
            dss = []
            for dataset in datasets:
                ds = xr.open_mfdataset(
                    str(root / f"{dataset}_{typ}_*.zarr"),
                    engine='zarr',
                    preprocess=split_season,
                    decode_timedelta=False
                )
                if dataset in regs and typ in ['climato', 'delta']:
                    ds = regs[dataset](ds)
                if dataset == 'ESPOG':
                    dsidm = xr.open_mfdataset(
                        str(root / f"ESPOG-IDM_{typ}_*.zarr"),
                        engine='zarr',
                        preprocess=lambda ds : split_dim(ds, dim='season'),
                        decode_timedelta=False
                    )
                    vvs = sorted(
                        set.intersection(
                            set(ds.data_vars.keys()), set(dsidm.data_vars.keys())
                        )
                    )
                    ds = ds[vvs]
                    dsidm = dsidm[vvs]
                    if typ in ['climato', 'delta']:
                        dsidm = regs['ESPOG-IDM'](dsidm, skipna=True)
                        vv = list(dsidm.data_vars)[0]
                        ds = ds.where(
                            ds[vv].isel(experiment=0, percentiles=0).notnull(),
                            dsidm
                        )
                    else:
                        ds = ds.where(
                            ds.region != 'Îles-de-la-Madeleine',
                            dsidm
                        )
                if lvl == 'sims':
                    ds['experiment'] = ['low', 'high'][:len(ds.experiment)]
                dss.append(ds.expand_dims(source=[dataset]))
            if typ in ['timeseries', 'timeseries-delta', 'delta']:
                vvs = sorted(set.intersection(*(set(ds.data_vars) for ds in dss)))
                dss = [ds[vvs] for ds in dss]
            ds = xr.concat(dss, 'source').drop_vars(['rotated_pole'], errors='ignore')
            print(f'Writing for {lvl}/{typ}')
            ds.to_netcdf(outpath)

Writing for sims/climato
Writing for refs/timeseries
Writing for sims/timeseries
Writing for sims/timeseries-delta
Writing for sims/delta


In [4]:
dss = {
    typ: {
        lvl : xr.open_dataset(root / f"{lvl}_{typ}.nc", decode_timedelta=False, drop_variables=['rotated_pole'])
        for lvl in sources.keys()
        if not ('delta' in typ and lvl == 'refs')
    }
    for typ in ['climato', 'timeseries', 'timeseries-delta', 'delta']
}
regions = dict(zip(dss['timeseries']['refs'].region.values, dss['timeseries']['refs'].geom.values))

In [5]:
score_functions = {}
def score(aspect):
    def _score(func):
        score_functions[aspect] = func
        return func
    return _score

def get_score(aspect, datasets):
    cache = Path(f'.scores_{datasets}_{aspect}.obj')
    if not cache.is_file():
        scores = score_functions[aspect](dss[aspect][datasets])
        with cache.open('wb') as f:
            pickle.dump(scores, f)
    with cache.open('rb') as f:
        scores = pickle.load(f)
    return scores

# Climato Carte
@score('climato')
def climato_score(ds):
    dsref = ds.isel(source=0)
    ds = ds.isel(source=slice(1, None))
    err = 100 * (ds - dsref) / dsref
    aerr = abs(err)
    scores = xr.concat([aerr.mean(['lat', 'lon']), aerr.quantile(0.95, ['lat', 'lon']).drop_vars(['quantile'])], 'spatial_stat')
    scores['spatial_stat'] = ['mean', 'p95']
    if 'percentiles' in scores.dims:
        scores['percentiles'] = ['p10', 'p50', 'p90']
    return (
        scores
        .to_array('indicator')
        .rename('score')
        .to_dataframe()
        .reset_index()
        .pivot(
            index=['indicator'] + (['experiment'] if 'experiment' in scores.dims else []),
            columns=[c for c in ('source', 'spatial_stat', 'percentiles') if c in scores.dims]
        )
    )

@score('delta')
def delta_score(ds):
    dsref = ds.isel(source=0, drop=True)
    dscomp = ds.isel(source=slice(1, None))
    ndiff = abs(dsref - dscomp) / ds.std(['lat', 'lon']).mean('source')
    conflict = (
        (np.sign(dsref) * np.sign(dscomp)) == -1
    ).where(dsref.notnull() & dscomp.notnull() & (dsref.lat < 51))
    scores = xr.concat(
        [ndiff.mean(['lat', 'lon']), conflict.sum(['lat', 'lon'])],
        'spatial_stat'
    )
    scores['spatial_stat'] = ['all', 'disagree']
    scores['percentiles'] = ['p10', 'p50', 'p90']
    return (
        scores
        .to_array('indicator')
        .rename('score')
        .to_dataframe()
        .reset_index()
        .pivot(
            index=['indicator', 'experiment'],
            columns=['source', 'spatial_stat', 'percentiles']
        )
    ) 
   

@score('timeseries')
def timeseries_score(ds):
    if 'experiment' in ds.dims:
        ds = ds.sel(time=slice('2050', '2100'))
    dsref = ds.isel(source=0, drop=True)
    dscomp = ds.isel(source=slice(1, None))

    adiff = abs(dsref - dscomp) 
    andiff = adiff / ds.std(['time']).mean(['geom', 'source'])
    scores = andiff.mean('time')
    if 'percentiles' in scores.dims:
        scores['percentiles'] = ['p10', 'p50', 'p90']
    return (
        scores
        .drop_vars(['lat', 'lon', 'rotated_pole'], errors='ignore')
        .to_array('indicator')
        .rename('score')
        .to_dataframe()
        .reset_index()
        .drop(columns=['geom'])
        .pivot(
            index=['indicator', 'region'] + (['experiment'] if 'experiment' in scores.dims else []),
            columns=['source'] + (['percentiles'] if 'percentiles' in scores.dims else [])
        )
    )

In [6]:
colors = {
    'NRCAN': 'red', 'RDRS': 'green', 'ERA5-Land': 'blue',
    'RDRS - NRCAN': 'green', 'ERA5-Land - NRCAN': 'blue',
    'ScenGen': 'red', 'ESPOG': 'blue', 'ESPOR': 'purple', 
    'ESPOG-RDRS': 'turquoise', 'ESPOG-NRCAN': 'salmon',
    'ESPOG - ScenGen': 'blue', 'ESPOR - ScenGen': 'purple',
    'ESPOG-RDRS - ScenGen': 'turquoise', 'ESPOG-NRCAN - ScenGen': 'salmon',
}

plot_functions = {}
def plot(aspect):
    def _plot(func):
        plot_functions[aspect] = func
        return func
    return _plot

# Climato Carte
@plot('climato')
def climato_plot(datasets, indicator, experiment=None, filt=None):
    ds = dss['climato'][datasets][indicator]
    if experiment is not None:
        ds = ds.sel(experiment=experiment)
    if filt is not None:
        ds = ds.sel(source=filt[datasets])
    old = ds.isel(source=0, drop=True)
    new = ds.isel(source=slice(1, None))

    def _climato_plot(options):       
        diff = new - old
        clabel='Différence (NEW - OLD)'
        if 'percent' in options:
            diff = 100 * diff / old
            clabel='Erreur 100 * (NEW - OLD) / OLD'
        clim = abs(diff.quantile([0.05, 0.95])).max('quantile').item()
        if 'percentiles' in ds.dims:
            da = ds.isel(percentiles=1)
        else:
            da = ds
        clim_raw = tuple(da.quantile([0.05, 0.95]).values)
        lastsrc = diff.source[-1].item()
        raw_plots = []
        for src in ds.source.values:
            raw_plots.append(
                da.sel(source=src).rename('raw_climato').hvplot(
                    x='lon', y='lat',
                    frame_width=250,
                    clabel='Climatologie', clim=clim_raw,
                    cmap='bmy', colorbar=(src == lastsrc),
                    geo=True, title=f"{src} climato" + ('p50' if 'percentiles' in ds.dims else '')
                )
            )
        diff_plots = []
        def __plot(da, title, colorbar):
            return da.hvplot(
                x='lon', y='lat',
                frame_width=500,
                clabel=clabel, clim=(-clim, clim),
                cmap='coolwarm', colorbar=colorbar,
                geo=True, title=title
            )
        if 'percentiles' in diff.dims:
            diff_plots = [
                __plot(diff.sel(source=src, percentiles=perc), f"{src} diff p{perc}", src==lastsrc)
                for src in diff.source.values
                for perc in diff.percentiles.values
            ]
        else:
            diff_plots = [
                __plot(diff.sel(source=src), src, src==lastsrc)
                for src in diff.source.values
            ]

        return pn.Column(
            hv.Layout(raw_plots),
            hv.Layout(diff_plots).cols(3)
        )
    return _climato_plot


@plot('delta')
def delta_plot(datasets, indicator, experiment, filt):
    ds = dss['delta'][datasets][indicator].sel(experiment=experiment, source=filt[datasets])
    old = ds.isel(source=0, drop=True)
    new = ds.isel(source=slice(1, None))

    def _delta_plot(options):       
        diff = new - old
        clabel='Différence (NEW - OLD)'
        if 'percent' in options:
            diff = 100 * diff / old
            clabel='Erreur 100 * (NEW - OLD) / OLD)'
        clim_diff = abs(diff.quantile([0.05, 0.95])).max('quantile').item()
        clim_raw = abs(ds.isel(percentiles=1).quantile([0.05, 0.95])).max('quantile').item()
        if 'sign' in options:
            diff = (
                ((np.sign(old) * np.sign(new)) == -1) * 1 * np.sign(new)
            ).where(old.notnull() & new.notnull())
            clim_diff = 1
            clabel = 'Signes (OLD > 0, NEW < 0) -> -1 (et vice-versa)'
        lastsrc = diff.source[-1].item()
        raw_plots = []
        for src in ds.source.values:
            raw_plots.append(
                ds.sel(source=src).isel(percentiles=1).rename('raw_delta').hvplot(
                    x='lon', y='lat',
                    frame_width=250,
                    clabel='Delta', clim=(-clim_raw, clim_raw),
                    cmap='PuOr_r', colorbar=(src == lastsrc),
                    geo=True, title=f"{src} delta p50"
                )
            )
        diff_plots = []
        for src in diff.source.values:
            for perc in diff.percentiles.values:
                diff_plots.append(
                    diff
                    .sel(source=src, percentiles=perc)
                    .hvplot(
                        x='lon', y='lat', frame_width=250, 
                        clabel=clabel, clim=(-clim_diff, clim_diff), 
                        cmap='coolwarm', colorbar=(str(perc) == '90'),
                        geo=True, title=f"{src} diff p{perc}")
                )
        return pn.Column(
            hv.Layout(raw_plots),
            hv.Layout(diff_plots).cols(3)
        )
    return _delta_plot


def maybe_smooth(ds, opts):
    if 'smooth' in opts:
        return ds.rolling(time=30, center=True, min_periods=5).mean()
    return ds


def maybe_delta(ds, opts):
    if 'delta' in opts:
        return ds - ds.sel(time=slice('1981', '2010')).mean('time')
    return ds

def maybe_diff(ds, opts):
    if 'diff' in opts:
        out = ds.isel(source=slice(1, None)) - ds.isel(source=0, drop=True)
        ref = ds.source.values[0]
        out['source'] = [f'{src} - {ref}' for src in out.source.values]
        return out
    return ds

@plot('timeseries')
def timeseries_plot(datasets, indicator, region, experiment='low', filt=None):
    ireg = regions[region]

    def _timeseries_plot(options):
        refs_ds = (
            dss['timeseries']['refs'][indicator].sel(source=filt['refs'])
            .sel(geom=ireg)
            .pipe(maybe_delta, opts=options)
            .pipe(maybe_diff, opts=options)
        )
        refs = [
            refs_ds.sel(source=src).hvplot(x='time', color=colors[src], label=src)
            for src in refs_ds.source.values
        ]
        if indicator == 'dtr_mean_annual' or 'nosim' in options:
            return hv.Overlay(refs).opts(title=indicator, ylabel=indicator, width=1250)
        asp = 'timeseries-delta' if 'delta' in options else 'timeseries'
        sims_ds = (
            dss[asp]['sims'][indicator]
            .sel(geom=ireg, experiment=experiment, source=filt['sims'], drop=True)
            .pipe(maybe_diff, opts=options)
            .pipe(maybe_smooth, opts=options)
            .assign_coords(percentiles=['p10', 'p50', 'p90'])
            .to_dataset('percentiles')
        )
        sims = []
        for src in sims_ds.source.values:
            sim_df = sims_ds.sel(source=src).to_dataframe()
            sims.append(
                sim_df.hvplot.area(y='p10', y2='p90', color=colors[src], label=src, alpha=0.33)
            )
            sims.append(sim_df.hvplot(y='p50', by='source', color=colors[src], label=src))
        if 'delta' in options:
            refs.append(hv.HLine(0).opts(color='black', line_width=1, line_dash='dashed'))
        return hv.Overlay(sims + refs).opts(
            title=indicator + ' '.join(options), ylabel=indicator,
            width=1250, height=600,
            legend_opts={"click_policy": "hide"}
        )
    return _timeseries_plot

plot_options = {
    'climato': {'Valeurs relatives (%)': 'percent'},
    'timeseries': {
        'Deltas climatiques': 'delta',
        'Différences avec la REF': 'diff',
        'Projections lissées': 'smooth',
        'Cacher les simulations': 'nosim'},
    'delta': {'Valeurs relatives (%)': 'percent',
              'Changements de signe' : 'sign'},
}

In [17]:
base_aspects = {'Climatologie': 'climato',
                'Séries annuelles': 'timeseries'}

def _table_filter(df, indicator, region, options):
    if (
        ('indicator' not in df.columns)
        or ('region' not in df.columns and (region not in ['<toutes>', None] or 'south' in options))
        or ('rcp45' in options and 'experiment' not in df.columns)
    ):
        # Something went wrong
        return df
    if 'rcp45' in options:
        df = df[df.experiment == 'low']
    if 'south' in options:
        df = df[~(df.region.str.startswith('Nuna') | df.region.str.startswith('Jam'))]
    if indicator not in ['<tous>', None]:
        df = df[df['indicator'].str.startswith(indicator)]
    if region not in ['<toutes>', None]:
        df = df[df['region'] == region]
    return df

w_taboptions = pn.widgets.CheckBoxGroup(
    name='Table options',
    options={},
)
w_indicator = pn.widgets.Select(name='Indicateur', options=[])
w_region = pn.widgets.Select(name='Région', options=[])
table_filter = pn.bind(_table_filter, indicator=w_indicator, region=w_region, options=w_taboptions)
w_pltoptions = pn.widgets.CheckBoxGroup(
    name='Plot options',
    options={},
)

w_table_container = pn.layout.Card(pn.pane.Str('rien'), title='Scores')

w_datasets = pn.widgets.ToggleGroup(
    options={'Références': 'refs', 'Simulations': 'sims'},
    behavior='radio'
)

w_refsrcs = pn.widgets.CheckBoxGroup(
    options=list(dss['climato']['refs'].source.values[1:]),
    value=list(dss['climato']['refs'].source.values[1:]),
)

w_simsrcs = pn.widgets.CheckBoxGroup(
    options=list(dss['climato']['sims'].source.values[1:]),
    value=list(dss['climato']['sims'].source.values[1:]),
)

w_aspect = pn.widgets.Select(options=base_aspects.copy())

@pn.depends(datasets=w_datasets, aspect=w_aspect)
def w_title(datasets, aspect):
    # Title
    ref = {'sims': 'ScenGen', 'refs': 'NRCAN'}[datasets]
    quoi = {
        'climato': 'climatologies (1981-2010)',
        'timeseries': 'indicateurs annuels',
        'delta': 'deltas climatiques (2071-2100 - 1981-2010)'
    }[aspect]
    if aspect == 'climato':
        comment = (
            'Score : 100 * (OLD - NEW) / OLD.\n\t'
            'mean: moyenne spatiale, p95 : 95e centile spatial.'
        )
    elif aspect == 'delta':
        comment = (
            "Score : |OLD - NEW| / S.\n\t"
            "Où S est l'écart-type spatial moyenné sur tous les datasets. Moyenne spatiale du score.\n\t"
            "Disagree : nombre de pixels où les deltas sont de signes différents (0 exclu, sous le 51$^e$ parallèle)."
        )
    elif aspect == 'timeseries':
        period = {'sims': '(2050-2100)', 'refs': '(1950-2020)'}
        comment = (
            f'Score : Moyenne temporelle {period} de |OLD - NEW| / S .\n\t'
            'Où S = écart-type le long du temps moyenné sur les régions et les datasets.'
        )
    return pn.pane.Markdown(
        f"# Comparaison des {quoi} avec {ref}\n\n{comment}",
        width=750
    )

def update_aspects(event):
    aspects = base_aspects.copy()
    if w_datasets.value == 'sims':
        aspects.update({'Deltas': 'delta'})
    w_aspect.options = aspects

w_datasets.param.watch(update_aspects, 'value')

w_plot_container = pn.Row(pn.pane.Str('rien'))

def plot(event):
    idx = event.new
    if len(idx) == 0 or 'indicator' not in event.obj.value.columns:
        return pn.pane.Str('Selectionnez une ligne.')
    row = event.obj.value.iloc[idx[0]]
    kws = {k: row[k] for k in ('indicator', 'region', 'experiment') if k in row}
    w_plot_container[0] = pn.Column(
        pn.pane.Markdown(
            f"# {kws['indicator']} {kws.get('region', '')} {kws.get('experiment', '')}"
        ),
        pn.depends(options=w_pltoptions)(
            plot_functions[w_aspect.value](
                w_datasets.value, 
                filt={'refs': ['NRCAN', *w_refsrcs.value], 'sims': ['ScenGen', *w_simsrcs.value]},
                **kws
            )
        )
    )

def update_all(event):
    datasets = w_datasets.value
    aspect = w_aspect.value
    refsrcs = w_refsrcs.value
    simsrcs = w_simsrcs.value

    # Table
    sdf = get_score(aspect, datasets)
    dfc = sdf.reset_index()
    dfc = dfc[~dfc['score'].isnull().all(axis=1)]
    dfc.columns = [' '.join(cols).replace('score', '').strip() for cols in dfc.columns]
    if datasets == 'refs':
        filtre = refsrcs
    elif datasets == 'sims':
        filtre = simsrcs
    cols = [
        col
        for col in dfc.columns
        if col in ['indicator', 'experiment', 'region'] or col.split(' ')[0] in filtre
    ]
    dfc = dfc[cols].reset_index(drop=True)
    if isinstance(w_table_container[0], pn.widgets.Tabulator):
        w_table_container[0].remove_filter(table_filter)
        
    w_table = pn.widgets.Tabulator(
        dfc,
        selection=[0],
        editors=dict(zip_longest(dfc.columns, [])),
        formatters=dict(
            zip_longest(
                [c for c in dfc.columns if c not in ['experiment', 'indicator', 'region']],
                [],
                fillvalue=NumberFormatter(format='0.000')
            )
        ),
        pagination='remote',
        page_size=25,
        show_index=False,
        width=700,
    )
    w_table.add_filter(table_filter)
    w_table.param.watch(plot, 'selection')
    w_indicator.options = (
        ['<tous>']
        + sorted({'_'.join(ind.split('_')[:-1]) for ind in dfc.indicator.unique()})
    )
    w_table_container[0] = w_table
    if aspect == 'timeseries':
        w_region.disabled = False
        w_region.options = ['<toutes>'] + sorted(dfc.region.unique())
    else:
        w_region.options = []
        w_region.disabled = True
    
    tab_options = {}
    if datasets == 'sims':
        tab_options['Ne montrer que le RCP4.5'] = 'rcp45'
    if aspect == 'timeseries': 
        tab_options['Cacher le grand nord'] = 'south'
    w_taboptions.options = tab_options

    w_pltoptions.options = plot_options[aspect]
    
w_datasets.param.watch(update_all, 'value')
w_refsrcs.param.watch(update_all, 'value')
w_simsrcs.param.watch(update_all, 'value')
w_aspect.param.watch(update_all, 'value')
    
dash = pn.Column(
    pn.Row(
        pn.layout.WidgetBox(
            '## Choix de la vue',
            w_datasets,
            w_refsrcs,
            w_simsrcs,
            w_aspect
        ),
        pn.layout.WidgetBox(
            '## Options du tableau',
            w_indicator,
            w_region,
            w_taboptions
        ),
        pn.layout.WidgetBox(
            '## Options des figures',
            w_pltoptions
        ),
        w_title
    ),
    pn.layout.Divider(),
    pn.Row(
        w_table_container,
        w_plot_container
    )
)

update_all(None)

In [19]:
# s = dash.show(port=9988, websocket_origin='*')

Launching server at http://localhost:9988


In [18]:
# s.stop()

In [None]:
dash.servable()