# Imports

In [None]:
import pickle
import xarray as xr
import numpy as np
import seaborn as sns
import matplotlib as mpl
import matplotlib.pyplot as plt
import matplotlib.cm as cm
from matplotlib.colors import LinearSegmentedColormap
from matplotlib import colors as mat_colors
from functools import partial
import mpl_toolkits.axisartist as axisartist
import mpl_toolkits
from mpl_toolkits.axes_grid1 import Size, Divider
import os

# Load Data

In [None]:
with open('plot_data/model_comparison.pickle', 'rb') as f:
    data_dict = pickle.load(f)

# Define functions for plot

## performance measurements

In [None]:
def BIAS(a1, a2):
    return (a1 - a2).mean().item()


def RMSE(a1, a2):
    return np.sqrt(((a1 - a2)**2).mean()).item()


def DIFF(a1, a2):
    return np.max(np.abs(a1 - a2)).item()

## heatmap

In [None]:
def heatmap(data_dict,
            bed_shape,
            annotation=None,
            annotation_x_position=0,
            fig=None, ax=None,
            cmap='vlag',
            cmap_levels=None,
            grid_color='grey',
            grid_linewidth=1.5,
            presentation=False,
            labels_pad=-320,
            xlim=None,
            nr_of_iterations=None):
    if not ax:
        ax = plt.gca()
    
    if not fig:
        fig = plt.gcf()

    #if all(dataset is None for dataset in datasets):
    #    raise ValueError('All datasets are None!')

    # define variables for plotting
    sfc_h_diff = []
    BIAS_sfc = []
    RMSE_sfc = []
    DIFF_sfc = []
    DIFF_volume = []
    array_length = 0

    # create data and label variables
    sfc_h_1 = data_dict['COMBINE']['surface_h']
    sfc_h_2 = data_dict['OGGM']['surface_h']
    array_length = len(sfc_h_1)
    sfc_h_diff.append(sfc_h_1 - sfc_h_2)
    BIAS_sfc.append(BIAS(sfc_h_1, sfc_h_2))
    RMSE_sfc.append(RMSE(sfc_h_1, sfc_h_2))
    DIFF_sfc.append(DIFF(sfc_h_1, sfc_h_2))
    DIFF_volume.append(data_dict['COMBINE']['volume'] - data_dict['OGGM']['volume'])
    
    if bool(data_dict['MUSCL']):
        sfc_h_3 = data_dict['MUSCL']['surface_h']
        sfc_h_diff.append(sfc_h_1 - sfc_h_3)
        BIAS_sfc.append(BIAS(sfc_h_1, sfc_h_3))
        RMSE_sfc.append(RMSE(sfc_h_1, sfc_h_3))
        DIFF_sfc.append(DIFF(sfc_h_1, sfc_h_3))
        DIFF_volume.append(data_dict['COMBINE']['volume'] - data_dict['MUSCL']['volume'])
    else:
        sfc_h_diff.append(None)

    # create variables for ploting (data and y label)
    data = []
    y_labels = []

    # first add heading
    data.append(np.empty((array_length)) * np.nan)
    if not presentation:
        if opti_var == 'bed_h':
            y_labels.append('    RMSE_b  DIFF_b  RMSE_s  DIFF_s  fct  time')
        elif opti_var in ['bed_shape', 'w0']:
            y_labels.append('    RMSE_b  DIFF_b  RMSE_w  DIFF_w  fct  time')
        else:
            raise ValueError('Unknown opti_var !')
        y_label_variable_format = '{:6.2f}, {: 6.2f}, {:6.2f}, {:6.2f}'
    else:
        y_labels.append(r'  DIFF_s  DIFF_V')

        y_label_variable_format = '{: 5.0e}, {: 6.0e}'

    # add all other data with empty line for None data
    for i, spatial_data in enumerate(sfc_h_diff):
        if spatial_data is None:
            data.append(np.empty((array_length)) * np.nan)
            y_labels.append((chr(65 + i) + ': NO MUSCL RUN'))
        else:
            data.append(spatial_data)
            y_labels.append((chr(65 + i) + ':' + y_label_variable_format).format(DIFF_sfc[i],
                                                                              DIFF_volume[i]))

    # make data an numpy array
    data = np.array(data)

    #define colormap limits
    color_nr = 100
    cmap_limit = np.max(np.array([np.abs(np.floor(np.nanmin(np.array(data)))),
                                  np.abs(np.ceil(np.nanmax(np.array(data))))]))
    if cmap_limit < 1:
        cmap_limit = 1
    cmap_levels = np.linspace(-cmap_limit, cmap_limit, color_nr, endpoint=True)
    
    rel_color_steps = np.arange(color_nr)/color_nr
    if cmap == 'rainbow':
        colors = cm.rainbow(rel_color_steps)
    elif cmap == 'vlag':
        colors = sns.color_palette('vlag', color_nr)
    elif cmap == 'icefire':
        colors = sns.color_palette('icefire', color_nr)
    elif cmap == 'Spectral':
        colors = sns.color_palette('Spectral_r', color_nr)
    
    cmap = LinearSegmentedColormap.from_list('custom', colors, N=color_nr)
    cmap.set_bad(color='white')
    norm = mat_colors.BoundaryNorm(cmap_levels, cmap.N)

    # plot heatmap
    im = plt.imshow(data, aspect='auto', interpolation=None, cmap=cmap, norm=norm, alpha=1.)
    
    # Turn spines and ticks off and create white frame.
    for key, spine in ax.axis.items():
        spine.major_ticks.set_visible(False)
        spine.minor_ticks.set_visible(False)
        spine.line.set_visible(False)
        # spine.line.set_color(grid_color)
        # spine.line.set_linewidth(0) #grid_linewidth)
    
    # set y ticks
    ax.set_yticks(np.arange(data.shape[0]))
                
    ax.set_yticklabels(y_labels)
    #for tick in ax.get_yticklabels():
    #    tick.set_fontname("Arial")

    # align yticklabels left
    ax.axis["left"].major_ticklabels.set_ha("left")
    
    # set pad to put labels over heatmap
    ax.axis["left"].major_ticklabels.set_pad(labels_pad)

    # set y minor grid
    ax.set_yticks(np.arange(data.shape[0]+1)-.5, minor=True)
    ax.grid(which="minor", axis='y', color=grid_color, linestyle='-', linewidth=grid_linewidth)
    
    # set x ticklabels off
    ax.set_xticklabels([])
    
    # create colorbar
    cax = ax.inset_axes([1.01, 0.1, 0.03, 0.8]) 
    #cax = fig.add_axes([0.5, 0, 0.01, 1])
    cbar = fig.colorbar(im, cax=cax, boundaries=cmap_levels, spacing='proportional',)
    cbar.set_ticks([np.min(cmap_levels),0,np.max(cmap_levels)])
    #if opti_var == 'bed_h':
    cbar.set_ticklabels(['{:d}'.format(int(-cmap_limit)), '0' ,'{:d}'.format(int(cmap_limit))])
    #elif opti_var == 'bed_shape':
    #cbar.set_ticklabels(['{:.1f}'.format(-cmap_limit), '0' ,'{:.1f}'.format(cmap_limit)])
    #elif opti_var == 'w0':
    #    cbar.set_ticklabels(['{:d}'.format(int(-cmap_limit)), '0' ,'{:d}'.format(int(cmap_limit))])
    #else:
    #    raise ValueError('Unknown opti var!!')
    #cbar.ax.set_ylabel(cbarlabel,)
    
    # set title
    #ax.set_title(title)
    
    if annotation is not None:
        # include text
        ax.text(annotation_x_position, 1, 
                 annotation,
                 horizontalalignment='center',
                 verticalalignment='center',
                 transform=ax.transAxes)
    
    return im

