# Plot Trends
Use the projection summary files to make trend plots

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import copy
import os
from glob import glob
import json
import geopandas as gpd
import pandas as pd
import numpy as np
import logging
import matplotlib as mpl
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
import plotly.express as px
from plotly.subplots import make_subplots
import plotly.graph_objects as go
import seaborn as sns
import xarray as xr

from rex import Resource, init_logger
from region_classifier import RegionClassifier

from sup3r.preprocessing.data_handling.base import DataHandler
from sup3r.preprocessing.data_handling import DataHandlerNCforCC
from sup3r.preprocessing.data_handling import DataHandlerNCforCCwithPowerLaw
from sup3r.bias.bias_calc import SkillAssessment

from make_projection_summaries_cmip import (get_countries_shape, get_states_shape, get_eez_shape, get_fps, get_targets_shapes, 
                                            make_summary_files, FEATURES, MODELS, TAGS, REGIONS)

In [None]:
if 'nsrdb' not in TAGS:
    TAGS.append('nsrdb')
    MODELS.append('NSRDB')
    TAGS.append('era5')
    MODELS.append('ERA5')
    TAGS.append('daymet')
    MODELS.append('DAYMET')

In [None]:
colors = {
 'cesm2': '#1f77b4',
 'cesm2waccm': [np.min([1, rgb*1.5]) for rgb in mcolors.to_rgb('#1f77b4')],
 'ecearth3cc': [np.min([1, rgb*0.75]) for rgb in mcolors.to_rgb('#ff7f0e')],
 'ecearth3': '#ff7f0e',
 'ecearth3veg': [np.min([1, rgb*1.4]) for rgb in mcolors.to_rgb('#ff7f0e')],
 'gfdlcm4': '#2ca02c',
 'gfdlesm4': [np.min([1, rgb*1.5]) for rgb in mcolors.to_rgb('#2ca02c')],
 'inmcm48': '#d62728',
 'inmcm50': [np.min([1, rgb*1.7]) for rgb in mcolors.to_rgb('#d62728')],
 'mpiesm12hr': '#9467bd',
 'mriesm20': '#e377c2',
 'noresm2mm': '#7f7f7f',
 'taiesm1': '#bcbd22',
 'nsrdb': 'k',
 'era5': 'k',
 'daymet': 'k',
}

In [None]:
tab20 = plt.cm.get_cmap('tab20', 20)
named_colors = [mcolors.rgb2hex(tab20(i)) for i in range(tab20.N)]
colors = {t: nc for t, nc in zip(TAGS, named_colors)}

colors = {'cesm2': '#1f77b4',
 'cesm2waccm': '#aec7e8',
 'ecearth3cc': '#d62728',
 'ecearth3': '#ff7f0e',
 'ecearth3veg': '#ffbb78',
 'gfdlcm4': '#2ca02c',
 'gfdlesm4': '#98df8a',
 'inmcm48': '#9467bd',
 'inmcm50': '#c5b0d5',
 'mpiesm12hr': '#ff9896',
 'mriesm20': '#8c564b',
 'noresm2mm': '#c49c94',
 'taiesm1': '#e377c2',
 'nsrdb': '#000000',
 'era5': '#000000',
 'daymet': '#000000'}

In [None]:
plotly_marker_map = {
 'o': 'circle',
 '.': 'circle-dot',
 '8': 'octagon',
 '^': 'triangle-up',
 '>': 'triangle-right',
 'v': 'triangle-down',
 '1': 'y-down-open',
 '2': 'y-up-open',
 's': 'square',
 'p': 'pentagon',
 'P': 'cross',
 '*': 'star',
 'X': 'x',
 'h': 'hexagon',
 'H': 'hexagon2',
 'd': 'diamond-tall',
 'd': 'diamond-tall',
 'd': 'diamond-tall',
}

markers = {'cesm2': 'h',
 'cesm2waccm': 'H',
 'ecearth3cc': '^',
 'ecearth3': '>',
 'ecearth3veg': 'v',
 'gfdlcm4': '1',
 'gfdlesm4': '2',
 'inmcm48': 's',
 'inmcm50': 'p',
 'mpiesm12hr': 'P',
 'mriesm20': '*',
 'noresm2mm': 'X',
 'taiesm1': 'o',
 'nsrdb': 'd',
 'era5': 'd',
 'daymet': 'd',}

In [None]:
linewidth=0.8
markersize=5

