Imports:

In [None]:
import os

import numpy as np
import pandas as pd
import matplotlib.cm as cm
import matplotlib as mpl
from matplotlib import pyplot as plt
from scipy.spatial import Voronoi, voronoi_plot_2d

In [None]:
data = pd.read_csv(os.path.join('results', 'exp_res_final.csv')).sort_values('n_model_params')

In [None]:
data = data[data.epoch == data.epoch.max()]

Main plot, based on https://stackoverflow.com/questions/41244322/how-to-color-voronoi-according-to-a-color-scale-and-the-area-of-each-cell and https://stackoverflow.com/questions/20515554/colorize-voronoi-diagram.

In [None]:
def voronoi_finite_polygons_2d(vor, radius=None):
    """
    Reconstruct infinite voronoi regions in a 2D diagram to finite
    regions.

    Parameters
    ----------
    vor : Voronoi
        Input diagram
    radius : float, optional
        Distance to 'points at infinity'.

    Returns
    -------
    regions : list of tuples
        Indices of vertices in each revised Voronoi regions.
    vertices : list of tuples
        Coordinates for revised Voronoi vertices. Same as coordinates
        of input vertices, with 'points at infinity' appended to the
        end.

    """

    if vor.points.shape[1] != 2:
        raise ValueError("Requires 2D input")

    new_regions = []
    new_vertices = vor.vertices.tolist()

    center = vor.points.mean(axis=0)
    if radius is None:
        radius = vor.points.ptp().max()

    # Construct a map containing all ridges for a given point
    all_ridges = {}
    for (p1, p2), (v1, v2) in zip(vor.ridge_points, vor.ridge_vertices):
        all_ridges.setdefault(p1, []).append((p2, v1, v2))
        all_ridges.setdefault(p2, []).append((p1, v1, v2))

    # Reconstruct infinite regions
    for p1, region in enumerate(vor.point_region):
        vertices = vor.regions[region]

        if all(v >= 0 for v in vertices):
            # finite region
            new_regions.append(vertices)
            continue

        # reconstruct a non-finite region
        ridges = all_ridges[p1]
        new_region = [v for v in vertices if v >= 0]

        for p2, v1, v2 in ridges:
            if v2 < 0:
                v1, v2 = v2, v1
            if v1 >= 0:
                # finite ridge: already in the region
                continue

            # Compute the missing endpoint of an infinite ridge

            t = vor.points[p2] - vor.points[p1]  # tangent
            t /= np.linalg.norm(t)
            n = np.array([-t[1], t[0]])  # normal

            midpoint = vor.points[[p1, p2]].mean(axis=0)
            direction = np.sign(np.dot(midpoint - center, n)) * n
            far_point = vor.vertices[v2] + direction * radius

            new_region.append(len(new_vertices))
            new_vertices.append(far_point.tolist())

        # sort region counterclockwise
        vs = np.asarray([new_vertices[v] for v in new_region])
        c = vs.mean(axis=0)
        angles = np.arctan2(vs[:, 1] - c[1], vs[:, 0] - c[0])
        new_region = np.array(new_region)[np.argsort(angles)]

        # finish
        new_regions.append(new_region.tolist())

    return new_regions, np.asarray(new_vertices)

In [None]:
def get_logspace_ticks(max_oom: int, min_oom: int = 0):
    big_ticks_vals = np.log([10 ** oom for oom in range(min_oom, max_oom + 1)])
    big_ticks_text = [f'$10^{oom}$' for oom in range(min_oom, max_oom + 1)]

    return big_ticks_vals, big_ticks_text

In [None]:
def calc_interp_peak_model_dims(interp_peak_n_params: int, n_feats: int, min_lat: int, max_lat: int, min_hid: int,
                                max_hid: int, n_points: int = 2 * 10 ** 4) -> tuple:
    lat_dims = np.linspace(min_lat, max_lat, num=n_points)
    hid_dims = 0.5 * (interp_peak_n_params - n_feats - lat_dims) / (n_feats + lat_dims + 1)
    hid_dims[hid_dims < min_hid] = None
    hid_dims[hid_dims > max_hid] = None

    return lat_dims, hid_dims