## function for text only

In [None]:
def set_text(ax, text, alignment='left', facecolor=[0.5,0.5,0.5], fontsize=12):
    ax.text(0, 0, text,
            fontsize=fontsize,
            verticalalignment='center', horizontalalignment=alignment)

    ax.set_ylim([-1, 1])

    if alignment == 'left':
        ax.set_xlim([0, 1])
    elif alignment == 'center':
        ax.set_xlim([-1, 1])

    if type(ax.axis) == mpl_toolkits.axes_grid1.mpl_axes.Axes.AxisDict:
        for key, spine in ax.axis.items():
                spine.major_ticks.set_visible(False)
                spine.minor_ticks.set_visible(False)
                spine.line.set_visible(False)
    else:
        ax.axis('off')

    ax.set_yticklabels([])
    ax.set_xticklabels([])

    ax.set_facecolor(facecolor)

## function for heatmap axis

In [None]:
def setup_axes(fig, rect):
    ax = axisartist.Subplot(fig, rect)
    fig.add_subplot(ax)

    return ax

## legend function

In [None]:
def add_legend(ax,
               #title,
               fontsize,
               labels):
    
    ax.plot([],
            [],
            c='none',
            label=labels[0])
    ax.plot([],
            [],
            c='none',
            label=labels[1])
    
    leg = ax.legend(loc='center',
                    fontsize=fontsize,
                    #title=title,
                    handlelength=0,
                    handletextpad=0,
                    fancybox=True)
    for item in leg.legendHandles:
        item.set_visible(False)
    ax.axis('off')

