In [None]:
import xarray as xr
import pandas as pd
import matplotlib.pyplot as plt
import os, yaml
import numpy as np
from matplotlib import cm

In [None]:
root = os.path.abspath(os.path.join(os.getcwd(),'..'))

In [None]:
sites = ['extras','ganga_damodar','godavari','kali','kaveri','krishna','mahanadi','narmada','penner','sharavati','tapi']

In [None]:
sites_data = yaml.load(open(os.path.join(root,'bin','experiments-lowest.yaml'),'r'),Loader=yaml.SafeLoader)

In [None]:
ds_gconv = {}
for site in sites:
    #print (sites_data[site]['gconv'].split('-'))
    el = os.path.splitext(os.path.split(sites_data[site]['gconv'])[-1])[0][-2:]
    #el = str(int(el)+10)
    ds_gconv[site] = xr.load_dataset(os.path.join(root,'data','final_preds',f'{site}-gconv-{el}-preds.nc'))
ds_noconv = {}
for site in sites:
    el = os.path.splitext(os.path.split(sites_data[site]['no_gconv'])[-1])[0][-2:]
    ds_noconv[site] = xr.load_dataset(os.path.join(root,'data','final_preds',f'{site}-no_gconv-{el}-preds.nc'))

In [None]:
all_gconv = xr.merge([ds_gconv[site] for site in sites])

In [None]:
all_gconv

In [None]:
select_sites = ['bhatghar', 'dudhganga', 'linganamakki'] # irshna x 2 sharavati

In [None]:
def norm_func(arr):
            return (arr - arr.min()) / (arr.max() - arr.min())  # 0 to 1

In [None]:
sharavati_data = xr.load_dataset(os.path.join(root,'data','data_sharavati.nc'))

In [None]:
krishna_data = xr.load_dataset(os.path.join(root,'data','data_krishna.nc'))

In [None]:
krishna_data

In [None]:
krishna_data['targets_WATER_VOLUME'] = krishna_data['targets_WATER_VOLUME'].groupby("global_sites").map(norm_func)

In [None]:
sharavati_data['targets_WATER_VOLUME'] = sharavati_data['targets_WATER_VOLUME'].groupby("global_sites").map(norm_func)

In [None]:
krishna_data = krishna_data.rename({'steps':'step','global_sites':'site'})

In [None]:
sharavati_data = sharavati_data.rename({'steps':'step','global_sites':'site'})

In [None]:
all_gconv

In [None]:
def revert_to_levels(
    data: xr.Dataset,
    preds: xr.Dataset,
    target_var: str,
) -> xr.Dataset:

    for var in ["obs", "sim", "sim-frozen", "sim-mean", "sim-std", "ci-95+", "ci-95-"]:
        if var in preds.keys():
            preds[var] = preds[var].cumsum(dim="step") + data[target_var].sel(
                {"date": data["date"].isin(preds["date"])}
            ).isel({"step": 0})

    return preds

In [None]:
preds_levels_krishna = revert_to_levels(
            data=krishna_data,
            preds=all_gconv.copy(deep=True),
            target_var='targets_WATER_VOLUME',
        )

In [None]:
preds_levels_sharaviti = revert_to_levels(
            data=sharavati_data,
            preds=all_gconv.copy(deep=True),
            target_var='targets_WATER_VOLUME',
        )

In [None]:
preds_levels = xr.merge([preds_levels_sharaviti.sel({'site':['linganamakki']}),preds_levels_krishna.sel({'site':select_sites[0:2]})])

In [None]:
date_idx = (all_gconv["date"] >= pd.to_datetime('2020-01-01')) & (
        all_gconv["date"] <= pd.to_datetime('2021-01-01')
    )

In [None]:
def _interpolate_1d(data):
    for var in list(data.keys()):
        if is_numeric_dtype(data[var]):
            data[var] = data[var].interpolate_na(
                dim="date", method="linear", limit=15
            )

    return data

In [None]:
from pandas.api.types import is_numeric_dtype

In [None]:
preds_levels = _interpolate_1d(preds_levels)

In [None]:
from matplotlib.patches import Patch
from matplotlib.lines import Line2D

In [None]:
fig, axs = plt.subplots(3,2, figsize=(12,9), sharex=True)
cmap = cm.get_cmap("winter_r")

leg_items = [
    Line2D([0], [0], color='#f200ff', lw=2, label='obs')
]

