# Plot results

In [1]:
%%capture
!pip install plotly==5.24.1
!pip install kaleido==0.2.1

In [None]:
import os
import numpy as np

from experiments.mupc_paper.utils import compute_metric_stats

import matplotlib.pyplot as plt
import plotly.graph_objs as go

## Plotting utils

In [None]:
def plot_loss_stats(metric, yaxis_title, save_path, test_every=1):
    means, stds = compute_metric_stats(metric)
    y_upper, y_lower = means + stds, means - stds

    n_iters = len(means)
    iters = [t for t in range(n_iters)]

    color = "#EF553B"
    fig = go.Figure()
    fig.add_trace(
        go.Scatter(
            x=list(iters) + list(iters[::-1]),
            y=list(y_upper) + list(y_lower[::-1]),
            fill="toself",
            fillcolor=color,
            line=dict(color="rgba(255,255,255,0)"),
            hoverinfo="skip",
            showlegend=False,
            opacity=0.3
        )
    )
    fig.add_trace(
        go.Scatter(
            x=iters,
            y=means,
            mode="lines",
            line=dict(width=2, color=color),
            showlegend=False
        )
    )

    xtickvals = [0, int(iters[-1] / 2), iters[-1]]
    xticktext = xtickvals if (
            test_every == 1
    ) else [(t+1)*test_every for t in xtickvals]
    fig.update_layout(
        height=300,
        width=400,
        xaxis=dict(
            title="Training step",
            tickvals=xtickvals,
            ticktext=xticktext
        ),
        yaxis=dict(title=yaxis_title),
        font=dict(size=16)
    )
    fig.write_image(save_path)


def plot_loss_and_accuracy(
        loss,
        accuracy,
        save_path,
        test_every=1
):
    loss_means, loss_stds = compute_metric_stats(loss)
    loss_y_upper, loss_y_lower = loss_means + loss_stds, loss_means - loss_stds

    acc_means, acc_stds = compute_metric_stats(accuracy)
    acc_y_upper, acc_y_lower = acc_means + acc_stds, acc_means - acc_stds
    
    n_train_iters = len(loss_means)
    train_iters = [t for t in range(n_train_iters)]
    
    n_test_iters = len(acc_means)
    test_iters = [t * test_every for t in range(n_test_iters)]

    loss_color, accuracy_color = "#EF553B", "#636EFA"
    fig = go.Figure()
    fig.add_trace(
        go.Scatter(
            x=list(train_iters) + list(train_iters[::-1]),
            y=list(loss_y_upper) + list(loss_y_lower[::-1]),
            fill="toself",
            fillcolor=loss_color,
            line=dict(color="rgba(255,255,255,0)"),
            hoverinfo="skip",
            showlegend=False,
            opacity=0.3
        )
    )
    fig.add_trace(
        go.Scatter(
            x=train_iters,
            y=loss_means,
            mode="lines",
            line=dict(width=2, color=loss_color),
            showlegend=False
        )
    )

    fig.add_trace(
        go.Scatter(
            x=list(test_iters) + list(test_iters[::-1]),
            y=list(acc_y_upper) + list(acc_y_lower[::-1]),
            fill="toself",
            fillcolor=accuracy_color,
            line=dict(color="rgba(255,255,255,0)"),
            hoverinfo="skip",
            showlegend=False,
            opacity=0.3,
            yaxis="y2"
        )
    )
    fig.add_trace(
        go.Scatter(
            x=test_iters,
            y=acc_means,
            mode="lines+markers",
            line=dict(width=2, color=accuracy_color),
            showlegend=False,
            yaxis="y2"
        )
    )
    xtickvals = [0, int(train_iters[-1]/2), train_iters[-1]]
    fig.update_layout(
        height=300,
        width=400,
        xaxis=dict(
            title="Training step",
            tickvals=xtickvals,
        ),
        yaxis=dict(
            title=f"Train dis. loss",
            titlefont=dict(
                color=loss_color
            ),
            tickfont=dict(
                color=loss_color
            )
        ),
        yaxis2=dict(
            title=f"Test accuracy (%)",
            side="right",
            overlaying="y",
            titlefont=dict(
                color=accuracy_color
            ),
            tickfont=dict(
                color=accuracy_color,
            )
        ),
        font=dict(size=16)
    )
    fig.write_image(save_path)