# Create figure

## all possibilities

In [None]:
different_bed_h = ['linear', 'cliff', 'random']
different_widths = ['constant', 'wide_top']
different_bed_shapes = ['rectangular', 'parabolic', 'trapezoidal']
different_glacier_states = ['equilibrium', 'advancing', 'retreating']
different_models = ['OGGM', 'COMBINE', 'MUSCL']

## actual figure function

In [None]:
#def create_figure_one_control_var(input_folders,
#                                  output_folder,
#                                  experiment_descriptions, 
#                                  bed_geometry, 
#                                  control_var, 
#                                  filename,
#                                  nr_of_iterations=-1,
#                                  show_legend=True,
#                                  suffix=[],
#                                  fontsize=10,
#                                  presentation=False,
#                                  only_equ=False,
#                                  line_height_multiplier=0.2,
#                                  facecolor = [0.9, 0.9, 0.9],
#                                  dpi=300,
#                                  file_format='pdf',
#                                  save_file=False,
#                                  labels_pad=-320):
#line_height_multiplier=0.2
#facecolor='white'
#dpi=300
#file_format='pdf'
#save_file=True
#labels_pad=-200
#fontsize=20
#show_legend=True
#presentation=True
#output_folder = ''
#filename='rec_model_comparision'
#
##for bed_shape in different_bed_shapes:
#bed_shape = different_bed_shapes[0]