In [None]:
def plot_loss_phase_diag(data: pd.DataFrame, big_fs: int = 18, small_fs: int = 16,
                         loss_name: str = 'Train', max_lat_oom: int = 4, max_hidden_oom: int = 2,
                         debug: bool = False):
    # need to make it be a log plot this way because voronoi doesn't explicitly support log-scales
    points = np.log(data[['latent_dim', 'hidden_dim']].values)
    vor = Voronoi(points)
    regions, vertices = voronoi_finite_polygons_2d(vor)

    if loss_name == 'Train':
        loss = data.train_loss.values
    else:
        loss = data.test_loss.values
    # find min/max values for normalization
    minima = loss.min()
    maxima = loss.max()

    # normalize chosen colormap
    norm = mpl.colors.Normalize(vmin=minima, vmax=maxima, clip=True)
    mapper = cm.ScalarMappable(norm=norm, cmap=cm.jet)

    # plot Voronoi diagram, and fill finite regions with color mapped from loss value
    voronoi_plot_2d(vor, show_points=debug, show_vertices=False, line_alpha=debug, ax=plt.gca())
    for r, region in enumerate(regions):
        polygon = vertices[region]
        plt.fill(*zip(*polygon), color=mapper.to_rgba(loss[r]))

    cbar = plt.colorbar(mapper)
    cbar.set_label(f'{loss_name} Loss', fontsize=big_fs)
    cbar.ax.tick_params(labelsize=small_fs)

    plt.xlabel('Latent dim.', fontsize=big_fs)
    plt.ylabel('Hidden width', fontsize=big_fs)

    hid_dim_ticks_vals, hid_dim_ticks_txt = get_logspace_ticks(max_hidden_oom, min_oom=1)
    lat_dim_ticks_vals, lat_dim_ticks_txt = get_logspace_ticks(max_lat_oom, min_oom=0)

    xstart, xend = plt.gca().get_xlim()
    plt.yticks(hid_dim_ticks_vals, hid_dim_ticks_txt, fontsize=small_fs)
    plt.xticks(lat_dim_ticks_vals, lat_dim_ticks_txt, fontsize=small_fs)
    plt.xlim(xstart, xend)

In [None]:
def plot_interp_peak_loc(lat_hid_min_max: tuple, interp_peak_n_params: int, n_feats: int, linestyle: str, label: str):
    interp_peak_lat, interp_peak_hid = calc_interp_peak_model_dims(interp_peak_n_params, n_feats,
                                                                   *lat_hid_min_max)

    plt.plot(np.log(interp_peak_lat), np.log(interp_peak_hid), c='k', zorder=np.inf, linestyle=linestyle,
             linewidth=5, label=label)

To have the legend be at the bottom instead of the top set `loc=(0.1, -0.28)`

In [None]:
plt.figure(figsize=(11, 6), dpi=300)
plot_loss_phase_diag(data.drop_duplicates(subset=['latent_dim', 'hidden_dim']), loss_name='Train', debug=False)

n_feats = 50
data_lat_dim = 20
dataset_size = 5000
dims_buffer_fac: float = 10.

lat_hid_min_max = (data.latent_dim.min() / dims_buffer_fac, data.latent_dim.max() * dims_buffer_fac,
                   data.hidden_dim.min() / dims_buffer_fac, data.hidden_dim.max() * dims_buffer_fac)
plot_interp_peak_loc(lat_hid_min_max, dataset_size * n_feats, n_feats, linestyle='dashed', label='# of features')
plot_interp_peak_loc(lat_hid_min_max, dataset_size * data_lat_dim, n_feats, linestyle='dotted',
                     label="Data's latent dim.")
plot_interp_peak_loc(lat_hid_min_max, dataset_size, n_feats, linestyle=None, label='1')

plt.legend(title=r'# model params/dataset size=', fontsize=14, title_fontsize=14, loc=(0.1, 1.01), ncol=3)

In [None]:
plt.figure(figsize=(11, 6), dpi=300)
plot_loss_phase_diag(data.drop_duplicates(subset=['latent_dim', 'hidden_dim']), loss_name='Test', debug=False)

n_feats = 50
data_lat_dim = 20
dataset_size = 5000
dims_buffer_fac: float = 10.

lat_hid_min_max = (data.latent_dim.min() / dims_buffer_fac, data.latent_dim.max() * dims_buffer_fac,
                   data.hidden_dim.min() / dims_buffer_fac, data.hidden_dim.max() * dims_buffer_fac)
plot_interp_peak_loc(lat_hid_min_max, dataset_size * n_feats, n_feats, linestyle='dashed', label='# of features')
plot_interp_peak_loc(lat_hid_min_max, dataset_size * data_lat_dim, n_feats, linestyle='dotted',
                     label="Data's latent dim.")
plot_interp_peak_loc(lat_hid_min_max, dataset_size, n_feats, linestyle=None, label='1')

plt.legend(title=r'# model params/dataset size=', fontsize=14, title_fontsize=14, loc=(0.1, 1.01), ncol=3)