def plot_metric_phase_diagram(metric, metric_id, save_path, title=None, log=False):
    n_widths, n_hiddens = metric.shape[0], metric.shape[1]

    norm = mcolors.LogNorm() if log else None
    im = plt.imshow(
        metric, 
        origin="lower", 
        interpolation="bicubic",
        norm=norm
    )
    plt.xlabel("$H$", fontsize=30, labelpad=15)
    plt.ylabel("$N$", fontsize=30, labelpad=15)
    if title is not None:
        plt.title(title, fontsize=30, pad=20)
    
    cbar = plt.colorbar(im)
    cbar.set_label(
        metric_id, 
        fontsize=30, 
        labelpad=15
    )
    cbar.ax.tick_params(labelsize=14)
    
    ax = plt.gca()    
    xtick_positions = [i for i in range(n_widths)]
    ytick_positions = [i for i in range(n_hiddens)]
    tick_labels = [f"$2^{i+1}$" for i in range(n_hiddens)]
    
    ax.set_xticks(xtick_positions)
    ax.set_yticks(ytick_positions)
    ax.set_xticklabels(tick_labels, fontsize=16)
    ax.set_yticklabels(tick_labels, fontsize=16)
    
    plt.savefig(save_path, bbox_inches="tight")
    plt.close("all")


## Plotting scripts

In [None]:
results_dir = "pcn_results"
datasets = ["MNIST", "Fashion-MNIST"]
width = 256
n_hidden = 2
layer_types = ["basis_fn", "mlp"]
init_types = ["gen", "amort"]
activity_lr = 5e-1
param_lr = 1e-3
batch_size = 64
n_train_iters = 300
test_every = 50
n_seeds = 3

train_gen_losses_all_seeds = [[] for seed in range(n_seeds)]
train_amort_losses_all_seeds = [[] for seed in range(n_seeds)]
test_amort_accs_all_seeds = [[] for seed in range(n_seeds)]
for dataset in datasets:
    for layer_type in layer_types:
        for init_type in init_types:
            for seed in range(n_seeds):
                save_path = os.path.join(
                    results_dir,
                    dataset,
                    f"width_{width}",
                    f"{n_hidden}_n_hidden",
                    f"{layer_type}_layer",
                    f"{init_type}_init",
                    f"{activity_lr}_activity_lr",
                    f"{param_lr}_param_lr",
                    f"batch_size_{batch_size}",
                    f"{n_train_iters}_train_iters",
                    f"test_every_{test_every}",
                    str(seed)
                )
                train_gen_losses_all_seeds[seed] = np.load(f"{save_path}/train_gen_losses.npy")
                train_amort_losses_all_seeds[seed] = np.load(f"{save_path}/train_amort_losses.npy")            
                test_amort_accs_all_seeds[seed] = np.load(f"{save_path}/test_amort_accs.npy")
            
            plot_loss_stats(
                metric=train_gen_losses_all_seeds,
                yaxis_title="Train gen. loss",
                save_path=f"{results_dir}/{dataset}/train_gen_losses_{layer_type}_layer_{init_type}_init.pdf.pdf"
            )
            plot_loss_and_accuracy(
                loss=train_amort_losses_all_seeds,
                accuracy=test_amort_accs_all_seeds,
                save_path=f"{results_dir}/{dataset}/train_amort_losses_&_accs_{layer_type}_layer_{init_type}_init.pdf",
                test_every=test_every
            )


In [None]:
SAVE_DIR = "activity_hessian_results"
IN_OUT_DIMS = [[10, 784]]
ACT_FNS = ["linear", "tanh"]# "relu"]
WIDTHS = [2**i for i in range(1, 8)]
N_HIDDENS = [2**i for i in range(1, 8)]
N_SEEDS = 3

for in_out_dims in IN_OUT_DIMS:
    for act_fn in ACT_FNS:      
        for seed in range(N_SEEDS):
            
            # max & min eigens & cond nums
            H_max_eigen = np.zeros((len(WIDTHS), len(N_HIDDENS)))
            H_min_eigen = np.zeros_like(H_max_eigen)
            H_cond_num = np.zeros_like(H_max_eigen)
            
            for i, width in enumerate(WIDTHS):
                for j, n_hidden in enumerate(N_HIDDENS):
                    save_path = os.path.join(
                        SAVE_DIR, 
                        f"{in_out_dims}_in_out_dims",
                        act_fn,
                        f"width_{width}", 
                        f"{n_hidden}_n_hidden", 
                        str(seed)
                    )
                    H_eigens = np.load(f"{save_path}/hessian_eigenvals.npy")
                    H_max_eigen[i, j] = max(H_eigens)
                    H_min_eigen[i, j] = min(H_eigens)
                    
                    cond_num = np.abs(max(H_eigens))/np.abs(min(H_eigens))
                    H_cond_num[i, j] = cond_num

            # max & min eigens, and condition number
            plot_metric_phase_diagram(
                H_max_eigen, 
                "$\lambda_{max}(H_{\mathbf{z}})$", 
                f"{save_path}/H_max_eigen.pdf"
            )
            plot_metric_phase_diagram(
                H_min_eigen, 
                "$\lambda_{min}(H_{\mathbf{z}})$", 
                f"{save_path}/H_min_eigen.pdf"
            )
            plot_metric_phase_diagram(
                H_cond_num, 
                "$\kappa(H_{\mathbf{z}})$", 
                f"{save_path}/H_cond_num.pdf",
                log=False
            )
                            