In [None]:
for i, (model, tag) in enumerate(zip(MODELS, TAGS)):
    plt.scatter(i, i, label=model, c=colors[tag], marker=markers[tag])
plt.legend()

In [None]:
FEATURES = ['temperature_2m', 'temperature_max_2m', 'temperature_min_2m', 
            'relativehumidity_2m', 
            # 'relativehumidity_max_2m', 'relativehumidity_min_2m', 
            'rsds', 'pr', 'windspeed_100m']

In [None]:
FP_BASE = '/projects/alcaps/gcm_eval/analysis/projections/{reg}_{tag}_{scen}_{feat}.csv'

In [None]:
def get_trend_df(region, scenario, feature, period=365*10, option='mean', baseline=True, tslice=None, relative=False, years=None):
    
    df = None
    
    for tag, model in zip(TAGS, MODELS):
        
        fp = FP_BASE.format(reg=region, tag=tag, scen=scenario, feat=feature)
        
        if tag in ('nsrdb', 'era5', 'daymet'):
            fp = fp.replace(f'_{scenario}', '')
            fp = fp.replace(f'_max', '')
            fp = fp.replace(f'_min', '')
            fp = fp.replace('era5_temperature_2m', 'era5_trh')
            fp = fp.replace('era5_relativehumidity_2m', 'era5_trh')
            
        if os.path.exists(fp):
            idf = pd.read_csv(fp, index_col=0)
            idf.index = pd.to_datetime(idf.index)

            if feature in idf:
                idf = idf.rename({feature: model}, axis=1)
            elif feature.replace('_max', '') in idf:
                idf = idf.rename({feature.replace('_max', ''): model}, axis=1)
            elif feature.replace('_min', '') in idf:
                idf = idf.rename({feature.replace('_min', ''): model}, axis=1)
            idf = idf[[model]]
        else:
            idf = pd.DataFrame(columns=[model])

        if tag in ('nsrdb', 'era5', 'daymet'):
            if len(idf) > 1:
                idf = idf[idf.index.year.isin(range(1980,2020))]
                idf.index = idf.index.tz_localize(None)
                if feature in ('temperature_2m', 'windspeed_100m', 'relativehumidity_2m', 'rsds'):
                    idf = idf.groupby(idf.index.date).mean()
                elif '_max_' in feature:
                    idf = idf.groupby(idf.index.date).max()
                elif '_min_' in feature:
                    idf = idf.groupby(idf.index.date).min()
            
        if df is None:
            df = idf
        else:
            df = df.join(idf, how='outer')

    # drop leap days (some GCMs have them, some do not)
    mask = (df.index.month == 2) & (df.index.day == 29)
    df = df[~mask]
        
    if period is not None and option == 'mean':
        df = df.rolling(period, center=True).mean()
    if period is not None and option == 'max':
        df = df.rolling(period, center=True).max()
    if period is not None and option == 'min':
        df = df.rolling(period, center=True).min()

    mask = df.index.year.isin(range(1980,2020))
    if baseline is True:
        for col in df.columns:
            arr = df[mask][col].values
            if not np.isnan(arr).all():
                baseline = np.nanmean(arr)
                df[col] -= baseline
                if relative:
                    df[col] /= baseline
                    df[col] *= 100
    elif baseline in df.columns:
        arr = df.loc[mask, baseline].values
        baseline = np.nanmean(arr)
        non_cesm_cols = [c for c in df.columns if 'CESM' not in c]
        cesm_cols = [c for c in df.columns if 'CESM' in c]
        for col in non_cesm_cols:
            arr = df.loc[mask, col].values
            value0 = np.nanmean(arr)
            df[col] = df[col] - value0 + baseline
        for col in cesm_cols:
            mask_cesm = ~df[col].isna()
            value0 = np.nanmean(df[col].values)
            baseline = np.nanmean(df.loc[mask_cesm, non_cesm_cols].values)
            df[col] = df[col] - value0 + baseline

    df = df.dropna(axis=0, how='all')

    if tslice is not None:
        df = df.iloc[tslice]

    if years is not None:
        df = df[df.index.year.isin(years)]
    
    return df

