# Visualize results for talks
Small script used to plot data in isolation.

In [None]:
# allows update of external libraries without need to reload package
%load_ext autoreload
%autoreload 2

In [None]:
import a2.dataset
import a2.plotting
import matplotlib.pyplot as plt
import numpy as np
import matplotlib

In [None]:
def single_plot(
    cm,
    ax,
    vmin=0,
    vmax=0.45,
    xedges=None,
    yedges=None,
    text_color="seashell",
    title="Dataset 2020",
):
    if xedges is None:
        xedges = [0, 0.5, 1]
    if yedges is None:
        yedges = [0, 0.5, 1]

    norm = a2.plotting.utils_plotting.get_norm("linear", vmin=vmin, vmax=vmax)
    mesh = ax.pcolormesh(xedges, yedges, cm.T, norm=norm, cmap="Blues")
    a2.plotting.utils_plotting.overplot_values(np.array(cm), ax, 2, 2, color=text_color)
    ax.set_title(title)


def prepare_axis(ax0, no_ylabels=False):
    ax0.set_xticks([0.25, 0.75])
    ax0.set_xticklabels(["not raining", "raining"])
    ax0.set_xlabel("Predicted label")
    if no_ylabels:
        ax0.set_yticks([])
    else:
        ax0.set_yticks([0.25, 0.75])
        ax0.set_yticklabels(["raining", "not raining"])
        ax0.set_ylabel("True label")
    return ax0


def plot_pretty_confusion_matrix_2(cm0: np.ndarray | None = None, cm1: np.ndarray | None = None, font_size: int = 12):
    if cm0 is None:
        cm0 = [[0.18, 0.42], [0.28, 0.13]]
    if cm1 is None:
        cm1 = [[0.19, 0.41], [0.29, 0.11]]
    font = {"family": "DejaVu Sans", "weight": "normal", "size": font_size}
    matplotlib.rc("font", **font)
    cm0 = np.array(cm0)
    cm1 = np.array(cm1)
    xedges = [0, 0.5, 1]
    yedges = [0, 0.5, 1]
    fig = plt.figure()
    vmin = 0
    vmax = 0.45

    vertical_offset = 0.2
    horizontal_seperation_axes = 0.02
    colorbar_length = 0.7
    ax0 = plt.axes([0.1, vertical_offset, 0.4, 0.4])
    text_color = "red"

    single_plot(
        cm0,
        ax0,
        vmin=0,
        vmax=0.45,
        xedges=[0, 0.5, 1],
        yedges=[0, 0.5, 1],
        text_color=text_color,
        title="Dataset 2020",
    )

    ax0 = prepare_axis(ax0)

    ax1 = plt.axes([0.1 + 0.4 + horizontal_seperation_axes, vertical_offset, 0.4, 0.4])
    norm = a2.plotting.utils_plotting.get_norm("linear", vmin=vmin, vmax=vmax)
    mesh = ax1.pcolormesh(xedges, yedges, cm1.T, norm=norm, cmap="Blues")
    a2.plotting.utils_plotting.overplot_values(np.array(cm1), ax1, 2, 2, color=text_color)
    ax1 = prepare_axis(ax1, no_ylabels=True)
    ax1.set_title("Whole dataset")

    axes_colorbar = plt.axes(
        [
            0.1 + (0.4 * 2 + horizontal_seperation_axes - colorbar_length) / 2,
            0.0,
            colorbar_length,
            0.02,
        ]
    )
    cbar = plt.colorbar(mesh, cax=axes_colorbar, orientation="horizontal")
    axes_colorbar.set_title("Test set fraction")
    fig.savefig("test.png", bbox_inches="tight", dpi=400)


plot_pretty_confusion_matrix_2()

In [None]:
def save_single_plot(cm, filename="test.pdf", title="Dataset 2020", font_size=12):
    font = {"family": "DejaVu Sans", "weight": "normal", "size": font_size}
    matplotlib.rc("font", **font)

    fig = plt.figure()
    ax = plt.axes()
    ax = prepare_axis(ax)
    single_plot(
        cm,
        ax,
        vmin=0,
        vmax=0.45,
        xedges=[0, 0.5, 1],
        yedges=[0, 0.5, 1],
        text_color="red",
        title=title,
    )
    fig.tight_layout()
    if filename is not None:
        fig.savefig(filename, bbox_inches="tight", dpi=400)


cm0 = [[0.18, 0.42], [0.28, 0.13]]
cm0 = np.array(cm0)
save_single_plot(cm0, filename="BestModel.png", title="Best model", font_size=15)