def create_figure(
    bed_shape = different_bed_shapes[0],
    line_height_multiplier=0.2,
    facecolor='white',
    dpi=300,
    file_format='pdf',
    save_file=True,
    labels_pad=-200,
    fontsize=20,
    show_legend=True,
    presentation=True,
    output_folder = '',
    filename='rec_model_comparision'):
    mpl.rcParams.update({'font.size': fontsize})

    fig = plt.figure(figsize=(1, 1), facecolor=facecolor)

    # define grid for axis

    #       tiltle
    #------------------------
    #        heading
    #------------------------
    # identifier | subplots

    # multiplier for variable plot height
    #if bed_shape == 'rectangular':
    height_multiplier = 2
    #else:
    #    height_multiplier = 1
    # define fixed size of subplot
    subplot_width = 3
    subplot_height = .4 + line_height_multiplier * height_multiplier
    if not presentation:
        subplot_separation_x = 1.5
        subplot_separation_y = .1
    else:
        subplot_separation_x = 1.
        subplot_separation_y = .5

    # define fixed x size for identifier (first columns)
    identifier_width = 1.2
    identifier_separation = .1

    # define separation identifier to subplots
    separation_identifier_plots = .5

    # define hight of header line
    height_header = .5

    #define separation identifier/subplots to heading
    separation_heading_plots = .1

    #define separation heading title
    separation_heading_title = .1

    # define hight of title 
    height_title = .5 + line_height_multiplier * height_multiplier

    # fixed size in inch
    # along x axis                                                              x-index for locator
    horiz = [Size.Fixed(identifier_width),                                    # 0 first identifier
             Size.Fixed(identifier_separation),   
             Size.Fixed(identifier_width),                                    # 2 second identifier
             Size.Fixed(separation_identifier_plots),
             Size.Fixed(subplot_width),                                       # 4 first subplot column
             Size.Fixed(subplot_separation_x),
             Size.Fixed(subplot_width),                                       # 6 second subplot column
             Size.Fixed(subplot_separation_x),
             Size.Fixed(subplot_width)                                        # 8 third subplot column
            ]

    if show_legend:
                                                                                  # y-index for locator
        vert = [Size.Fixed(subplot_height),                                       # 0 6th row subplot
                Size.Fixed(subplot_separation_y),
                Size.Fixed(subplot_height),                                       # 2 5th row subplot
                Size.Fixed(subplot_separation_y),
                Size.Fixed(subplot_height),                                       # 4 4th row subplot 
                Size.Fixed(subplot_separation_y),
                Size.Fixed(subplot_height),                                       # 6 3rd row subplot 
                Size.Fixed(subplot_separation_y),
                Size.Fixed(subplot_height),                                       # 8 2nd row subplot
                Size.Fixed(subplot_separation_y),
                Size.Fixed(subplot_height),                                       # 10 1st row subplot 
                Size.Fixed(separation_heading_plots),
                Size.Fixed(height_header),                                        # 12 header
                Size.Fixed(separation_heading_title),
                Size.Fixed(height_title)                                          # 14 title
               ]
    else:
        vert = [Size.Fixed(subplot_height),                                       # 0 6th row subplot
                Size.Fixed(subplot_separation_y),
                Size.Fixed(subplot_height),                                       # 2 5th row subplot
                Size.Fixed(subplot_separation_y),
                Size.Fixed(subplot_height),                                       # 4 4th row subplot 
                Size.Fixed(subplot_separation_y),
                Size.Fixed(subplot_height),                                       # 6 3rd row subplot 
                Size.Fixed(subplot_separation_y),
                Size.Fixed(subplot_height),                                       # 8 2nd row subplot
                Size.Fixed(subplot_separation_y),
                Size.Fixed(subplot_height),                                       # 10 1st row subplot 
                Size.Fixed(separation_heading_plots),
                Size.Fixed(height_header)                                         # 12 header                                         # 14 title
               ]

    rect = (0., 0., 1., 1.)  # Position of the grid in the figure

    # divide the axes rectangle into grid whose size is specified by horiz * vert
    divider = Divider(fig, rect, horiz, vert, aspect=False)

    if show_legend:
        # include title and experiment description
        ax = fig.subplots()
        add_legend(ax=ax,
                   fontsize=fontsize,
                   labels=['A: COMBINE - OGGM', 'B: COMBINE - MUSCL'])
        ax.set_axes_locator(divider.new_locator(nx=0, nx1=9, ny=14))

    # include header line
    if bed_shape == 'rectangular':
        shape_header = 'widhts'
    elif bed_shape == 'parabolic':
        shape_header = 'shape factor'
    elif bed_shape == 'trapezoidal':
        shape_header = 'w0'
    else:
        raise ValueError('Unknown bed geometry!')

    for i, item in enumerate(['bed',
                              shape_header,
                              'equilibrium',
                              'advancing',
                              'retreating']):
        ax = setup_axes(fig, 111)
        set_text(ax, item, facecolor=facecolor, alignment='center', fontsize=fontsize)
        ax.set_axes_locator(divider.new_locator(nx=i*2, ny=12))

    # variables to keep track of current row
    row_nr = 10

    cmap='Spectral'

    for bed_h in different_bed_h:

        for bed_width in different_widths:
            # set first identifier (for two lines)
            ax = setup_axes(fig, 111)
            set_text(ax, bed_h, facecolor=facecolor, alignment='center', fontsize=fontsize)
            ax.set_axes_locator(divider.new_locator(nx=0, ny=row_nr))
            # set second identifier
            ax = setup_axes(fig, 111)
            set_text(ax, bed_width, facecolor=facecolor, alignment='center', fontsize=fontsize)
            ax.set_axes_locator(divider.new_locator(nx=2, ny=row_nr))

            for column, glacier_state in enumerate(different_glacier_states):

                with plt.rc_context({'font.family': 'monospace'}):
                    ax = setup_axes(fig, 111)
                    #if not only_equ:
                    im = heatmap(data_dict[bed_shape][bed_h][bed_width][glacier_state],
                                 bed_shape=bed_shape,
                                 fig=fig,
                                 ax=ax,
                                 cmap=cmap,
                                 grid_color=facecolor,
                                 presentation=presentation,
                                 #nr_of_iterations=nr_of_iterations,
                                 labels_pad=labels_pad)

                    ax.set_axes_locator(divider.new_locator(nx=column * 2 + 4, ny=row_nr))
            row_nr -= 2

    if save_file:
        fig.savefig(output_folder + filename + '.' + file_format,format=file_format,bbox_inches='tight',dpi=dpi)

# Rectangular

In [None]:
create_figure(
    bed_shape = different_bed_shapes[0],
    line_height_multiplier=0.4,
    facecolor='white',
    dpi=300,
    file_format='pdf',
    save_file=True,
    labels_pad=-200,
    fontsize=20,
    show_legend=True,
    presentation=True,
    output_folder = '',
    filename='rec_model_comparision')

# Parabolic

In [None]:
create_figure(
    bed_shape = different_bed_shapes[1],
    line_height_multiplier=0.4,
    facecolor='white',
    dpi=300,
    file_format='pdf',
    save_file=True,
    labels_pad=-200,
    fontsize=20,
    show_legend=True,
    presentation=True,
    output_folder = '',
    filename='par_model_comparision')

# Trapezoidal

In [None]:
create_figure(
    bed_shape = different_bed_shapes[2],
    line_height_multiplier=0.4,
    facecolor='white',
    dpi=300,
    file_format='pdf',
    save_file=True,
    labels_pad=-200,
    fontsize=20,
    show_legend=True,
    presentation=True,
    output_folder = '',
    filename='tra_model_comparision')