In [None]:
def plot(df_245, df_585, ylabel, xlabel, 
         figsize=(10, 4), legend_offset=1.15, 
         del_degf_ylabel=False, abs_degf_ylabel=False,
         y_offset_frac=0.05, option='line',
         fp_out=None, show=False):
    
    fig, ax = plt.subplots(1, 2, figsize=figsize)

    data = {'SSP2 4.5': df_245, 'SSP5 8.5': df_585}
    
    for i, (title, df) in enumerate(data.items()):
        subax = ax[i]
        for i, (model, series) in enumerate(df.items()):
            series = series.dropna()
            subax.set_title(title)
            tag = TAGS[i]
            if not pd.isna(series).all():
                if option == 'line':
                    subax.plot(series, c=colors[tag], linewidth=linewidth)
                    subax.plot(series.index[-1], series.values[-1], marker=markers[tag], c=colors[tag], markersize=markersize, label=model)
                elif option == 'scatter':
                    subax.scatter(series.index, series.values, marker=markers[tag], c=colors[tag], s=markersize*3, label=model)

    
    y0 = np.min((np.nanmin(df_245.values), np.nanmin(df_585.values)))
    y1 = np.max((np.nanmax(df_245.values), np.nanmax(df_585.values)))
    offset = y_offset_frac * np.abs(y1)
    for subax in ax:
        ticks = pd.date_range(df_245.index.values[0], df_245.index.values[-1], freq='10y')
        subax.set_xticks(ticks, labels=ticks.year)
        subax.set_ylim(y0-offset, y1+offset)
        subax.grid(True)
        if xlabel:
            subax.set_xlabel(xlabel)

    ax[0].set_ylabel(ylabel)

    if del_degf_ylabel:
        secax = ax[1].secondary_yaxis('right', functions=(lambda x: x*9/5, lambda x: x*5/9))
        secax.set_ylabel(ylabel.replace('$^\circ$C', '$^\circ$F'))
    elif abs_degf_ylabel:
        secax = ax[1].secondary_yaxis('right', functions=(lambda x: x*9/5+32, lambda x: (x-32)*5/9))
        secax.set_ylabel(ylabel.replace('$^\circ$C', '$^\circ$F'))
    
    plt.legend()
    plt.legend(loc='center left', bbox_to_anchor=(legend_offset, 0.5))
    plt.tight_layout()
    if fp_out is not None:
        plt.savefig(fp_out, bbox_inches='tight', dpi=300)
    if show:
        plt.show()
    plt.close()

In [None]:
def plotly(df_245, df_585, ylabel, xlabel, 
         figsize=(10, 4), legend_offset=1.15, 
         del_degf_ylabel=False, abs_degf_ylabel=False,
         y_offset_frac=0.05, option='line',
         fp_out=None, show=False):
    
    fig = make_subplots(rows=1, cols=2, horizontal_spacing=0.0,
                    subplot_titles=("SSP2 4.5", "SSP5 8.5"))

    for col in df_245.columns:
        if (~df_245[col].isna()).any():
            tag = col.lower().replace('-', '')
            scatter0 = go.Scatter(x=df_245.index, y=df_245[col], name=col + ' (SSP2 4.5)',
                                  marker_color=colors[tag],
                                  mode='lines+markers',
                                  marker_size=6,
                                  marker_symbol=plotly_marker_map[markers[tag]],
                                  connectgaps=False)
            scatter1 = go.Scatter(x=df_585.index, y=df_585[col], name=col + ' (SSP5 8.5)',
                                  marker_color=colors[tag],
                                  mode='lines+markers',
                                  marker_size=6,
                                  marker_symbol=plotly_marker_map[markers[tag]],
                                  connectgaps=False)
            fig.add_trace(scatter0, row=1, col=1)
            fig.add_trace(scatter1, row=1, col=2)
    
    fig['layout']['yaxis']['title']=ylabel.replace('$^\circ$', '°')
    fig['layout']['xaxis']['title']=xlabel
    fig['layout']['xaxis2']['title']=xlabel
    
    fig.update_xaxes(matches='x')
    fig.update_yaxes(matches='y')
    fig.update_xaxes(showline=True, linewidth=2, linecolor='black', mirror=True)
    fig.update_yaxes(showline=True, linewidth=2, linecolor='black', mirror=True)
    fig.update_xaxes(showgrid=True, gridwidth=1, gridcolor='Grey')
    fig.update_yaxes(showgrid=True, gridwidth=1, gridcolor='Grey')
    fig.update_xaxes(zeroline=True, zerolinewidth=1, zerolinecolor='Grey')
    fig.update_yaxes(zeroline=True, zerolinewidth=1, zerolinecolor='Grey')
    fig.update_yaxes(showticklabels=False, row=1, col=2)
    fig.update_layout(height=600, width=1200)
    if fp_out is not None:
        fp_out = fp_out.replace('.png', '.html')
        fig.write_html(fp_out)
    if show:
        fig.show()