for ii_s,site in enumerate(select_sites):

    for step in [5, 15, 25, 50, 75, 89]:
        # print (site, step)
        hexcolor = "#" + "".join(
            f"{int(el*255):02x}" for el in cmap(int(step / 90 * 255))[:3]
        )
        if ii_s==0:
            leg_items.append(
                Line2D([0], [0], color=hexcolor, lw=2, label=f'{step} days')
            )
        
        # plot main
        all_gconv.sel({"site": site, "step": step, "date": date_idx}).shift(
            date=step
        )['sim-mean'].plot(ax=axs[ii_s,0], c=hexcolor)

        # plot CI if available
        for dim in ['ci-95+','ci-95-']:
            all_gconv.sel({"site": site, "step": step, "date": date_idx}).shift(
                date=step
            )[dim].plot(ax=axs[ii_s,0], c=hexcolor, ls=":")

    all_gconv.sel({"site": site, "step": 0, "date": date_idx})["obs"].plot(
        ax=axs[ii_s,0], c="#f200ff"
    )
    
    
    for step in [5, 15, 25, 50, 75, 89]:
        # print (site, step)
        hexcolor = "#" + "".join(
            f"{int(el*255):02x}" for el in cmap(int(step / 90 * 255))[:3]
        )
        # plot main
        preds_levels.sel({"site": site, "step": step, "date": date_idx}).shift(
            date=step
        )['sim-mean'].plot(ax=axs[ii_s,1], c=hexcolor)

        # plot CI if available
        for dim in ['ci-95+','ci-95-']:
            preds_levels.sel({"site": site, "step": step, "date": date_idx}).shift(
                date=step
            )[dim].plot(ax=axs[ii_s,1], c=hexcolor, ls=":")

    preds_levels.sel({"site": site, "step": 0, "date": date_idx})["obs"].plot(
        ax=axs[ii_s,1], c="#f200ff"
    )
    
for ii in range(3):
    axs[ii,0].text(0.02,0.9,select_sites[ii], weight='bold', transform=axs[ii,0].transAxes)
    for jj in range(2):
        axs[ii,jj].set_title('')
        axs[ii,jj].set_xlabel('')
    axs[ii,0].set_ylabel('Resv. Vol. Changes')
    axs[ii,1].set_ylabel('Resv. Volumes')
    
leg_items.append(
    Line2D([0], [0], color='gray', lw=2, ls=':', label=f'$\mp$2$\sigma$')
)
    
fig.legend(handles=leg_items, ncol=len(leg_items), loc='lower center')
fig.savefig('./sample_predictions.pdf',bbox_inches='tight')

In [None]:
cmap = cm.get_cmap("winter_r")
n_cols = 3
n_rows = ceil(len(preds["site"]) / n_cols) * len(test_chunks)
fig, axs = plt.subplots(n_rows, n_cols, figsize=(6 * n_cols, 3 * n_rows))
axs = axs.flatten()
_ii = 0
for chunk in test_chunks:
    date_idx = (preds["date"] >= pd.to_datetime(chunk[0])) & (
        preds["date"] <= pd.to_datetime(chunk[1])
    )
    for site in preds["site"].data:

        for step in [5, 15, 25, 50, 75, 89]:
            # print (site, step)
            hexcolor = "#" + "".join(
                f"{int(el*255):02x}" for el in cmap(int(step / 90 * 255))[:3]
            )
            # plot main
            preds.sel({"site": site, "step": step, "date": date_idx}).shift(
                date=step
            )[main_dim].plot(ax=axs[_ii], c=hexcolor)

            # plot CI if available
            if ci_dims is not None:
                for dim in ci_dims:
                    preds.sel({"site": site, "step": step, "date": date_idx}).shift(
                        date=step
                    )[dim].plot(ax=axs[_ii], c=hexcolor, ls=":")

        preds.sel({"site": site, "step": 0, "date": date_idx})["obs"].plot(
            ax=axs[_ii], c="#f200ff"
        )
        _ii += 1


In [None]:
def plot_test_preds(
    filepath: Path,
    preds: xr.Dataset,
    test_chunks: List[List[str]],
    site_dim: Optional[str] = "site",
    main_dim: Optional[str] = "sim",
    ci_dims: Optional[List[str]] = None,
):

    cmap = cm.get_cmap("winter_r")
    n_cols = 3
    n_rows = ceil(len(preds["site"]) / n_cols) * len(test_chunks)
    fig, axs = plt.subplots(n_rows, n_cols, figsize=(6 * n_cols, 3 * n_rows))
    axs = axs.flatten()
    _ii = 0
    for chunk in test_chunks:
        date_idx = (preds["date"] >= pd.to_datetime(chunk[0])) & (
            preds["date"] <= pd.to_datetime(chunk[1])
        )
        for site in preds["site"].data:

            for step in [5, 15, 25, 50, 75, 89]:
                # print (site, step)
                hexcolor = "#" + "".join(
                    f"{int(el*255):02x}" for el in cmap(int(step / 90 * 255))[:3]
                )
                # plot main
                preds.sel({"site": site, "step": step, "date": date_idx}).shift(
                    date=step
                )[main_dim].plot(ax=axs[_ii], c=hexcolor)

                # plot CI if available
                if ci_dims is not None:
                    for dim in ci_dims:
                        preds.sel({"site": site, "step": step, "date": date_idx}).shift(
                            date=step
                        )[dim].plot(ax=axs[_ii], c=hexcolor, ls=":")

            preds.sel({"site": site, "step": 0, "date": date_idx})["obs"].plot(
                ax=axs[_ii], c="#f200ff"
            )
            _ii += 1

    fig.savefig(filepath)
    plt.close("all")