# Plot utilities

> A set of convenience functions for plotting code

In [None]:
#| default_exp plot_utils

In [None]:
#| export
import fastcore.test
import matplotlib as mpl
import matplotlib.pyplot as plt
from nbdev.showdoc import *
import nptyping
import numpy as np
import pandas

In [None]:
#| export
def plot_strategy_distribution(data, # The dataset containing data on parameters and the strategy distribution
                               strategy_set, # The strategies to plot from the dataset
                               x="pr", # The parameter to place on the x-axis of the plot
                               x_label='Risk of an AI disaster, pr', # the x-axis label
                               title='Strategy distribution', # the plot title
                               ) -> None:
    """Plot the strategy distribution as we vary `x`."""

    fig, ax = plt.subplots()
    ax.stackplot(data[x],
                 [data[strategy + "_frequency"] for strategy in strategy_set],
                 labels=strategy_set,
                 alpha=0.8)
    ax.legend(loc='upper left')
    ax.set_title(title)
    ax.set_xlabel(x_label)
    ax.set_ylabel('Proportion')

    # Add threshold boundaries to convey dilemma region
    plt.vlines([data['threshold_society_prefers_safety'].values[0],
                data['threshold_risk_dominant_safety'].values[0]],
               0,
               0.995,
               colors=['C2', 'C3'],
               linewidth=3)

In [None]:
#| export
def plot_heatmap(table, # A pivot table, created using `pandas.pivot` function
                 xlabel="x",
                 ylabel="y",
                 zlabel="z",
                 cmap='inferno',
                 zmin=0,
                 zmax=1,
                ):
    """Plot heatmap using the index, columns, and values from `table`."""
    heatmap, ax = plt.subplots()
    im = ax.imshow(table.values,
                   cmap=cmap,
                   extent=[table.columns.min(),
                           table.columns.max(),
                           table.index.min(),
                           table.index.max()],
                   vmin=zmin,
                   vmax=zmax,
                   interpolation='nearest',
                   origin='lower',
                   aspect='auto')
    ax.set(xlabel=xlabel,
           ylabel=ylabel)

    cbar = heatmap.colorbar(im)
    cbar.ax.set_ylabel(zlabel)

In [None]:
#| hide
import nbdev; nbdev.nbdev_export()