In [None]:
all_regions = list(REGIONS.keys())
# all_regions = []
all_regions += ['atlantic', 'pacific', 'gulf']

In [None]:
for region in all_regions:
# for region in ['conus']:
    display(region)
    region = region.lower().replace(' ', '_')
    
    data_kwargs = {
        'temperature_2m': dict(period=365*10, baseline=True, tslice=slice(None, None, 365*5)),
        'temperature_max_2m': dict(period=365*10, option='max', baseline='ERA5', tslice=slice(None, None, 365*10)),
        'temperature_min_2m': dict(period=365*10, option='min', baseline='ERA5', tslice=slice(None, None, 365*10)),
        'relativehumidity_2m': dict(period=365*10, baseline=True, tslice=slice(None, None, 365*5), relative=False),
        'rsds': dict(period=365*10, baseline=True, tslice=slice(None, None, 365*5), relative=True, years=list(range(2005, 2055))),
        'pr': dict(period=365*10, baseline=True, tslice=slice(None, None, 365*5), relative=True),
        'windspeed_100m': dict(period=365*10, baseline=True, tslice=slice(None, None, 365*5), relative=True, years=list(range(2005, 2055))),
    }
    
    plot_kwargs = {
        'temperature_2m': dict(ylabel='Change in Temperature ($^\circ$C)', xlabel='Year (10-Year Moving Average)', 
                               figsize=(10, 4), legend_offset=1.15, del_degf_ylabel=True, y_offset_frac=0.05,
                               fp_out=f'./trend_plots/{region}_t2m.png'),
        'temperature_max_2m': dict(ylabel='10-Year Maximum Temperature ($^\circ$C)', xlabel='Year (10-Year Window Maximum)', 
                                   figsize=(10, 4), legend_offset=1.2, abs_degf_ylabel=True, y_offset_frac=0.05, 
                                   fp_out=f'./trend_plots/{region}_t2m_max.png'),
        'temperature_min_2m': dict(ylabel='10-Year Minimum Temperature ($^\circ$C)', xlabel='Year (10-Year Window Minimum)', 
                                   figsize=(10, 4), legend_offset=1.2, abs_degf_ylabel=True, y_offset_frac=0.2, 
                                   fp_out=f'./trend_plots/{region}_t2m_min.png'),
        'relativehumidity_2m': dict(ylabel='Percent Change in Relative Humidity (%)', xlabel='Year (10-Year Moving Average)', 
                               figsize=(10, 4), legend_offset=1.0, y_offset_frac=0.05,
                                   fp_out=f'./trend_plots/{region}_rh.png'),
        'rsds': dict(ylabel='Percent Change in GHI (%)', xlabel='Year (10-Year Moving Average)', 
                               figsize=(10, 4), legend_offset=1.0, y_offset_frac=0.05,
                             fp_out=f'./trend_plots/{region}_ghi.png'),
        'pr': dict(ylabel='Percent Change in Precipitation (%)', xlabel='Year (10-Year Moving Average)', 
                   figsize=(10, 4), legend_offset=1.0, y_offset_frac=0.05,
                   fp_out=f'./trend_plots/{region}_pr.png'),
        'windspeed_100m': dict(ylabel='Percent Change in Windspeed (%)', xlabel='Year (10-Year Moving Average)', 
                               figsize=(10, 4), legend_offset=1.0, y_offset_frac=0.05,
                               fp_out=f'./trend_plots/{region}_ws100m.png'),
    }

    for feature in FEATURES:
        df_245 = get_trend_df(region, 'ssp245', feature, **data_kwargs[feature])
        df_585 = get_trend_df(region, 'ssp585', feature, **data_kwargs[feature])
        
        plot(df_245, df_585, show=False, **plot_kwargs[feature])
        plotly(df_245, df_585, show=False, **plot_kwargs[feature])

        if feature.startswith('temperature'):
            plot_kwargs[feature]['ylabel'] = plot_kwargs[feature]['ylabel'].replace('$^\circ$C', '$^\circ$F')
            plot_kwargs[feature]['fp_out'] = plot_kwargs[feature]['fp_out'].replace('.png', '_degf.png')
            df_245 *= 9/5
            df_585 *= 9/5
            if '_max_' in feature or '_min_' in feature:
                df_245 += 32
                df_585 += 32
            plotly(df_245, df_585, show=False, **plot_kwargs[feature])