# Plot results

In [1]:
%%capture
!pip install plotly==5.24.1
!pip install kaleido==0.2.1

In [2]:
import os
import numpy as np

import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
import matplotlib.cm as cm
from matplotlib.colors import LogNorm

import plotly.graph_objs as go
import plotly.colors as pc

os.environ["BROWSER_PATH"] = "/home/myhome/chrome-headless-shell/linux-132.0.6834.83/chrome-headless-shell-linux64/chrome-headless-shell"

## Hessian analysis results

In [14]:
def plot_hessian_eigenvalues_slice(
        hessian_eigenspectrum, 
        matrix_id,    
        save_path,
        width_slice=None, 
        n_hidden_slice=None
):
    widths = list(hessian_eigenspectrum.keys())
    n_hiddens = list(hessian_eigenspectrum[2].keys())
    fig = go.Figure()
    n_bins = 50

    colorscale = "Blues" if width_slice is not None else "Reds"
    colors = pc.sample_colorscale(colorscale, len(widths)+3)[3:]
    if width_slice is not None:
        for i, width in enumerate(widths):
            for j, n_hidden in enumerate(n_hiddens):    
                if width == width_slice:
                    fig.add_trace(
                        go.Histogram(
                            x=hessian_eigenspectrum[width][n_hidden],
                            histnorm="probability",
                            nbinsx=n_bins,
                            name=f"$H=2^{{{j+1}}}$",
                            marker=dict(color=colors[j])
                        )
                    )

    if n_hidden_slice is not None:
        for i, n_hidden in enumerate(n_hiddens):    
            for j, width in enumerate(widths):
                if n_hidden == n_hidden_slice:
                    fig.add_trace(
                        go.Histogram(
                            x=hessian_eigenspectrum[width][n_hidden],
                            histnorm="probability",
                            nbinsx=n_bins,
                            name=f"$N=2^{{{j+1}}}$",
                            marker=dict(color=colors[j])
                        )
                    )
    
    fig.update_layout(
        barmode="overlay",
        height=400,
        width=550,
        xaxis=dict(title=f"$\LARGE{{\lambda({{{matrix_id}}})}}$"),
        yaxis=dict(
            title=f"Density (log)",
            type="log",
            exponentformat="power",
            dtick=1
        ),
        font=dict(size=18),
        margin=dict(b=100)
    )
    fig.update_traces(opacity=0.75)
    fig.write_image(save_path)


def plot_metric_phase_diagram(metric, metric_id, save_path, title=None, log=False):
    n_widths, n_hiddens = metric.shape[0], metric.shape[1]

    norm = mcolors.LogNorm() if log else None
    im = plt.imshow(
        metric, 
        origin="lower", 
        interpolation="bicubic",
        norm=norm
    )
    plt.xlabel("$H$", fontsize=30, labelpad=15)
    plt.ylabel("$N$", fontsize=30, labelpad=15)
    if title is not None:
        plt.title(title, fontsize=30, pad=20)
    
    cbar = plt.colorbar(im)
    cbar.set_label(
        metric_id, 
        fontsize=30, 
        labelpad=15
    )
    cbar.ax.tick_params(labelsize=14)
    
    ax = plt.gca()    
    xtick_positions = [i for i in range(n_widths)]
    ytick_positions = [i for i in range(n_hiddens)]
    tick_labels = [f"$2^{i+1}$" for i in range(n_hiddens)]
    
    ax.set_xticks(xtick_positions)
    ax.set_yticks(ytick_positions)
    ax.set_xticklabels(tick_labels, fontsize=16)
    ax.set_yticklabels(tick_labels, fontsize=16)
    
    plt.savefig(save_path, bbox_inches="tight")
    plt.close("all")


def get_act_fn_title(act_fn):
    if act_fn == "linear":
        return "Linear"
    elif act_fn == "tanh":
        return "Tanh"
    elif act_fn == "relu":
        return "ReLU"


In [None]:
SAVE_DIR = "activity_hessian_results"
IN_OUT_DIMS = ["width"]
ACT_FNS = ["linear", "tanh", "relu"]
USE_SKIPS = [False]  #True
WEIGHT_INITS = ["orthogonal"]#["one_over_N", "standard_gauss", "standard", "orthogonal"]
PARAM_TYPES = ["sp"]  #"mupc", 
ACTIVITY_DECAY = [0]
WIDTHS = [2**i for i in range(1, 8)]
N_HIDDENS = [2**i for i in range(1, 8)]
N_SEEDS = 3

N_slice, H_slice = 128, 128
for in_out_dims in IN_OUT_DIMS:
    print(f"\nin_out_dims: {in_out_dims}")
    for act_fn in ACT_FNS:
        print(f"\n\tact_fn: {act_fn}")
        act_fn_title = get_act_fn_title(act_fn)
        for use_skips in USE_SKIPS:
            print(f"\n\t\tuse_skips: {use_skips}")
            for weight_init in WEIGHT_INITS:
                print(f"\n\t\t\tweight_init: {weight_init}")
                for param_type in PARAM_TYPES:
                    print(f"\n\t\t\t\tparam_type: {param_type}")
                    for activity_decay in ACTIVITY_DECAY:
                        print(f"\n\t\t\t\t\tactivity_decay: {activity_decay}\n")                        
                        for seed in range(N_SEEDS):
                            print(f"\t\t\t\t\t\tseed: {seed}")
            
                            # eigenspectra
                            D_eigenvals, O_eigenvals = {}, {}
                            H_num_eigenvals, H_theory_eigenvals = {}, {}
                            
                            # max & min eigens
                            D_max_eigen = np.zeros((len(WIDTHS), len(N_HIDDENS)))
                            O_max_eigen = np.zeros_like(D_max_eigen)
                            H_num_max_eigen = np.zeros_like(D_max_eigen)
                            H_theory_max_eigen = np.zeros_like(D_max_eigen)
                            
                            D_min_eigen = np.zeros_like(D_max_eigen)
                            O_min_eigen = np.zeros_like(D_max_eigen)
                            H_num_min_eigen = np.zeros_like(D_max_eigen)
                            H_theory_min_eigen = np.zeros_like(D_max_eigen)
                    
                            # cond nums
                            H_num_cond_num = np.zeros_like(D_max_eigen)
                            H_theory_cond_num = np.zeros_like(D_max_eigen)
                            
                            for i, width in enumerate(WIDTHS):
                                D_eigenvals[width] = {}
                                O_eigenvals[width] = {}
                                H_num_eigenvals[width] = {}
                                H_theory_eigenvals[width] = {}
                                for j, n_hidden in enumerate(N_HIDDENS):
                                    save_path = os.path.join(
                                        SAVE_DIR, 
                                        f"{in_out_dims}_in_out_dims",
                                        act_fn,
                                        "no_biases/supervised",
                                        "skips" if use_skips else "no_skips",
                                        f"{weight_init}_weight_init",
                                        f"{param_type}_param",
                                        f"activity_decay_{activity_decay}",
                                        f"width_{width}", 
                                        f"{n_hidden}_n_hidden", 
                                        str(seed)
                                    )
                                    H_num_eigens = np.load(f"{save_path}/num_hessian_eigenvals.npy")
                                    H_num_eigenvals[width][n_hidden] = H_num_eigens
                                    H_num_max_eigen[i, j] = max(H_num_eigens)
                                    H_num_min_eigen[i, j] = min(H_num_eigens)
                                    
                                    cond_num = np.abs(max(H_num_eigens))/np.abs(min(H_num_eigens))
                                    H_num_cond_num[i, j] = cond_num
                                    
                                    if act_fn == "linear":
                                        D_eigens = np.load(f"{save_path}/theory_D_eigenvals.npy")
                                        O_eigens = np.load(f"{save_path}/theory_O_eigenvals.npy")
                                        H_theory_eigens = np.load(f"{save_path}/theory_hessian_eigenvals.npy")
                        
                                        # full eigenspectra
                                        D_eigenvals[width][n_hidden] = D_eigens
                                        O_eigenvals[width][n_hidden] = O_eigens
                                        H_theory_eigenvals[width][n_hidden] = H_theory_eigens
                        
                                        # max & min eigens
                                        D_max_eigen[i, j] = max(D_eigens)
                                        O_max_eigen[i, j] = max(O_eigens)        
                                        H_theory_max_eigen[i, j] = max(H_theory_eigens)
                                        
                                        D_min_eigen[i, j] = min(D_eigens)
                                        O_min_eigen[i, j] = min(O_eigens)        
                                        H_theory_min_eigen[i, j] = min(H_theory_eigens)
                        
                                        # cond nums
                                        H_theory_cond_num[i, j] = np.abs(max(H_theory_eigens))/np.abs(min(H_theory_eigens))
    
                            # H slices, max & min eigens, and condition number
                            plot_hessian_eigenvalues_slice(
                                H_num_eigenvals, 
                                matrix_id="H_{\mathbf{z}}", 
                                width_slice=N_slice, 
                                save_path=f"{save_path}/H_num_eigenvals_N_{N_slice}.pdf", 
                            )
                            plot_hessian_eigenvalues_slice(
                                H_num_eigenvals, 
                                matrix_id="H_{\mathbf{z}}", 
                                n_hidden_slice=H_slice, 
                                save_path=f"{save_path}/H_num_eigenvals_H_{H_slice}.pdf", 
                            )
                            plot_metric_phase_diagram(
                                H_num_max_eigen, 
                                "$\lambda_{max}(H_{\mathbf{z}})$", 
                                f"{save_path}/H_num_max_eigen.pdf"
                            )
                            plot_metric_phase_diagram(
                                H_num_min_eigen, 
                                "$\lambda_{min}(H_{\mathbf{z}})$", 
                                f"{save_path}/H_num_min_eigen.pdf"
                            )
                            plot_metric_phase_diagram(
                                H_num_cond_num, 
                                "$\kappa(H_{\mathbf{z}})$", 
                                f"{save_path}/H_num_cond_num.pdf",
                                title=act_fn_title,
                                log=False if (not use_skips and param_type == "sp") else True
                            )
                            
                            if act_fn == "linear":
                                # D & O slices
                                plot_hessian_eigenvalues_slice(
                                    D_eigenvals, 
                                    matrix_id="D", 
                                    width_slice=N_slice, 
                                    save_path=f"{save_path}/D_eigenvals_N_{N_slice}.pdf", 
                                )
                                plot_hessian_eigenvalues_slice(
                                    O_eigenvals, 
                                    matrix_id="O", 
                                    width_slice=N_slice, 
                                    save_path=f"{save_path}/O_eigenvals_N_{N_slice}.pdf", 
                                )
                                plot_hessian_eigenvalues_slice(
                                    D_eigenvals, 
                                    matrix_id="D", 
                                    n_hidden_slice=H_slice, 
                                    save_path=f"{save_path}/D_eigenvals_H_{H_slice}.pdf", 
                                )
                                plot_hessian_eigenvalues_slice(
                                    O_eigenvals, 
                                    matrix_id="O", 
                                    n_hidden_slice=H_slice, 
                                    save_path=f"{save_path}/O_eigenvals_H_{H_slice}.pdf", 
                                )
        
                                # H theory slices
                                plot_hessian_eigenvalues_slice(
                                    H_theory_eigenvals, 
                                    matrix_id="H_{theory}", 
                                    width_slice=N_slice, 
                                    save_path=f"{save_path}/H_theory_eigenvals_N_{N_slice}.pdf", 
                                )
                                plot_hessian_eigenvalues_slice(
                                    H_theory_eigenvals, 
                                    matrix_id="H_{theory}", 
                                    n_hidden_slice=H_slice, 
                                    save_path=f"{save_path}/H_theory_eigenvals_H_{H_slice}.pdf", 
                                )
                        
                                # max & min eigens phase plots
                                plot_metric_phase_diagram(
                                    D_max_eigen, 
                                    "$\lambda_{max}(D)$", 
                                    f"{save_path}/D_max_eigen.pdf"
                                )
                                plot_metric_phase_diagram(
                                    O_max_eigen, 
                                    "$\lambda_{max}(O)$", 
                                    f"{save_path}/O_max_eigen.pdf"
                                )
                                plot_metric_phase_diagram(
                                    H_theory_max_eigen, 
                                    "$\lambda_{max}(H_{theory})$", 
                                    f"{save_path}/H_theory_max_eigen.pdf"
                                )
                        
                                plot_metric_phase_diagram(
                                    D_min_eigen, 
                                    "$\lambda_{min}(D)$", 
                                    f"{save_path}/D_min_eigen.pdf"
                                )
                                plot_metric_phase_diagram(
                                    O_min_eigen, 
                                    "$\lambda_{min}(O)$", 
                                    f"{save_path}/O_min_eigen.pdf"
                                )
                                plot_metric_phase_diagram(
                                    H_theory_min_eigen, 
                                    "$\lambda_{min}(H_{theory})$", 
                                    f"{save_path}/H_theory_min_eigen.pdf"
                                )
                        
                                # cond num phase plot
                                plot_metric_phase_diagram(
                                    H_theory_cond_num, 
                                    "$\kappa(H_{theory})$", 
                                    f"{save_path}/H_theory_cond_num.pdf",
                                    title=act_fn_title,
                                    log=False if (not use_skips and param_type == "sp") else True
                                )
 

## Forward pass results

In [23]:
def plot_metric_per_iv(
        metric, 
        metric_id, 
        ivs,
        yaxis_title, 
        xaxis_title, 
        param_type, 
        save_path
):
    fig = go.Figure()
    layer_idxs = [1, "1/4L", "1/2L", "3/4L", "L"]

    n_layers = metric.shape[0]
    colorscale = "Reds" if metric_id == "activities" else "Greens"
    colors = pc.sample_colorscale(colorscale, n_layers + 3)[3:]
    for i, layer_metric in enumerate(metric):
        layer_idx = layer_idxs[i]
        fig.add_traces(
            go.Scatter(
                x=ivs,
                y=layer_metric,
                name=f"$\ell = {{{layer_idx}}}$",
                mode="lines" if (
                    xaxis_title == "Training iteration"
                ) else "lines+markers",
                line=dict(width=2, color=colors[i]),
                opacity=0.8
            )
        )

    fig.update_layout(
        height=350,
        width=550,
        xaxis=dict(title=xaxis_title),
        yaxis=dict(title=yaxis_title),
        font=dict(size=18),
        margin=dict(r=140, b=90, l=110)
    )
    if metric_id == "activities" and param_type == "sp":
        fig.update_layout(
            yaxis=dict(
                type="log",
                exponentformat="power",
                dtick=10 if xaxis_title == "Depth" else 1
            )
        )            
    if xaxis_title != "Training iteration":
        fig.update_layout(
            xaxis=dict(
                tickvals=ivs,
                ticktext=[f"$\\large{{2^{{{int(np.log2(iv))}}}}}$" for iv in ivs],
                type="log",
                exponentformat="power"
            )
        )
    else:
        fig.update_layout(
            xaxis=dict(
                tickvals=[0, int(ivs[-1]/2), ivs[-1]],
                ticktext=[0, int(ivs[-1]/2), ivs[-1]]
            )
        )
            
    fig.write_image(save_path)

In [None]:
RESULTS_DIR = "mlp_fwd_pass_results"
widths = [2 ** i for i in range(7, 11)]
depths = [2 **  i for i in range(4, 10)]
act_fns = ["linear", "tanh", "relu"]
optim_ids = ["sgd", "adam"]
param_types = ["sp", "depth_mup", "orthogonal"]
seed = 54638
n_ts = 3

for act_fn in act_fns:
    for optim_id in optim_ids:
        for param_type in param_types:

            skip_uses = [False, True] if param_type != "orthogonal" else [False]
            for use_skips in skip_uses:
                save_dir = os.path.join(
                    RESULTS_DIR,
                    act_fn,
                    optim_id,
                    param_type,
                    "skips" if use_skips else "no_skips",
                    str(seed)
                )
                avg_activity_l1_per_N_L = np.load(f"{save_dir}/avg_activity_l1_per_N_L.npy")
                avg_activity_l2_per_N_L = np.load(f"{save_dir}/avg_activity_l2_per_N_L.npy")
                param_l2_norms_per_N_L = np.load(f"{save_dir}/param_l2_norms_per_N_L.npy")
                param_spectral_norms_per_N_L = np.load(f"{save_dir}/param_spectral_norms_per_N_L.npy")
        
                # activities l1 norm vs width & depth
                width_idxs = [0, -1]  # 128 & 1024 units
                depth_idx = 0   # 16 layers
                L = depths[depth_idx]
                for t in range(n_ts):
                    plot_metric_per_iv(
                        metric=avg_activity_l1_per_N_L[:, t, :, depth_idx],
                        metric_id="activities", 
                        ivs=widths,
                        yaxis_title="$\Large{||\mathbf{z}_\ell||_1}$", 
                        xaxis_title="Width",
                        param_type=param_type,
                        save_path=f"{save_dir}/avg_activity_l1_per_N_at_t{t}_L_{L}.pdf"
                    )
                    for width_idx in width_idxs:
                        N = widths[width_idx]
                        plot_metric_per_iv(
                            metric=avg_activity_l1_per_N_L[:, t, width_idx, :],
                            metric_id="activities", 
                            ivs=depths,
                            yaxis_title="$\Large{||\mathbf{z}_\ell||_1}$", 
                            xaxis_title="Depth",
                            param_type=param_type,
                            save_path=f"{save_dir}/avg_activity_l1_per_L_at_t{t}_N_{N}.pdf"
                        )
    
                # activities l2 norm vs width & depth
                for t in range(n_ts):
                    plot_metric_per_iv(
                        metric=avg_activity_l2_per_N_L[:, t, :, depth_idx],
                        metric_id="activities", 
                        ivs=widths,
                        yaxis_title="$\Large{||\mathbf{z}_\ell||_2}$", 
                        xaxis_title="Width",
                        param_type=param_type,
                        save_path=f"{save_dir}/avg_activity_l2_per_N_at_t{t}_L_{L}.pdf"
                    )
                    for width_idx in width_idxs:
                        N = widths[width_idx]
                        plot_metric_per_iv(
                            metric=avg_activity_l2_per_N_L[:, t, width_idx, :],
                            metric_id="activities", 
                            ivs=depths,
                            yaxis_title="$\Large{||\mathbf{z}_\ell||_2}$", 
                            xaxis_title="Depth",
                            param_type=param_type,
                            save_path=f"{save_dir}/avg_activity_l2_per_L_at_t{t}_N_{N}.pdf"
                        )
    
                # spectral params vs width & depth
                for t in range(n_ts):
                    plot_metric_per_iv(
                        metric=param_spectral_norms_per_N_L[:, t, :, depth_idx],
                        metric_id="params", 
                        ivs=widths,
                        yaxis_title="$\Large{||W_\ell||_2}$", 
                        xaxis_title="Width", 
                        param_type=param_type,
                        save_path=f"{save_dir}/params_spectral_norm_per_N_at_t{t}_L_{L}.pdf"
                    )
                    for width_idx in width_idxs:
                        N = widths[width_idx]
                        plot_metric_per_iv(
                            metric=param_spectral_norms_per_N_L[:, t, width_idx, :],
                            metric_id="params", 
                            ivs=depths,
                            yaxis_title="$\Large{||W_\ell||_2}$", 
                            xaxis_title="Depth",
                            param_type=param_type,
                            save_path=f"{save_dir}/params_spectral_norm_per_L_at_t{t}_N_{N}.pdf"
                        )
    
                # spectral params over training
                n_train_iters = param_spectral_norms_per_N_L.shape[1]
                for width_idx in width_idxs:
                    N = widths[width_idx]
                    plot_metric_per_iv(
                        metric=param_spectral_norms_per_N_L[:, :, width_idx, depth_idx],
                        metric_id="params", 
                        ivs=[t for t in range(n_train_iters)],
                        yaxis_title="$\Large{||W_\ell||_2}$",
                        xaxis_title="Training iteration", 
                        param_type=param_type,
                        save_path=f"{save_dir}/params_spectral_norm_over_t_N_{N}.pdf"
                    )


## Training results

In [5]:
from utils import compute_cond_num, compute_metric_stats

In [6]:
def plot_metric_per_iv(metric, metric_id, iv_id, test_every, save_path, height=300, width=450):
    key = next(iter(metric))
    n_iters = len(metric[key][0])
    iters = [t for t in range(n_iters)]
    ivs = metric.keys()

    fig = go.Figure()
    if metric_id == "cond_num":
        colorscale = "Viridis"
        colors = pc.sample_colorscale(colorscale, len(ivs)) 
    elif metric_id == "test_acc":
        if iv_id == "n_hidden":
            colorscale = "Blues" 
        elif iv_id == "width":
            colorscale = "Reds" 
        else:
            colorscale = "Oranges"
        colors = pc.sample_colorscale(colorscale, len(ivs)+2)[2:]
        
    for iv, color in zip(ivs, colors):
        means, stds = metric[iv][0], metric[iv][1]
        y_upper, y_lower = means + stds, means - stds

        fig.add_trace(
            go.Scatter(
                x=list(iters) + list(iters[::-1]),
                y=list(y_upper) + list(y_lower[::-1]),
                fill="toself",
                fillcolor=color,
                line=dict(color="rgba(255,255,255,0)"),
                hoverinfo="skip",
                showlegend=False,
                opacity=0.3
            )
        )
        if iv == 8:
            label = "2^3"
        elif iv == 16:
            label = "2^4"
        elif iv == 32:
            label = "2^5"
        elif iv == 64:
            label = "2^6"
        elif iv == 128:
            label = "2^7"
        elif iv == 256:
            label = "2^8"
        elif iv == 512:
            label = "2^9"
        elif iv == 1024:
            label = "2^{10}"

        elif iv == 1000:
            label = "1e^3"
        elif iv == 500:
            label = "5e^2"
        elif iv == 100:
            label = "1e^2"
        elif iv == 50:
            label = "5e^1"
        elif iv == 10:
            label = "1e^1"
        elif iv == 5:
            label = "5e^0"
        elif iv == 1:
            label = "1e^0"
        elif iv == 5e-1:
            label = "5e^{-1}"
        elif iv == 1e-1:
            label = "1e^{-1}"
        elif iv == 5e-2:
            label = "5e^{-2}"
        elif iv == 1e-2:
            label = "1e^{-2}"
        elif iv == 5e-3:
            label = "5e^{-3}"
        elif iv == 1e-3:
            label = "1e^{-3}"
        elif iv == 5e-4:
            label = "5e^{-4}"
        elif iv == 1e-4:
            label = "1e^{-4}"

        if iv_id == "n_hidden":
            iv_label = "H" 
        elif iv_id == "width":
            iv_label = "N" 
        else:
            iv_label = "\eta"
            
        fig.add_trace(
            go.Scatter(
                x=iters,
                y=means,
                mode="lines+markers",
                line=dict(width=2, color=color),
                name=f"${iv_label}={label}$"
            )
        )

    xtickvals = [0, int(iters[-1] / 2), iters[-1]]
    if metric_id == "test_acc":
        xticktext = [(t+1) * test_every for t in xtickvals]
        yaxis_title = "Test accuracy (%)"
    else:
        xticktext = [t * test_every for t in xtickvals]
        yaxis_title = "$\Large{\kappa(\mathrm{H}_{\mathbf{z}})}$"
        
    fig.update_layout(
        height=height,
        width=width,
        xaxis=dict(
            title="Training iteration",
            tickvals=xtickvals,
            ticktext=xticktext
        ),
        yaxis=dict(title=yaxis_title),
        font=dict(size=16),
        margin=dict(r=120 if iv_id != "activity_lr" else 130)
    )
    fig.write_image(save_path)


### GD

In [6]:
############ GD learning rate sweep ############
results_dir = "pcn_results"
datasets = ["MNIST", "Fashion-MNIST"]

act_fns = ["linear", "tanh", "relu"]
n_hiddens = [8, 16, 32]
width = 128
use_skips = False
weight_init = "standard"
param_type = "sp"
param_optim_id = "adam"
param_lr = 1e-3
batch_size = 64
max_infer_iters = 500
activity_optim_id = "gd"
activity_lrs = [5e-1, 1e-1, 5e-2]
activity_decay = 0
weight_decay = 0
spectral_penalty = 0
max_epochs = 1
test_every = 100
n_seeds = 3

for dataset in datasets:
    for act_fn in act_fns:
        for n_hidden in n_hiddens:

            test_accs_per_H = {} 
            for activity_lr in activity_lrs:
                test_accs_all_seeds = [[] for seed in range(n_seeds)]
                for seed in range(n_seeds):
                
                    save_path = os.path.join(
                        results_dir,
                        dataset,
                        f"width_{width}",
                        f"{n_hidden}_n_hidden",
                        act_fn,
                        "skips" if use_skips else "no_skips",
                        f"{weight_init}_weight_init",
                        f"{param_type}_param",
                        f"param_optim_{param_optim_id}",
                        f"param_lr_{param_lr}",
                        f"batch_size_{batch_size}",
                        f"{max_infer_iters}_max_infer_iters",
                        f"activity_optim_{activity_optim_id}",
                        f"activity_lr_{activity_lr}",
                        f"activity_decay_{activity_decay}",
                        f"weight_decay_{weight_decay}",
                        f"spectral_penalty_{spectral_penalty}",
                        f"{max_epochs}_epochs",
                        str(seed)
                    )
                    test_accs = np.load(f"{save_path}/test_accs.npy")
                    test_accs_all_seeds[seed] = test_accs
    
                test_acc_means, test_acc_stds = compute_metric_stats(test_accs_all_seeds)                
                test_accs_per_H[activity_lr] = (test_acc_means, test_acc_stds)
    
            plot_metric_per_iv(
                metric=test_accs_per_H,
                metric_id="test_acc", 
                iv_id="activity_lr",
                test_every=test_every, 
                save_path=f"{results_dir}/{dataset}/test_accs_GD_{act_fn}_{n_hidden}_n_hidden.pdf"
            )


In [9]:
############ best GD results ############
n_hiddens = [8, 16, 32]
act_fns = ["linear", "tanh", "relu"]

results_dir = "pcn_results"
datasets = ["MNIST", "Fashion-MNIST"]
width = 128
use_skips = False
weight_init = "standard"
param_type = "sp"
param_optim_id = "adam"
param_lr = 1e-3
batch_size = 64
max_infer_iters = 500
activity_optim_id = "gd"
activity_decay = 0
weight_decay = 0
spectral_penalty = 0
max_epochs = 1
test_every = 100
n_seeds = 3

for dataset in datasets:
    for act_fn in act_fns:
        
        test_accs_per_H = {} 
        cond_nums_per_H = {} 
        for n_hidden in n_hiddens:
            test_accs_all_seeds = [[] for seed in range(n_seeds)]
            cond_nums_all_seeds = [[] for seed in range(n_seeds)]

            # best lr based on dataset, act fn & n hidden
            if dataset == "MNIST":
                if act_fn == "linear":
                    if n_hidden == 8:
                        activity_lr = 5e-1
                    elif n_hidden == 16:
                        activity_lr = 5e-2
                    elif n_hidden == 32:
                        activity_lr = 5e-2
                
                elif act_fn == "tanh":
                    if n_hidden == 8:
                        activity_lr = 5e-1
                    elif n_hidden == 16:
                        activity_lr = 5e-2
                    elif n_hidden == 32:
                        activity_lr = 5e-2
    
                elif act_fn == "relu":
                    if n_hidden == 8:
                        activity_lr = 5e-1
                    elif n_hidden == 16:
                        activity_lr = 5e-1
                    elif n_hidden == 32:
                        activity_lr = 5e-2

            elif dataset == "Fashion-MNIST":
                if act_fn == "linear":
                    if n_hidden == 8:
                        activity_lr = 5e-1
                    elif n_hidden == 16:
                        activity_lr = 5e-2
                    elif n_hidden == 32:
                        activity_lr = 5e-2
                
                elif act_fn == "tanh":
                    if n_hidden == 8:
                        activity_lr = 5e-1
                    elif n_hidden == 16:
                        activity_lr = 5e-2
                    elif n_hidden == 32:
                        activity_lr = 5e-2
    
                elif act_fn == "relu":
                    if n_hidden == 8:
                        activity_lr = 1e-1
                    elif n_hidden == 16:
                        activity_lr = 5e-1
                    elif n_hidden == 32:
                        activity_lr = 5e-2
            
            for seed in range(n_seeds):
            
                save_path = os.path.join(
                    results_dir,
                    dataset,
                    f"width_{width}",
                    f"{n_hidden}_n_hidden",
                    act_fn,
                    "skips" if use_skips else "no_skips",
                    f"{weight_init}_weight_init",
                    f"{param_type}_param",
                    f"param_optim_{param_optim_id}",
                    f"param_lr_{param_lr}",
                    f"batch_size_{batch_size}",
                    f"{max_infer_iters}_max_infer_iters",
                    f"activity_optim_{activity_optim_id}",
                    f"activity_lr_{activity_lr}",
                    f"activity_decay_{activity_decay}",
                    f"weight_decay_{weight_decay}",
                    f"spectral_penalty_{spectral_penalty}",
                    f"{max_epochs}_epochs",
                    str(seed)
                )
                test_accs = np.load(f"{save_path}/test_accs.npy")
                hessian_eigenvals = np.load(f"{save_path}/hessian_eigenvals.npy")
                cond_nums = [compute_cond_num(eig) for eig in hessian_eigenvals]
                
                test_accs_all_seeds[seed] = test_accs
                cond_nums_all_seeds[seed] = cond_nums

            test_acc_means, test_acc_stds = compute_metric_stats(test_accs_all_seeds)
            cond_num_means, cond_num_stds = compute_metric_stats(cond_nums_all_seeds)
            
            test_accs_per_H[n_hidden] = (test_acc_means, test_acc_stds)
            cond_nums_per_H[n_hidden] = (cond_num_means, cond_num_stds)

        plot_metric_per_iv(
            metric=test_accs_per_H,
            metric_id="test_acc", 
            iv_id="n_hidden",
            test_every=test_every, 
            save_path=f"{results_dir}/{dataset}/best_test_accs_GD_{act_fn}.pdf"
        )
        plot_metric_per_iv(
            metric=cond_nums_per_H, 
            metric_id="cond_num", 
            iv_id="n_hidden",
            test_every=test_every, 
            save_path=f"{results_dir}/{dataset}/best_cond_nums_GD_{act_fn}.pdf"
        )

### Adam

In [8]:
############ Adam learning rate sweep ############
results_dir = "pcn_results"
datasets = ["MNIST", "Fashion-MNIST"]

act_fns = ["linear", "tanh", "relu"]
n_hiddens = [8, 16, 32]
width = 128
use_skips = False
weight_init = "standard"
param_type = "sp"
param_optim_id = "adam"
param_lr = 1e-3
batch_size = 64
max_infer_iters = 500
activity_optim_id = "adam"
activity_lrs = [5e-1, 1e-1, 5e-2]
activity_decay = 0
weight_decay = 0
spectral_penalty = 0
max_epochs = 1
test_every = 100
n_seeds = 3

for dataset in datasets:
    for act_fn in act_fns:
        for n_hidden in n_hiddens:

            test_accs_per_H = {} 
            for activity_lr in activity_lrs:
                test_accs_all_seeds = [[] for seed in range(n_seeds)]
                for seed in range(n_seeds):
                
                    save_path = os.path.join(
                        results_dir,
                        dataset,
                        f"width_{width}",
                        f"{n_hidden}_n_hidden",
                        act_fn,
                        "skips" if use_skips else "no_skips",
                        f"{weight_init}_weight_init",
                        f"{param_type}_param",
                        f"param_optim_{param_optim_id}",
                        f"param_lr_{param_lr}",
                        f"batch_size_{batch_size}",
                        f"{max_infer_iters}_max_infer_iters",
                        f"activity_optim_{activity_optim_id}",
                        f"activity_lr_{activity_lr}",
                        f"activity_decay_{activity_decay}",
                        f"weight_decay_{weight_decay}",
                        f"spectral_penalty_{spectral_penalty}",
                        f"{max_epochs}_epochs",
                        str(seed)
                    )
                    test_accs = np.load(f"{save_path}/test_accs.npy")
                    test_accs_all_seeds[seed] = test_accs
    
                test_acc_means, test_acc_stds = compute_metric_stats(test_accs_all_seeds)                
                test_accs_per_H[activity_lr] = (test_acc_means, test_acc_stds)
    
            plot_metric_per_iv(
                metric=test_accs_per_H,
                metric_id="test_acc", 
                iv_id="activity_lr",
                test_every=test_every, 
                save_path=f"{results_dir}/{dataset}/test_accs_Adam_{act_fn}_{n_hidden}_n_hidden.pdf"
            )


In [9]:
############ best Adam results ############
n_hiddens = [8, 16, 32]
act_fns = ["linear", "tanh", "relu"]

results_dir = "pcn_results"
datasets = ["MNIST", "Fashion-MNIST"]
width = 128
use_skips = False
weight_init = "standard"
param_type = "sp"
param_optim_id = "adam"
param_lr = 1e-3
batch_size = 64
max_infer_iters = 500
activity_optim_id = "adam"
activity_decay = 0
weight_decay = 0
spectral_penalty = 0
max_epochs = 1
test_every = 100
n_seeds = 3

for dataset in datasets:
    for act_fn in act_fns:
        
        test_accs_per_H = {} 
        cond_nums_per_H = {} 
        for n_hidden in n_hiddens:
            test_accs_all_seeds = [[] for seed in range(n_seeds)]
            cond_nums_all_seeds = [[] for seed in range(n_seeds)]

            # best lr based on dataset, act fn & n hidden
            if dataset == "MNIST":
                if act_fn == "linear":
                    if n_hidden == 8:
                        activity_lr = 5e-2
                    elif n_hidden == 16:
                        activity_lr = 5e-1
                    elif n_hidden == 32:
                        activity_lr = 1e-1
                
                elif act_fn == "tanh":
                    if n_hidden == 8:
                        activity_lr = 5e-1
                    elif n_hidden == 16:
                        activity_lr = 5e-2
                    elif n_hidden == 32:
                        activity_lr = 1e-1
    
                elif act_fn == "relu":
                    if n_hidden == 8:
                        activity_lr = 1e-1
                    elif n_hidden == 16:
                        activity_lr = 1e-1
                    elif n_hidden == 32:
                        activity_lr = 5e-1

            elif dataset == "Fashion-MNIST":
                if act_fn == "linear":
                    if n_hidden == 8:
                        activity_lr = 5e-2
                    elif n_hidden == 16:
                        activity_lr = 5e-1
                    elif n_hidden == 32:
                        activity_lr = 5e-1
                
                elif act_fn == "tanh":
                    if n_hidden == 8:
                        activity_lr = 5e-1
                    elif n_hidden == 16:
                        activity_lr = 5e-2
                    elif n_hidden == 32:
                        activity_lr = 1e-1
    
                elif act_fn == "relu":
                    if n_hidden == 8:
                        activity_lr = 1e-1
                    elif n_hidden == 16:
                        activity_lr = 5e-1
                    elif n_hidden == 32:
                        activity_lr = 5e-2
            
            for seed in range(n_seeds):
            
                save_path = os.path.join(
                    results_dir,
                    dataset,
                    f"width_{width}",
                    f"{n_hidden}_n_hidden",
                    act_fn,
                    "skips" if use_skips else "no_skips",
                    f"{weight_init}_weight_init",
                    f"{param_type}_param",
                    f"param_optim_{param_optim_id}",
                    f"param_lr_{param_lr}",
                    f"batch_size_{batch_size}",
                    f"{max_infer_iters}_max_infer_iters",
                    f"activity_optim_{activity_optim_id}",
                    f"activity_lr_{activity_lr}",
                    f"activity_decay_{activity_decay}",
                    f"weight_decay_{weight_decay}",
                    f"spectral_penalty_{spectral_penalty}",
                    f"{max_epochs}_epochs",
                    str(seed)
                )
                test_accs = np.load(f"{save_path}/test_accs.npy")
                hessian_eigenvals = np.load(f"{save_path}/hessian_eigenvals.npy")
                cond_nums = [compute_cond_num(eig) for eig in hessian_eigenvals]
                
                test_accs_all_seeds[seed] = test_accs
                cond_nums_all_seeds[seed] = cond_nums

            test_acc_means, test_acc_stds = compute_metric_stats(test_accs_all_seeds)
            cond_num_means, cond_num_stds = compute_metric_stats(cond_nums_all_seeds)
            
            test_accs_per_H[n_hidden] = (test_acc_means, test_acc_stds)
            cond_nums_per_H[n_hidden] = (cond_num_means, cond_num_stds)

        plot_metric_per_iv(
            metric=test_accs_per_H,
            metric_id="test_acc", 
            iv_id="n_hidden",
            test_every=test_every, 
            save_path=f"{results_dir}/{dataset}/best_test_accs_Adam_{act_fn}.pdf"
        )
        plot_metric_per_iv(
            metric=cond_nums_per_H, 
            metric_id="cond_num", 
            iv_id="n_hidden",
            test_every=test_every, 
            save_path=f"{results_dir}/{dataset}/best_cond_nums_Adam_{act_fn}.pdf"
        )

### Skips

In [10]:
############ skip learning rate sweep ############
results_dir = "pcn_results"
datasets = ["MNIST", "Fashion-MNIST"]

act_fns = ["linear", "tanh", "relu"]
n_hiddens = [8, 16, 32]
width = 128
use_skips = True
weight_init = "standard"
param_type = "sp"
param_optim_id = "adam"
param_lr = 1e-3
batch_size = 64
max_infer_iters = 500
activity_optim_id = "gd"
activity_lrs = [5e-1, 1e-1, 5e-2]
activity_decay = 0
weight_decay = 0
spectral_penalty = 0
max_epochs = 1
test_every = 100
n_seeds = 3

for dataset in datasets:
    for act_fn in act_fns:
        for n_hidden in n_hiddens:

            test_accs_per_H = {} 
            for activity_lr in activity_lrs:
                test_accs_all_seeds = [[] for seed in range(n_seeds)]
                for seed in range(n_seeds):
                
                    save_path = os.path.join(
                        results_dir,
                        dataset,
                        f"width_{width}",
                        f"{n_hidden}_n_hidden",
                        act_fn,
                        "skips" if use_skips else "no_skips",
                        f"{weight_init}_weight_init",
                        f"{param_type}_param",
                        f"param_optim_{param_optim_id}",
                        f"param_lr_{param_lr}",
                        f"batch_size_{batch_size}",
                        f"{max_infer_iters}_max_infer_iters",
                        f"activity_optim_{activity_optim_id}",
                        f"activity_lr_{activity_lr}",
                        f"activity_decay_{activity_decay}",
                        f"weight_decay_{weight_decay}",
                        f"spectral_penalty_{spectral_penalty}",
                        f"{max_epochs}_epochs",
                        str(seed)
                    )
                    test_accs = np.load(f"{save_path}/test_accs.npy")
                    test_accs_all_seeds[seed] = test_accs
    
                test_acc_means, test_acc_stds = compute_metric_stats(test_accs_all_seeds)                
                test_accs_per_H[activity_lr] = (test_acc_means, test_acc_stds)
    
            plot_metric_per_iv(
                metric=test_accs_per_H,
                metric_id="test_acc", 
                iv_id="activity_lr",
                test_every=test_every, 
                save_path=f"{results_dir}/{dataset}/test_accs_GD_{act_fn}_{n_hidden}_n_hidden_skips.pdf"
            )


In [11]:
############ best skips results ############
n_hiddens = [8, 16, 32]
act_fns = ["linear", "tanh", "relu"]

results_dir = "pcn_results"
datasets = ["MNIST", "Fashion-MNIST"]
width = 128
use_skips = True
weight_init = "standard"
param_type = "sp"
param_optim_id = "adam"
param_lr = 1e-3
batch_size = 64
max_infer_iters = 500
activity_optim_id = "gd"
activity_decay = 0
weight_decay = 0
spectral_penalty = 0
max_epochs = 1
test_every = 100
n_seeds = 3

for dataset in datasets:
    for act_fn in act_fns:
        
        test_accs_per_H = {} 
        cond_nums_per_H = {} 
        for n_hidden in n_hiddens:
            test_accs_all_seeds = [[] for seed in range(n_seeds)]
            cond_nums_all_seeds = [[] for seed in range(n_seeds)]

            # best lr based on dataset, act fn & n hidden
            if dataset == "MNIST":
                if act_fn == "linear":
                    if n_hidden == 8:
                        activity_lr = 5e-1
                    elif n_hidden == 16:
                        activity_lr = 5e-1
                    elif n_hidden == 32:
                        activity_lr = 1e-1
                
                elif act_fn == "tanh":
                    if n_hidden == 8:
                        activity_lr = 5e-2
                    elif n_hidden == 16:
                        activity_lr = 5e-1
                    elif n_hidden == 32:
                        activity_lr = 5e-1
    
                elif act_fn == "relu":
                    if n_hidden == 8:
                        activity_lr = 5e-2
                    elif n_hidden == 16:
                        activity_lr = 5e-2
                    elif n_hidden == 32:
                        activity_lr = 5e-1

            elif dataset == "Fashion-MNIST":
                if act_fn == "linear":
                    if n_hidden == 8:
                        activity_lr = 5e-1
                    elif n_hidden == 16:
                        activity_lr = 1e-1
                    elif n_hidden == 32:
                        activity_lr = 5e-1
                
                elif act_fn == "tanh":
                    if n_hidden == 8:
                        activity_lr = 5e-2
                    elif n_hidden == 16:
                        activity_lr = 5e-2
                    elif n_hidden == 32:
                        activity_lr = 5e-1
    
                elif act_fn == "relu":
                    if n_hidden == 8:
                        activity_lr = 5e-2
                    elif n_hidden == 16:
                        activity_lr = 5e-2
                    elif n_hidden == 32:
                        activity_lr = 5e-1
            
            for seed in range(n_seeds):
            
                save_path = os.path.join(
                    results_dir,
                    dataset,
                    f"width_{width}",
                    f"{n_hidden}_n_hidden",
                    act_fn,
                    "skips" if use_skips else "no_skips",
                    f"{weight_init}_weight_init",
                    f"{param_type}_param",
                    f"param_optim_{param_optim_id}",
                    f"param_lr_{param_lr}",
                    f"batch_size_{batch_size}",
                    f"{max_infer_iters}_max_infer_iters",
                    f"activity_optim_{activity_optim_id}",
                    f"activity_lr_{activity_lr}",
                    f"activity_decay_{activity_decay}",
                    f"weight_decay_{weight_decay}",
                    f"spectral_penalty_{spectral_penalty}",
                    f"{max_epochs}_epochs",
                    str(seed)
                )
                test_accs = np.load(f"{save_path}/test_accs.npy")
                hessian_eigenvals = np.load(f"{save_path}/hessian_eigenvals.npy")
                cond_nums = [compute_cond_num(eig) for eig in hessian_eigenvals]
                
                test_accs_all_seeds[seed] = test_accs
                cond_nums_all_seeds[seed] = cond_nums

            test_acc_means, test_acc_stds = compute_metric_stats(test_accs_all_seeds)
            cond_num_means, cond_num_stds = compute_metric_stats(cond_nums_all_seeds)
            
            test_accs_per_H[n_hidden] = (test_acc_means, test_acc_stds)
            cond_nums_per_H[n_hidden] = (cond_num_means, cond_num_stds)

        plot_metric_per_iv(
            metric=test_accs_per_H,
            metric_id="test_acc", 
            iv_id="n_hidden",
            test_every=test_every, 
            save_path=f"{results_dir}/{dataset}/best_test_accs_GD_skips_{act_fn}.pdf"
        )
        plot_metric_per_iv(
            metric=cond_nums_per_H, 
            metric_id="cond_num", 
            iv_id="n_hidden",
            test_every=test_every, 
            save_path=f"{results_dir}/{dataset}/best_cond_nums_GD_skips_{act_fn}.pdf"
        )

### Orthogonal init

In [6]:
############ orthogonal init. learning rate sweep ############
results_dir = "pcn_results"
datasets = ["MNIST", "Fashion-MNIST"]

act_fns = ["linear", "tanh", "relu"]
width = 128
use_skips = False
weight_init = "orthogonal"
param_type = "sp"
param_optim_id = "adam"
param_lr = 1e-3
batch_size = 64
max_infer_iters = 500
activity_optim_id = "gd"
activity_lrs = [5e-1, 1e-1, 5e-2]
activity_decay = 0
weight_decay = 0
spectral_penalty = 0
max_epochs = 1
test_every = 100
n_seeds = 3

for dataset in datasets:
    for act_fn in act_fns:
        n_hiddens = [8, 16, 32, 64, 128] if (
            act_fn != "relu" 
        ) else [8, 16, 32]
        
        for n_hidden in n_hiddens:

            test_accs_per_H = {} 
            for activity_lr in activity_lrs:
                test_accs_all_seeds = [[] for seed in range(n_seeds)]
                for seed in range(n_seeds):
                
                    save_path = os.path.join(
                        results_dir,
                        dataset,
                        f"width_{width}",
                        f"{n_hidden}_n_hidden",
                        act_fn,
                        "skips" if use_skips else "no_skips",
                        f"{weight_init}_weight_init",
                        f"{param_type}_param",
                        f"param_optim_{param_optim_id}",
                        f"param_lr_{param_lr}",
                        f"batch_size_{batch_size}",
                        f"{max_infer_iters}_max_infer_iters",
                        f"activity_optim_{activity_optim_id}",
                        f"activity_lr_{activity_lr}",
                        f"activity_decay_{activity_decay}",
                        f"weight_decay_{weight_decay}",
                        f"spectral_penalty_{spectral_penalty}",
                        f"{max_epochs}_epochs",
                        str(seed)
                    )
                    test_accs = np.load(f"{save_path}/test_accs.npy")
                    test_accs_all_seeds[seed] = test_accs
                    
                test_acc_means, test_acc_stds = compute_metric_stats(test_accs_all_seeds)
                test_accs_per_H[activity_lr] = (test_acc_means, test_acc_stds)

            plot_metric_per_iv(
                metric=test_accs_per_H,
                metric_id="test_acc", 
                iv_id="activity_lr",
                test_every=test_every, 
                save_path=f"{results_dir}/{dataset}/test_accs_orthogonal_{act_fn}_{n_hidden}_n_hidden.pdf"
            )


In [8]:
############ best orthogonal init. results ############
act_fns = ["linear", "tanh", "relu"]

results_dir = "pcn_results"
datasets = ["MNIST", "Fashion-MNIST"]
width = 128
use_skips = False
weight_init = "orthogonal"
param_type = "sp"
param_optim_id = "adam"
param_lr = 1e-3
batch_size = 64
max_infer_iters = 500
activity_optim_id = "gd"
activity_decay = 0
weight_decay = 0
spectral_penalty = 0
max_epochs = 1
test_every = 100
n_seeds = 3

for dataset in datasets:
    for act_fn in act_fns:
        n_hiddens = [8, 16, 32, 64, 128] if (
            act_fn != "relu" 
        ) else [8, 16, 32]
        
        test_accs_per_H = {} 
        cond_nums_per_H = {} 
        for n_hidden in n_hiddens:
            compute_hessian = False if (
                dataset == "Fashion-MNIST" and act_fn == "tanh" and n_hidden == 128
            ) else True
            
            test_accs_all_seeds = [[] for seed in range(n_seeds)]
            cond_nums_all_seeds = [[] for seed in range(n_seeds)]

            # best lr based on dataset, act fn & n hidden
            if dataset == "MNIST":
                if act_fn == "linear":
                    if n_hidden == 8:
                        activity_lr = 5e-2
                    elif n_hidden == 16:
                        activity_lr = 5e-2
                    elif n_hidden == 32:
                        activity_lr = 1e-1
                    elif n_hidden == 64:
                        activity_lr = 1e-1
                    elif n_hidden == 128:
                        activity_lr = 1e-1
                
                elif act_fn == "tanh":
                    if n_hidden == 8:
                        activity_lr = 5e-2
                    elif n_hidden == 16:
                        activity_lr = 5e-2
                    elif n_hidden == 32:
                        activity_lr = 5e-2
                    elif n_hidden == 64:
                        activity_lr = 5e-2
                    elif n_hidden == 128:
                        activity_lr = 5e-1
    
                elif act_fn == "relu":
                    if n_hidden == 8:
                        activity_lr = 5e-1
                    elif n_hidden == 16:
                        activity_lr = 5e-1
                    elif n_hidden == 32:
                        activity_lr = 5e-2

            elif dataset == "Fashion-MNIST":
                if act_fn == "linear":
                    if n_hidden == 8:
                        activity_lr = 5e-2
                    elif n_hidden == 16:
                        activity_lr = 1e-1
                    elif n_hidden == 32:
                        activity_lr = 1e-1
                    elif n_hidden == 64:
                        activity_lr = 5e-2
                    elif n_hidden == 128:
                        activity_lr = 5e-2
                
                elif act_fn == "tanh":
                    if n_hidden == 8:
                        activity_lr = 5e-2
                    elif n_hidden == 16:
                        activity_lr = 5e-2
                    elif n_hidden == 32:
                        activity_lr = 5e-2
                    elif n_hidden == 64:
                        activity_lr = 5e-2
                    elif n_hidden == 128:
                        activity_lr = 5e-1
    
                elif act_fn == "relu":
                    if n_hidden == 8:
                        activity_lr = 1e-1
                    elif n_hidden == 16:
                        activity_lr = 5e-2
                    elif n_hidden == 32:
                        activity_lr = 5e-2
            
            for seed in range(n_seeds):
            
                save_path = os.path.join(
                    results_dir,
                    dataset,
                    f"width_{width}",
                    f"{n_hidden}_n_hidden",
                    act_fn,
                    "skips" if use_skips else "no_skips",
                    f"{weight_init}_weight_init",
                    f"{param_type}_param",
                    f"param_optim_{param_optim_id}",
                    f"param_lr_{param_lr}",
                    f"batch_size_{batch_size}",
                    f"{max_infer_iters}_max_infer_iters",
                    f"activity_optim_{activity_optim_id}",
                    f"activity_lr_{activity_lr}",
                    f"activity_decay_{activity_decay}",
                    f"weight_decay_{weight_decay}",
                    f"spectral_penalty_{spectral_penalty}",
                    f"{max_epochs}_epochs",
                    str(seed)
                )
                test_accs = np.load(f"{save_path}/test_accs.npy")
                if compute_hessian:
                    hessian_eigenvals = np.load(f"{save_path}/hessian_eigenvals.npy")
                    cond_nums = [compute_cond_num(eig) for eig in hessian_eigenvals]
                
                test_accs_all_seeds[seed] = test_accs
                if compute_hessian:
                    cond_nums_all_seeds[seed] = cond_nums

            test_acc_means, test_acc_stds = compute_metric_stats(test_accs_all_seeds)
            if compute_hessian:
                cond_num_means, cond_num_stds = compute_metric_stats(cond_nums_all_seeds)
            
            test_accs_per_H[n_hidden] = (test_acc_means, test_acc_stds)
            if compute_hessian:
                cond_nums_per_H[n_hidden] = (cond_num_means, cond_num_stds)

        plot_metric_per_iv(
            metric=test_accs_per_H,
            metric_id="test_acc", 
            iv_id="n_hidden",
            test_every=test_every, 
            save_path=f"{results_dir}/{dataset}/best_test_accs_orthogonal_{act_fn}.pdf"
        )
        plot_metric_per_iv(
            metric=cond_nums_per_H, 
            metric_id="cond_num", 
            iv_id="n_hidden",
            test_every=test_every, 
            save_path=f"{results_dir}/{dataset}/best_cond_nums_orthogonal_{act_fn}.pdf"
        )

### $\mu$PC

In [7]:
from plotting import plot_metric_stats

In [15]:
def plot_mupc_vs_pc_vs_bp_accs(
        mupc_accs, 
        pc_accs,
        bp_accs, 
        test_every, 
        act_fn, 
        save_path, 
        height=300, 
        width=450,
        show_bp=True,
        show_pc=True,
        show_mupc=True,
    ):
    key = next(iter(mupc_accs))
    n_iters = len(mupc_accs[key][0])
    iters = [t for t in range(n_iters)]
    ivs = mupc_accs.keys()

    mupc_colors = pc.sample_colorscale("Blues", len(ivs)+2)[1:]
    pc_colors = pc.sample_colorscale("Reds", len(ivs)+2)[2:]
    bp_color = "black"

    fig = go.Figure()
    for i, iv in enumerate(ivs):
        mupc_means, mupc_stds = mupc_accs[iv][0], mupc_accs[iv][1]
        pc_means, pc_stds = pc_accs[iv][0], pc_accs[iv][1]
        
        mupc_y_upper, mupc_y_lower = mupc_means + mupc_stds, mupc_means - mupc_stds
        pc_y_upper, pc_y_lower = pc_means + pc_stds, pc_means - pc_stds

        fig.add_trace(
            go.Scatter(
                x=list(iters) + list(iters[::-1]),
                y=list(mupc_y_upper) + list(mupc_y_lower[::-1]),
                fill="toself",
                fillcolor=mupc_colors[i],
                line=dict(color="rgba(255,255,255,0)"),
                hoverinfo="skip",
                showlegend=False,
                opacity=0.3 if show_mupc else 0
            )
        )
        fig.add_trace(
            go.Scatter(
                x=list(iters) + list(iters[::-1]),
                y=list(pc_y_upper) + list(pc_y_lower[::-1]),
                fill="toself",
                fillcolor=pc_colors[i],
                line=dict(color="rgba(255,255,255,0)"),
                hoverinfo="skip",
                showlegend=False,
                opacity=0.3 if show_pc else 0
            )
        )  
        if iv == 8:
            label = "2^3"
        elif iv == 16:
            label = "2^4"
        elif iv == 32:
            label = "2^5"
        elif iv == 64:
            label = "2^6"
        elif iv == 128:
            label = "2^7"

        if iv == 128:
            bp_means, bp_stds = bp_accs[iv][0], bp_accs[iv][1]
            bp_y_upper, bp_y_lower = bp_means + bp_stds, bp_means - bp_stds
            fig.add_trace(
                go.Scatter(
                    x=list(iters) + list(iters[::-1]),
                    y=list(bp_y_upper) + list(bp_y_lower[::-1]),
                    fill="toself",
                    fillcolor=bp_color,
                    line=dict(color="rgba(255,255,255,0)"),
                    hoverinfo="skip",
                    showlegend=False,
                    opacity=0.3 if show_bp else 0
                )
            )
            fig.add_trace(
                go.Scatter(
                    x=iters,
                    y=bp_means,
                    mode="lines",
                    line=dict(width=3, color=bp_color, dash="dash"),
                    showlegend=False,
                    opacity=1 if show_bp else 0
                )
            )

        fig.add_trace(
            go.Scatter(
                x=iters,
                y=mupc_means,
                mode="lines+markers",
                line=dict(width=2, color=mupc_colors[i]),
                showlegend=False,
                opacity=1 if show_mupc else 0
            )
        )
        fig.add_trace(
            go.Scatter(
                x=iters,
                y=pc_means,
                mode="lines+markers",
                line=dict(width=2, color=pc_colors[i], dash="dash"),
                showlegend=False,
                opacity=1 if show_pc else 0
            )
        )

    xtickvals = [0, int(iters[-1] / 2), iters[-1]]
    xticktext = [(t+1) * test_every for t in xtickvals]

    if act_fn == "relu":
        ytickvals = [10, 50, 90]
    elif act_fn == "tanh":
        ytickvals = [60, 75, 90]
    elif act_fn == "linear":
        ytickvals = [10, 30, 50, 70, 90]
    
    yticktext = ytickvals
    fig.update_layout(
        height=height,
        width=width,
        xaxis=dict(
            title="Training iteration",
            tickvals=xtickvals,
            ticktext=xticktext
        ),
        yaxis=dict(
            title="Test accuracy (%)",
            tickvals=ytickvals,
            ticktext=yticktext
        ),
        font=dict(size=16),
        margin=dict(r=120)
    )
    fig.write_image(save_path)


def create_legend(n_hiddens, save_path, colorscale="Blues", starting_color=1, marker_size=5, fontsize=8, height=50, width=500):
    colors = pc.sample_colorscale(colorscale, len(n_hiddens)+2)[starting_color:]

    if "Blues" in colorscale:
        dash = "solid"
    elif "Reds" in colorscale:
        dash = "dash"
    else:
        dash = "dash"
        
    traces = [
        go.Scatter(
            x=[None], y=[None],
            mode="lines" if "gray" in colorscale else "lines+markers",
            line=dict(
                width=2, 
                color=color, 
                dash=dash
            ),
            marker=dict(
                size=marker_size, 
                color=color,
                #symbol="diamond" if "Reds" in colorscale else "circle"
            ),
            name=f"$H={h}$"
        )
        for h, color in zip(n_hiddens, colors)
    ]
    fig = go.Figure(data=traces)    
    fig.update_layout(
        showlegend=True,
        legend=dict(
            orientation="h",
            yanchor="bottom",
            y=1.0,
            xanchor="center",
            x=0.5,
            font=dict(size=fontsize)
        ),
        xaxis=dict(visible=False),
        yaxis=dict(visible=False),
        margin=dict(l=0, r=0, t=0, b=0),
        height=height,
        width=width
    )
    fig.write_image(save_path)


#### Condition numbers

In [17]:
############ cond numers of muPC results for depth at fixed width N = 512 ############
n_hiddens = [8, 16, 32]
act_fns = ["linear", "tanh", "relu"]

results_dir = "pcn_results"
dataset = "MNIST"
width = 512
use_skips = True
weight_init = "standard_gauss"
param_type = "mupc"
param_optim_id = "adam"
param_lr = 1e-1  # 5e-2 for Fashion - based on transfer results below
batch_size = 64
activity_optim_id = "gd"
activity_lr = 5e-1
activity_decay = 0
weight_decay = 0
spectral_penalty = 0
max_epochs = 1
test_every = 100
n_seeds = 3

for act_fn in act_fns:
    
    cond_nums_per_H = {} 
    for n_hidden in n_hiddens:
        max_infer_iters = n_hidden
        
        cond_nums_all_seeds = [[] for seed in range(n_seeds)]
        for seed in range(n_seeds):
        
            save_path = os.path.join(
                results_dir,
                dataset,
                f"width_{width}",
                f"{n_hidden}_n_hidden",
                act_fn,
                "skips" if use_skips else "no_skips",
                f"{weight_init}_weight_init",
                f"{param_type}_param",
                f"param_optim_{param_optim_id}",
                f"param_lr_{param_lr}",
                f"batch_size_{batch_size}",
                f"{max_infer_iters}_max_infer_iters",
                f"activity_optim_{activity_optim_id}",
                f"activity_lr_{activity_lr}",
                f"activity_decay_{activity_decay}",
                f"weight_decay_{weight_decay}",
                f"spectral_penalty_{spectral_penalty}",
                f"{max_epochs}_epochs",
                str(seed)
            )
            hessian_eigenvals = np.load(f"{save_path}/hessian_eigenvals.npy")
            cond_nums = [compute_cond_num(eig) for eig in hessian_eigenvals]                
            cond_nums_all_seeds[seed] = cond_nums

        cond_nums_means, cond_nums_stds = compute_metric_stats(cond_nums_all_seeds)
        cond_nums_per_H[n_hidden] = (cond_nums_means, cond_nums_stds)

    plot_metric_per_iv(
        metric=cond_nums_per_H, 
        metric_id="cond_num", 
        iv_id="n_hidden",
        test_every=test_every, 
        save_path=f"{results_dir}/{dataset}/best_cond_nums_mupc_{act_fn}.pdf"
    )


#### Lr sweep

In [33]:
############ muPC learning rate sweep ############
results_dir = "pcn_results"
datasets = ["MNIST"]

act_fns = ["linear", "tanh", "relu"]
width = 512
n_hiddens = [2**i for i in range(3, 8)]
use_skips = True
weight_init = "standard_gauss"
param_type = "mupc"
param_optim_id = "adam"
param_lr = 1e-1  # NOTE: best based on transfer results below
batch_size = 64
activity_optim_id = "gd"
activity_lrs = [1000, 500, 100, 50, 10, 5, 1, 5e-1, 1e-1, 5e-2, 1e-2]
activity_decay = 0
weight_decay = 0
spectral_penalty = 0
max_epochs = 1
test_every = 300
n_seeds = 3

for dataset in datasets:
    for act_fn in act_fns: 
        for n_hidden in n_hiddens:
            max_infer_iters = n_hidden

            test_accs_per_H = {} 
            for activity_lr in activity_lrs:
                test_accs_all_seeds = [[] for seed in range(n_seeds)]
                for seed in range(n_seeds):
                
                    save_path = os.path.join(
                        results_dir,
                        dataset,
                        f"width_{width}",
                        f"{n_hidden}_n_hidden",
                        act_fn,
                        "skips" if use_skips else "no_skips",
                        f"{weight_init}_weight_init",
                        f"{param_type}_param",
                        f"param_optim_{param_optim_id}",
                        f"param_lr_{param_lr}",
                        f"batch_size_{batch_size}",
                        f"{max_infer_iters}_max_infer_iters",
                        f"activity_optim_{activity_optim_id}",
                        f"activity_lr_{activity_lr}",
                        f"activity_decay_{activity_decay}",
                        f"weight_decay_{weight_decay}",
                        f"spectral_penalty_{spectral_penalty}",
                        f"{max_epochs}_epochs",
                        str(seed)
                    )
                    test_accs = np.load(f"{save_path}/test_accs.npy")
                    # fill NA accuracy with chance performance 
                    if len(test_accs) == 0:
                        test_accs = [10]
                        
                    test_accs_all_seeds[seed] = test_accs
                    
                test_acc_means, test_acc_stds = compute_metric_stats(test_accs_all_seeds)
                test_accs_per_H[activity_lr] = (test_acc_means, test_acc_stds)
            
            plot_metric_per_iv(
                metric=test_accs_per_H,
                metric_id="test_acc", 
                iv_id="activity_lr",
                test_every=test_every, 
                save_path=f"{results_dir}/{dataset}/test_accs_muPC_{act_fn}_{n_hidden}_n_hidden.pdf",
                height=450,
                width=650
            )


#### Width (wider is better)

In [5]:
############ best muPC results for width at fixed depth H = 8 ############
widths = [64, 128, 256, 512, 1024]
act_fns = ["linear", "tanh", "relu"]

results_dir = "pcn_results"
datasets = ["MNIST"]
n_hidden = 8
use_skips = True
weight_init = "standard_gauss"
param_type = "mupc"
param_optim_id = "adam"
param_lr = 1e-1
batch_size = 64
max_infer_iters = 8
activity_optim_id = "gd"
activity_decay = 0
weight_decay = 0
spectral_penalty = 0
max_epochs = 1
test_every = 300
n_seeds = 3

for dataset in datasets:
    for act_fn in act_fns:
        
        test_accs_per_N = {} 
        for width in widths:
            
            test_accs_all_seeds = [[] for seed in range(n_seeds)]

            # best lr based on act
            if act_fn == "linear":
                activity_lr = 1
            elif act_fn == "tanh":
                activity_lr = 1
            elif act_fn == "relu":
                activity_lr = 5e-1
            
            for seed in range(n_seeds):
            
                save_path = os.path.join(
                    results_dir,
                    dataset,
                    f"width_{width}",
                    f"{n_hidden}_n_hidden",
                    act_fn,
                    "skips" if use_skips else "no_skips",
                    f"{weight_init}_weight_init",
                    f"{param_type}_param",
                    f"param_optim_{param_optim_id}",
                    f"param_lr_{param_lr}",
                    f"batch_size_{batch_size}",
                    f"{max_infer_iters}_max_infer_iters",
                    f"activity_optim_{activity_optim_id}",
                    f"activity_lr_{activity_lr}",
                    f"activity_decay_{activity_decay}",
                    f"weight_decay_{weight_decay}",
                    f"spectral_penalty_{spectral_penalty}",
                    f"{max_epochs}_epochs",
                    str(seed)
                )
                test_accs = np.load(f"{save_path}/test_accs.npy") 
                test_accs_all_seeds[seed] = test_accs

            test_acc_means, test_acc_stds = compute_metric_stats(test_accs_all_seeds)
            test_accs_per_N[width] = (test_acc_means, test_acc_stds)

        plot_metric_per_iv(
            metric=test_accs_per_N,
            metric_id="test_acc", 
            iv_id="width",
            test_every=test_every, 
            save_path=f"{results_dir}/{dataset}/best_test_accs_width_muP_{act_fn}.pdf"
        )


#### Depth

In [21]:
############ best muPC results for depth at fixed width N = 512 ############
n_hiddens = [8, 16, 32, 64, 128]
act_fns = ["linear", "tanh", "relu"]

results_dir = "pcn_results"
datasets = ["MNIST"]
width = 512
use_skips = True
weight_init = "standard_gauss"
param_type = "mupc"
param_optim_id = "adam"
param_lr = 1e-1  # NOTE: best based on transfer results below
batch_size = 64
activity_optim_id = "gd"
activity_decay = 0
weight_decay = 0
spectral_penalty = 0
max_epochs = 1
test_every = 300
n_seeds = 3

for dataset in datasets:
    for act_fn in act_fns:
        
        test_accs_per_H = {} 
        for n_hidden in n_hiddens:
            max_infer_iters = n_hidden
            
            test_accs_all_seeds = [[] for seed in range(n_seeds)]

            # best lr based on act
            if act_fn == "linear":
                activity_lr = 1
            elif act_fn == "tanh":
                activity_lr = 1
            elif act_fn == "relu":
                activity_lr = 5e-1
            
            for seed in range(n_seeds):
            
                save_path = os.path.join(
                    results_dir,
                    dataset,
                    f"width_{width}",
                    f"{n_hidden}_n_hidden",
                    act_fn,
                    "skips" if use_skips else "no_skips",
                    f"{weight_init}_weight_init",
                    f"{param_type}_param",
                    f"param_optim_{param_optim_id}",
                    f"param_lr_{param_lr}",
                    f"batch_size_{batch_size}",
                    f"{max_infer_iters}_max_infer_iters",
                    f"activity_optim_{activity_optim_id}",
                    f"activity_lr_{activity_lr}",
                    f"activity_decay_{activity_decay}",
                    f"weight_decay_{weight_decay}",
                    f"spectral_penalty_{spectral_penalty}",
                    f"{max_epochs}_epochs",
                    str(seed)
                )
                test_accs = np.load(f"{save_path}/test_accs.npy") 
                test_accs_all_seeds[seed] = test_accs

            test_acc_means, test_acc_stds = compute_metric_stats(test_accs_all_seeds)
            test_accs_per_H[n_hidden] = (test_acc_means, test_acc_stds)

        plot_metric_per_iv(
            metric=test_accs_per_H,
            metric_id="test_acc", 
            iv_id="n_hidden",
            test_every=test_every, 
            save_path=f"{results_dir}/{dataset}/best_test_accs_depth_muP_{act_fn}.pdf"
        )


#### MNIST results after 5 epochs

In [7]:
results_dir = "pcn_results"
dataset = "MNIST"
act_fn = "relu"
width = 512
n_hidden = 128
use_skips = True
weight_init = "standard_gauss"
param_type = "mupc"
param_optim_id = "adam"
param_lr = 1e-1
batch_size = 64
max_infer_iters = 128
activity_optim_id = "gd"
activity_lr = 5e-1
activity_decay = 0
weight_decay = 0
spectral_penalty = 0
max_epochs = 5
test_every = 300
n_seeds = 5

test_accs_all_seeds = [[] for _ in range(n_seeds)]
for seed in range(n_seeds):
    save_path = os.path.join(
        results_dir,
        dataset,
        f"width_{width}",
        f"{n_hidden}_n_hidden",
        act_fn,
        "skips" if use_skips else "no_skips",
        f"{weight_init}_weight_init",
        f"{param_type}_param",
        f"param_optim_{param_optim_id}",
        f"param_lr_{param_lr}",
        f"batch_size_{batch_size}",
        f"{max_infer_iters}_max_infer_iters",
        f"activity_optim_{activity_optim_id}",
        f"activity_lr_{activity_lr}",
        f"activity_decay_{activity_decay}",
        f"weight_decay_{weight_decay}",
        f"spectral_penalty_{spectral_penalty}",
        f"{max_epochs}_epochs",
        str(seed)
    )
    test_accs = np.load(f"{save_path}/test_accs.npy")
    # fill NA accuracy with chance performance 
    if len(test_accs) == 0:
        test_accs = [10]

    test_accs_all_seeds[seed] = test_accs

plot_metric_stats(
    metric=test_accs_all_seeds, 
    metric_id="test_acc",
    test_every=test_every, 
    save_path=f"{results_dir}/mupc_128_MNIST_acc_5_epochs.pdf"
)

#### Fashion-MNIST results after 15 epochs

In [10]:
results_dir = "pcn_results"
dataset = "Fashion-MNIST"
act_fn = "relu"
width = 512
n_hidden = 128
use_skips = True
weight_init = "standard_gauss"
param_type = "mupc"
param_optim_id = "adam"
param_lr = 5e-2
batch_size = 64
max_infer_iters = 128
activity_optim_id = "gd"
activity_lr = 5e-1
activity_decay = 0
weight_decay = 0
spectral_penalty = 0
max_epochs = 15
test_every = 900
n_seeds = 3

test_accs_all_seeds = [[] for _ in range(n_seeds)]
for seed in range(n_seeds):
    save_path = os.path.join(
        results_dir,
        dataset,
        f"width_{width}",
        f"{n_hidden}_n_hidden",
        act_fn,
        "skips" if use_skips else "no_skips",
        f"{weight_init}_weight_init",
        f"{param_type}_param",
        f"param_optim_{param_optim_id}",
        f"param_lr_{param_lr}",
        f"batch_size_{batch_size}",
        f"{max_infer_iters}_max_infer_iters",
        f"activity_optim_{activity_optim_id}",
        f"activity_lr_{activity_lr}",
        f"activity_decay_{activity_decay}",
        f"weight_decay_{weight_decay}",
        f"spectral_penalty_{spectral_penalty}",
        f"{max_epochs}_epochs",
        str(seed)
    )
    test_accs = np.load(f"{save_path}/test_accs.npy")
    test_accs_all_seeds[seed] = test_accs

plot_metric_stats(
    metric=test_accs_all_seeds, 
    metric_id="test_acc",
    test_every=test_every, 
    save_path=f"{results_dir}/mupc_128_Fashion_acc_15_epochs.pdf"
)

#### Main figure for $\mu$PC vs PC

In [19]:
############ muPC vs PC vs BP with Depth-muP ############
n_hiddens = [8, 16, 32, 64, 128]
act_fns = ["relu"]  #"linear", "tanh", 

results_dir = "pcn_results"
dataset = "MNIST"
width = 512
use_skips = True
param_optim_id = "adam"
batch_size = 64
activity_optim_id = "gd"
activity_decay = 0
weight_decay = 0
spectral_penalty = 0
max_epochs = 1
test_every = 300
n_seeds = 3

for act_fn in act_fns:

    bp_test_accs_per_H = {}
    pc_test_accs_per_H = {}
    mupc_test_accs_per_H = {} 
    for n_hidden in n_hiddens:
        max_infer_iters = n_hidden
        
        bp_test_accs_all_seeds = [[] for seed in range(n_seeds)]
        pc_test_accs_all_seeds = [[] for seed in range(n_seeds)]
        mupc_test_accs_all_seeds = [[] for seed in range(n_seeds)]

        # muPC: best lrs based on act
        mupc_param_lr = 1e-1
        if act_fn == "linear":
            mupc_activity_lr = 1
        elif act_fn == "tanh":
            mupc_activity_lr = 1
        elif act_fn == "relu":
            mupc_activity_lr = 5e-1

        # PC: best lrs based on act & H
        if act_fn == "linear":
            pc_param_lr = 1e-2
            pc_activity_lr = 1e-2
        
        elif act_fn == "tanh":
            if n_hidden == 8:
                pc_param_lr = 1e-2
                pc_activity_lr = 1e-2
            elif n_hidden == 16:
                pc_param_lr = 1e-2
                pc_activity_lr = 1e-2
            elif n_hidden == 32:
                pc_param_lr = 1e-2
                pc_activity_lr = 1e-1
            elif n_hidden == 64:
                pc_param_lr = 1e-2
                pc_activity_lr = 1e-1
            elif n_hidden == 128:
                pc_param_lr = 1e-2
                pc_activity_lr = 1e-1
        
        elif act_fn == "relu":
            if n_hidden == 8:
                pc_param_lr = 1e-2
                pc_activity_lr = 1e-1
            elif n_hidden == 16:
                pc_param_lr = 1e-2
                pc_activity_lr = 1e-2
            elif n_hidden == 32:
                pc_param_lr = 1e-2
                pc_activity_lr = 1e-2
            elif n_hidden == 64:
                pc_param_lr = 1e-1
                pc_activity_lr = 1e-2
        
        for seed in range(n_seeds):
        
            pc_save_path = os.path.join(
                results_dir,
                dataset,
                f"width_{width}",
                f"{n_hidden}_n_hidden",
                act_fn,
                "skips" if use_skips else "no_skips",
                f"standard_weight_init",
                f"sp_param",
                f"param_optim_{param_optim_id}",
                f"param_lr_{pc_param_lr}",
                f"batch_size_{batch_size}",
                f"{max_infer_iters}_max_infer_iters",
                f"activity_optim_{activity_optim_id}",
                f"activity_lr_{pc_activity_lr}",
                f"activity_decay_{activity_decay}",
                f"weight_decay_{weight_decay}",
                f"spectral_penalty_{spectral_penalty}",
                f"{max_epochs}_epochs",
                str(seed)
            )
            mupc_save_path = os.path.join(
                results_dir,
                dataset,
                f"width_{width}",
                f"{n_hidden}_n_hidden",
                act_fn,
                "skips" if use_skips else "no_skips",
                f"standard_gauss_weight_init",
                f"mupc_param",
                f"param_optim_{param_optim_id}",
                f"param_lr_{mupc_param_lr}",
                f"batch_size_{batch_size}",
                f"{max_infer_iters}_max_infer_iters",
                f"activity_optim_{activity_optim_id}",
                f"activity_lr_{mupc_activity_lr}",
                f"activity_decay_{activity_decay}",
                f"weight_decay_{weight_decay}",
                f"spectral_penalty_{spectral_penalty}",
                f"{max_epochs}_epochs",
                str(seed)
            )
            if (n_hidden >= 64 and act_fn != "tanh") or (n_hidden == 32 and act_fn == "linear"):
                 # chance for failed runs
                pc_test_accs = np.array([10.] * 3)
            else:
                pc_test_accs = np.load(f"{pc_save_path}/test_accs.npy")
                
            mupc_test_accs = np.load(f"{mupc_save_path}/test_accs.npy") 

            if n_hidden == 128:
                bp_lr = 1e-3 if act_fn == "linear" else 1e-2
                bp_save_path = os.path.join(
                    "bp_results",
                    dataset,
                    f"width_{width}",
                    f"{n_hidden}_n_hidden",
                    act_fn,
                    f"depth_mup_param",
                    param_optim_id,
                    f"lr_{bp_lr}",
                    f"batch_size_{batch_size}",
                    f"{max_epochs}_epochs",
                    str(seed)
                )
                bp_test_accs = np.load(f"{bp_save_path}/test_accs.npy")   
                bp_test_accs_all_seeds[seed] = bp_test_accs
            
            pc_test_accs_all_seeds[seed] = pc_test_accs
            mupc_test_accs_all_seeds[seed] = mupc_test_accs

        pc_test_acc_means, pc_test_acc_stds = compute_metric_stats(pc_test_accs_all_seeds)
        mupc_test_acc_means, mupc_test_acc_stds = compute_metric_stats(mupc_test_accs_all_seeds)
        if n_hidden == 128:
            bp_test_acc_means, bp_test_acc_stds = compute_metric_stats(bp_test_accs_all_seeds)
        
        pc_test_accs_per_H[n_hidden] = (pc_test_acc_means, pc_test_acc_stds)
        mupc_test_accs_per_H[n_hidden] = (mupc_test_acc_means, mupc_test_acc_stds)
        if n_hidden == 128:
            bp_test_accs_per_H[n_hidden] = (bp_test_acc_means, bp_test_acc_stds)

    plot_mupc_vs_pc_vs_bp_accs(
        mupc_accs=mupc_test_accs_per_H, 
        pc_accs=pc_test_accs_per_H, 
        bp_accs=bp_test_accs_per_H,
        test_every=test_every, 
        act_fn=act_fn,
        save_path=f"{results_dir}/{act_fn}_mupc_vs_pc_best_accs.pdf", 
        height=350, 
        width=500,
        show_bp=True,
        show_pc=True,
        show_mupc=True
    )

create_legend(
    n_hiddens, 
    f"{results_dir}/n_hiddens_blue_legend.pdf", 
    colorscale="Blues",  # reversed or not
    starting_color=1,
    marker_size=5, 
    fontsize=12, 
    height=50, 
    width=1000
)
create_legend(
    n_hiddens, 
    f"{results_dir}/n_hiddens_red_legend.pdf", 
    colorscale="Reds",
    starting_color=2, # 2 if not reversed, 0 otherwise
    marker_size=5, 
    fontsize=12, 
    height=50, 
    width=1000
)
create_legend(
    n_hiddens, 
    f"{results_dir}/n_hiddens_black_legend.pdf", 
    colorscale="gray_r",
    starting_color=2,  # 2 if reversed, 0 otherwise
    marker_size=5, 
    fontsize=12, 
    height=50, 
    width=1000
)

#### 128 relu net ($\mu$PC vs PC vs BP)

In [8]:
def plot_mupc_vs_pc_vs_bp_accs_128_layer_net(
        bp_accs, 
        pc_accs, 
        mupc_accs, 
        dataset,
        test_every,
        save_path, 
        height=300, 
        width=450
):
    n_iters = len(bp_accs[0])
    iters = [t for t in range(n_iters)]

    bp_color = "#222A2A"
    pc_color = "#EF553B"
    mupc_color = "#636EFA"

    fig = go.Figure() 
    
    bp_means, bp_stds = bp_accs[0], bp_accs[1]
    pc_means, pc_stds = pc_accs[0], pc_accs[1]
    mupc_means, mupc_stds = mupc_accs[0], mupc_accs[1]
    
    bp_y_upper, bp_y_lower = bp_means + bp_stds, bp_means - bp_stds
    pc_y_upper, pc_y_lower = pc_means + pc_stds, pc_means - pc_stds
    mupc_y_upper, mupc_y_lower = mupc_means + mupc_stds, mupc_means - mupc_stds

    fig.add_trace(
        go.Scatter(
            x=list(iters) + list(iters[::-1]),
            y=list(bp_y_upper) + list(bp_y_lower[::-1]),
            fill="toself",
            fillcolor=bp_color,
            line=dict(color="rgba(255,255,255,0)"),
            hoverinfo="skip",
            showlegend=False,
            opacity=0.3
        )
    )
    fig.add_trace(
        go.Scatter(
            x=list(iters) + list(iters[::-1]),
            y=list(pc_y_upper) + list(pc_y_lower[::-1]),
            fill="toself",
            fillcolor=pc_color,
            line=dict(color="rgba(255,255,255,0)"),
            hoverinfo="skip",
            showlegend=False,
            opacity=0.3
        )
    )
    fig.add_trace(
        go.Scatter(
            x=list(iters) + list(iters[::-1]),
            y=list(mupc_y_upper) + list(mupc_y_lower[::-1]),
            fill="toself",
            fillcolor=mupc_color,
            line=dict(color="rgba(255,255,255,0)"),
            hoverinfo="skip",
            showlegend=False,
            opacity=0.3
        )
    )

    fig.add_trace(
        go.Scatter(
            x=iters,
            y=bp_means,
            mode="lines",
            line=dict(width=2, color=bp_color),
            name="$BP$"
        )
    )
    fig.add_trace(
        go.Scatter(
            x=iters,
            y=pc_means,
            mode="lines",
            line=dict(width=3, color=pc_color, dash="dash"),
            name="$PC$"
        )
    )
    fig.add_trace(
        go.Scatter(
            x=iters,
            y=mupc_means,
            mode="markers",
            line=dict(width=2, color=mupc_color),
            name="$\mu PC$"
        )
    )
            
    xtickvals = [0, int(iters[-1] / 2), iters[-1]]
    if dataset == "MNIST":
        xticktext = [(t+1) * test_every for t in xtickvals]
        ytickvals = [10, 55, 100] if dataset == "MNIST" else [10, 45, 90]
    else:
        xticktext = [t+1 for t in xtickvals]
        ytickvals = [10, 50, 90]
        
    fig.update_layout(
        height=height,
        width=width,
        xaxis=dict(
            title="Training iteration" if dataset == "MNIST" else "Epoch",
            tickvals=xtickvals,
            ticktext=xticktext
        ),
        yaxis=dict(
            title="Test accuracy (%)",
            tickvals=ytickvals,
            ticktext=ytickvals
        ),
        font=dict(size=16),
        margin=dict(r=120)
    )
    fig.write_image(save_path)


In [9]:
results_dir = "pcn_results"
dataset = "MNIST"  # Fashion-MNIST
width = 512
n_hidden = 128
act_fn =  "relu"
use_skips = True
param_optim_id = "adam"
batch_size = 64
max_infer_iters = n_hidden
activity_optim_id = "gd"
activity_decay = 0
weight_decay = 0
spectral_penalty = 0
max_epochs = 5    # 5 or 15
test_every = 300  # 300 or 900
n_seeds = 3

# NOTE: PC runs diverge with the above params for any lrs and dataset   
# best lrs for relu
pc_param_lr = 1e-1
pc_activity_lr = 1e-2
mupc_param_lr = 1e-1 if dataset == "MNIST" else 5e-2
mupc_activity_lr = 5e-1
bp_lr = 5e-3  # for both MNIST and Fashion-MNIST

bp_test_accs_all_seeds = [[] for seed in range(n_seeds)]
pc_test_accs_all_seeds = [[] for seed in range(n_seeds)]
mupc_test_accs_all_seeds = [[] for seed in range(n_seeds)]
for seed in range(n_seeds):

    bp_save_path = os.path.join(
        "bp_results",
        dataset,
        f"width_{width}",
        f"{n_hidden}_n_hidden",
        act_fn,
        f"depth_mup_param",
        param_optim_id,
        f"lr_{bp_lr}",
        f"batch_size_{batch_size}",
        f"{max_epochs}_epochs",
        str(seed)
    )
    pc_save_path = os.path.join(
        results_dir,
        dataset,
        f"width_{width}",
        f"{n_hidden}_n_hidden",
        act_fn,
        "skips" if use_skips else "no_skips",
        f"standard_weight_init",
        f"sp_param",
        f"param_optim_{param_optim_id}",
        f"param_lr_{pc_param_lr}",
        f"batch_size_{batch_size}",
        f"{max_infer_iters}_max_infer_iters",
        f"activity_optim_{activity_optim_id}",
        f"activity_lr_{pc_activity_lr}",
        f"activity_decay_{activity_decay}",
        f"weight_decay_{weight_decay}",
        f"spectral_penalty_{spectral_penalty}",
        f"{max_epochs}_epochs",
        str(seed)
    )
    mupc_save_path = os.path.join(
        results_dir,
        dataset,
        f"width_{width}",
        f"{n_hidden}_n_hidden",
        act_fn,
        "skips" if use_skips else "no_skips",
        f"standard_gauss_weight_init",
        f"mupc_param",
        f"param_optim_{param_optim_id}",
        f"param_lr_{mupc_param_lr}",
        f"batch_size_{batch_size}",
        f"{max_infer_iters}_max_infer_iters",
        f"activity_optim_{activity_optim_id}",
        f"activity_lr_{mupc_activity_lr}",
        f"activity_decay_{activity_decay}",
        f"weight_decay_{weight_decay}",
        f"spectral_penalty_{spectral_penalty}",
        f"{max_epochs}_epochs",
        str(seed)
    )
    bp_test_accs = np.load(f"{bp_save_path}/test_accs.npy")   
    pc_test_accs = np.array([10.] * len(bp_test_accs))
    mupc_test_accs = np.load(f"{mupc_save_path}/test_accs.npy") 
  
    bp_test_accs_all_seeds[seed] = bp_test_accs
    pc_test_accs_all_seeds[seed] = pc_test_accs
    mupc_test_accs_all_seeds[seed] = mupc_test_accs

bp_test_acc_stats = compute_metric_stats(bp_test_accs_all_seeds)
pc_test_acc_stats = compute_metric_stats(pc_test_accs_all_seeds)
mupc_test_acc_stats = compute_metric_stats(mupc_test_accs_all_seeds)

plot_mupc_vs_pc_vs_bp_accs_128_layer_net(
    bp_accs=bp_test_acc_stats, 
    pc_accs=pc_test_acc_stats, 
    mupc_accs=mupc_test_acc_stats, 
    dataset=dataset,
    test_every=test_every,
    save_path=f"{results_dir}/{dataset}/mupc_vs_pc_vs_bp_128_relu_net.pdf", 
    height=300, 
    width=425
)


### PC baseline

In [None]:
############ best SP results for depth at fixed width N = 512 ############
n_hiddens = [8, 16, 32, 64, 128]
act_fns = ["linear", "tanh", "relu"]

results_dir = "pcn_results"
dataset = "MNIST"
width = 512
use_skips = True
weight_init = "standard"
param_type = "sp"
param_optim_id = "adam"
batch_size = 64
activity_optim_id = "gd"
activity_decay = 0
weight_decay = 0
spectral_penalty = 0
max_epochs = 1
test_every = 300
n_seeds = 3

for act_fn in act_fns:
    print()
    test_accs_per_H = {} 
    for n_hidden in n_hiddens:
        max_infer_iters = n_hidden
        
        test_accs_all_seeds = [[] for seed in range(n_seeds)]

        # best lr based on act
        if act_fn == "linear":
            param_lr = 1e-2
            activity_lr = 1e-2
        
        elif act_fn == "tanh":
            if n_hidden == 8:
                param_lr = 1e-2
                activity_lr = 1e-2
            elif n_hidden == 16:
                param_lr = 1e-2
                activity_lr = 1e-2
            elif n_hidden == 32:
                param_lr = 1e-2
                activity_lr = 1e-1
            elif n_hidden == 64:
                param_lr = 1e-2
                activity_lr = 1e-1
            elif n_hidden == 128:
                param_lr = 1e-2
                activity_lr = 1e-1
        
        elif act_fn == "relu":
            if n_hidden == 8:
                param_lr = 1e-2
                activity_lr = 1e-1
            elif n_hidden == 16:
                param_lr = 1e-2
                activity_lr = 1e-2
            elif n_hidden == 32:
                param_lr = 1e-2
                activity_lr = 1e-2
            elif n_hidden == 64:
                param_lr = 1e-1
                activity_lr = 1e-2
        
        for seed in range(n_seeds):
        
            save_path = os.path.join(
                results_dir,
                dataset,
                f"width_{width}",
                f"{n_hidden}_n_hidden",
                act_fn,
                "skips" if use_skips else "no_skips",
                f"{weight_init}_weight_init",
                f"{param_type}_param",
                f"param_optim_{param_optim_id}",
                f"param_lr_{param_lr}",
                f"batch_size_{batch_size}",
                f"{max_infer_iters}_max_infer_iters",
                f"activity_optim_{activity_optim_id}",
                f"activity_lr_{activity_lr}",
                f"activity_decay_{activity_decay}",
                f"weight_decay_{weight_decay}",
                f"spectral_penalty_{spectral_penalty}",
                f"{max_epochs}_epochs",
                str(seed)
            )
            if (n_hidden >= 64 and act_fn != "tanh") or (n_hidden == 32 and act_fn == "linear"):
                 # chance for failed runs
                test_accs = np.array([10.] * 3)
            else:
                test_accs = np.load(f"{save_path}/test_accs.npy")

            test_accs_all_seeds[seed] = test_accs

        test_acc_means, test_acc_stds = compute_metric_stats(test_accs_all_seeds)
        print(f"mean test accs for {act_fn} and H = {n_hidden}: {test_acc_means}")
        test_accs_per_H[n_hidden] = (test_acc_means, test_acc_stds)

    plot_metric_per_iv(
        metric=test_accs_per_H,
        metric_id="test_acc", 
        iv_id="n_hidden",
        test_every=test_every, 
        save_path=f"{results_dir}/{dataset}/best_test_accs_depth_SP_{act_fn}.pdf"
    )


### BP with Depth-$\mu$P baseline

Learning rate sweep for $N = 512$ and $H =8$.
* MNIST: 1 epoch and 5 epochs (only relu)
* Fashion-MNIST: 15 epochs (relu)

In [50]:
results_dir = "bp_results"
dataset = "Fashion-MNIST"
width = 512
n_hidden = 8
param_type = "depth_mup"
optim_id = "adam"
batch_size = 64
max_epochs = 15   # 1, 5 or 15
test_every = 900  # 300 for MNIST, 900 Fashion
n_seeds = 3

act_fns = ["relu"]
lrs = [1, 5e-1, 1e-1, 5e-2, 1e-2, 5e-3, 1e-3, 5e-4, 1e-4]

for act_fn in act_fns:
    test_accs_per_lr = {} 
    for lr in lrs:
        
        test_accs_all_seeds = [[] for seed in range(n_seeds)]
        for seed in range(n_seeds):
            save_dir = os.path.join(
                results_dir,
                dataset,
                f"width_{width}",
                f"{n_hidden}_n_hidden",
                act_fn,
                f"{param_type}_param",
                optim_id,
                f"lr_{lr}",
                f"batch_size_{batch_size}",
                f"{max_epochs}_epochs",
                str(seed)
            )
            test_accs = np.load(f"{save_dir}/test_accs.npy")
            if len(test_accs) > 1:
                test_accs_all_seeds[seed] = test_accs
            else:
                test_accs_all_seeds[seed] = test_accs = np.array([10.] * 3)
                        
        test_acc_means, test_acc_stds = compute_metric_stats(test_accs_all_seeds)
        test_accs_per_lr[lr] = (test_acc_means, test_acc_stds)
    
    plot_metric_per_iv(
        metric=test_accs_per_lr,
        metric_id="test_acc", 
        iv_id="activity_lr",
        test_every=test_every, 
        save_path=f"{results_dir}/{dataset}/{act_fn}_best_test_accs.pdf",
        height=400,
        width=575
    )

# for MNIST:
# for 1 epoch:
# best linear: 1e-3
# best tanh & relu: 1e-2
# for 5 epochs best for relu: 5e-3

# for Fashion: best for relu 15 epochs: 5e-3

## Activity decay

In [21]:
def plot_accs_per_activity_decay(accs, activity_decays, save_path):
    means, stds = accs.mean(axis=0), accs.std(axis=0)
    y_upper, y_lower = means + stds, means - stds
    
    n_iters = len(means)
    iters = [t for t in range(n_iters)]
    
    color = "#636EFA"
    fig = go.Figure()
    fig.add_trace(
        go.Scatter(
            x=list(iters) + list(iters[::-1]),
            y=list(y_upper) + list(y_lower[::-1]),
            fill="toself",
            fillcolor=color,
            line=dict(color="rgba(255,255,255,0)"),
            hoverinfo="skip",
            showlegend=False,
            opacity=0.3
        )
    )
    fig.add_trace(
        go.Scatter(
            y=means,
            mode="lines+markers",
            line=dict(width=2, color=color),
            showlegend=False
        )
    )
    fig.update_layout(
        height=300,
        width=450,
        xaxis=dict(
            title="Activity decay",
            tickvals=iters,
            ticktext=[str(v) for v in activity_decays]
        ),
        yaxis=dict(title="Test accuracy (%)"),
        font=dict(size=16)
    )
    fig.write_image(save_path)


In [22]:
save_path = "pcn_results/MNIST/width_128/8_n_hidden/linear/no_skips/standard_weight_init/sp_param/param_optim_adam/param_lr_0.001/batch_size_64/500_max_infer_iters/activity_optim_gd/"
activity_lrs = [0.5, 0.1, 0.05]
activity_decays = [1, 5e-1, 1e-1, 5e-2, 1e-2, 0]
weight_decay = 0
spectral_penalty = 0
max_epochs = 1
n_seeds = 3

for activity_lr in activity_lrs:

    test_accs_seeds = np.zeros((n_seeds, len(activity_decays)))
    for i, activity_decay in enumerate(activity_decays):
        for seed in range(n_seeds):
            
            run_path = os.path.join(
                save_path,
                f"activity_lr_{activity_lr}",
                f"activity_decay_{activity_decay}",
                f"weight_decay_{weight_decay}",
                f"spectral_penalty_{spectral_penalty}",
                f"{max_epochs}_epochs",
                str(seed)
            )
            test_accs = np.load(f"{run_path}/test_accs.npy")
            test_accs_seeds[seed, i] = max(test_accs)

        plot_accs_per_activity_decay(
            test_accs_seeds, 
            activity_decays, 
            f"{save_path}/activity_lr_{activity_lr}/max_test_accs_per_activity_decay.pdf"
        )

## Hyperparameter transfer results

In [3]:
def plot_learning_rates_contour(
        metric, 
        activity_lrs, 
        param_lrs, 
        metric_id, 
        save_path,
        smooth_contours=True,
        show_axes_label=False
):
    colorscale = "RdBu_r" if metric_id == "loss" else "Greens"
    contours_coloring = "heatmap" if smooth_contours else "fill"

    metric_min, metric_max = metric.min(), metric.max()
    emin = int(np.floor(np.log10(metric_min)))
    emax = int(np.ceil(np.log10(metric_max)))

    num_ticks = 6
    tickvals = np.linspace(emin, emax, num=num_ticks)
    ticktext = [f"10<sup>{e:.1f}</sup>" for e in tickvals]

    contour = go.Contour(
        z=np.log10(metric),
        x=activity_lrs,
        y=param_lrs,
        colorscale=colorscale,
        showscale=True,
        contours_coloring=contours_coloring,
        colorbar=dict(
            len=1.08,
            title="Training loss" if (
                metric_id == "loss"
            ) else "Test accuracy (%)",
            title_side="right",
            tickfont=dict(size=16),
            tickvals=tickvals,
            ticktext=ticktext
        )
        # contours=dict(
        #     showlabels=True,
        #     labelfont=dict(
        #         size=10, 
        #         color="white"
        #     ),
        # )
    )
    fig = go.Figure(data=[contour])
         
    fig.update_layout(
        xaxis=dict(
            nticks=5, 
            type="log",
            exponentformat="power"
        ),
        yaxis=dict(
            nticks=2, 
            type="log",
            exponentformat="power"
        ),
        font=dict(size=18),
        plot_bgcolor="white",
        width=500, 
        height=400,
        margin=dict(
            r=100, 
            b=100,
            l=50, 
            t=80
        )
    )
    if show_axes_label:
        fig.update_layout(
            xaxis_title="Activity lr",
            yaxis_title="Weight lr",
        )
   
    fig.write_image(save_path)


#### Width

In [4]:
results_dir = "pcn_results"
dataset = "MNIST"
act_fns = ["linear", "tanh", "relu"]
n_hidden = 8
use_skips = True
weight_init = "standard_gauss"
param_type = "μP"
param_optim_id = "adam"
batch_size = 64
max_infer_iters = 8
activity_optim_id = "gd"
activity_decay = 0
weight_decay = 0
spectral_penalty = 0
max_epochs = 1
test_every = 300
n_seeds = 3

widths = [2**i for i in range(6, 11)]
param_lrs = [5e-1, 1e-1, 5e-2, 1e-2]
activity_lrs = [1000, 500, 100, 50, 10, 5, 1, 5e-1, 1e-1, 5e-2, 1e-2]

for act_fn in act_fns:     
    for width in widths:

        # to skip one failed run
        if max_infer_iters == 500:
            n_seeds = 2 if (
                dataset == "MNIST" and act_fn == "linear" and width == 64
            ) else 3
            
        min_train_losses = np.zeros((len(param_lrs), len(activity_lrs), n_seeds))
        max_test_accs = np.zeros_like(min_train_losses)
        for i, param_lr in enumerate(param_lrs):
            for j, activity_lr in enumerate(activity_lrs):
                for seed in range(n_seeds):
                
                    save_path = os.path.join(
                        results_dir,
                        dataset,
                        f"width_{width}",
                        f"{n_hidden}_n_hidden",
                        act_fn,
                        "1_skip" if use_skips else "no_skips",
                        f"{weight_init}_weight_init",
                        f"{param_type}_param",
                        f"param_optim_{param_optim_id}",
                        f"param_lr_{param_lr}",
                        f"batch_size_{batch_size}",
                        f"{max_infer_iters}_max_infer_iters",
                        f"activity_optim_{activity_optim_id}",
                        f"activity_lr_{activity_lr}",
                        f"activity_decay_{activity_decay}",
                        f"weight_decay_{weight_decay}",
                        f"spectral_penalty_{spectral_penalty}",
                        f"{max_epochs}_epochs",
                        str(seed)
                    )
                    batch_train_losses = np.load(f"{save_path}/batch_train_losses.npy")
                    test_accs = np.load(f"{save_path}/test_accs.npy")

                    # to skip failed runs
                    try:
                        min_train_losses[i, j, seed] = min(batch_train_losses)
                        max_test_accs[i, j, seed] = max(test_accs)
                    except:
                        print(f"failed run for act fn {act_fn}, width = {width}, seed {seed}")
                        pass
                    
        plot_learning_rates_contour(
            metric=min_train_losses.mean(axis=-1), 
            activity_lrs=activity_lrs, 
            param_lrs=param_lrs, 
            metric_id="loss", 
            save_path=f"{results_dir}/min_train_losses_over_lrs_{act_fn}_width_{width}.pdf"
        )


#### Depth

In [5]:
### NOTE: this can be used to plot both SP and muPC results ###
results_dir = "pcn_results"
dataset = "MNIST"
act_fns = ["linear", "tanh", "relu"]
width = 512
use_skips = True
weight_init = "standard_gauss"
param_type = "μP" 
param_optim_id = "adam"
batch_size = 64
activity_optim_id = "gd"
activity_decay = 0
weight_decay = 0
spectral_penalty = 0
max_epochs = 1
test_every = 100
n_seeds = 3

n_hiddens = [2**i for i in range(3, 8)]  #7 for SP
param_lrs = [5e-1, 1e-1, 5e-2, 1e-2]
activity_lrs = [1000, 500, 100, 50, 10, 5, 1, 5e-1, 1e-1, 5e-2, 1e-2]

for act_fn in act_fns:     
    for n_hidden in n_hiddens:
        max_infer_iters = n_hidden
 
        min_train_losses = np.zeros((len(param_lrs), len(activity_lrs), n_seeds))
        max_test_accs = np.zeros_like(min_train_losses)
        for i, param_lr in enumerate(param_lrs):
            for j, activity_lr in enumerate(activity_lrs):
                for seed in range(n_seeds):
                
                    save_path = os.path.join(
                        results_dir,
                        dataset,
                        f"width_{width}",
                        f"{n_hidden}_n_hidden",
                        act_fn,
                        "1_skip" if use_skips else "no_skips",
                        f"{weight_init}_weight_init",
                        f"{param_type}_param",
                        f"param_optim_{param_optim_id}",
                        f"param_lr_{param_lr}",
                        f"batch_size_{batch_size}",
                        f"{max_infer_iters}_max_infer_iters",
                        f"activity_optim_{activity_optim_id}",
                        f"activity_lr_{activity_lr}",
                        f"activity_decay_{activity_decay}",
                        f"weight_decay_{weight_decay}",
                        f"spectral_penalty_{spectral_penalty}",
                        f"{max_epochs}_epochs",
                        str(seed)
                    )
                    batch_train_losses = np.load(f"{save_path}/batch_train_losses.npy")
                    test_accs = np.load(f"{save_path}/test_accs.npy")

                    # to skip failed runs
                    try:
                        min_train_losses[i, j, seed] = min(batch_train_losses)
                        max_test_accs[i, j, seed] = max(test_accs)
                    except:
                        #print(f"failed run for act fn {act_fn}, width = {width}, seed {seed}")
                        pass

        plot_learning_rates_contour(
            metric=min_train_losses.mean(axis=-1), 
            activity_lrs=activity_lrs, 
            param_lrs=param_lrs, 
            metric_id="loss", 
            save_path=f"{results_dir}/min_train_losses_over_lrs_{act_fn}_depth_{n_hidden}.pdf"
        )


Extra plot for Fashion-MNIST ($H = 8, N = 512$)

In [5]:
results_dir = "pcn_results"
dataset = "Fashion-MNIST"
act_fn = ["linear", "tanh", "relu"]
n_hidden = 8
width = 512
use_skips = True
weight_init = "standard_gauss"
param_type = "mupc"
param_optim_id = "adam"
batch_size = 64
max_infer_iters = 8
activity_optim_id = "gd"
activity_decay = 0
weight_decay = 0
spectral_penalty = 0
max_epochs = 1
test_every = 100
n_seeds = 3

param_lrs = [5e-1, 1e-1, 5e-2, 1e-2]
activity_lrs = [1000, 500, 100, 50, 10, 5, 1, 5e-1, 1e-1, 5e-2, 1e-2]

for act_fn in act_fns:     
    min_train_losses = np.zeros((len(param_lrs), len(activity_lrs), n_seeds))
    max_test_accs = np.zeros_like(min_train_losses)
    for i, param_lr in enumerate(param_lrs):
        for j, activity_lr in enumerate(activity_lrs):
            for seed in range(n_seeds):
            
                save_path = os.path.join(
                    results_dir,
                    dataset,
                    f"width_{width}",
                    f"{n_hidden}_n_hidden",
                    act_fn,
                    "skips" if use_skips else "no_skips",
                    f"{weight_init}_weight_init",
                    f"{param_type}_param",
                    f"param_optim_{param_optim_id}",
                    f"param_lr_{param_lr}",
                    f"batch_size_{batch_size}",
                    f"{max_infer_iters}_max_infer_iters",
                    f"activity_optim_{activity_optim_id}",
                    f"activity_lr_{activity_lr}",
                    f"activity_decay_{activity_decay}",
                    f"weight_decay_{weight_decay}",
                    f"spectral_penalty_{spectral_penalty}",
                    f"{max_epochs}_epochs",
                    str(seed)
                )
                batch_train_losses = np.load(f"{save_path}/batch_train_losses.npy")
                test_accs = np.load(f"{save_path}/test_accs.npy")

                # to skip failed runs
                try:
                    min_train_losses[i, j, seed] = min(batch_train_losses)
                    max_test_accs[i, j, seed] = max(test_accs)
                except:
                    print(f"failed run for act fn {act_fn}, width = {width}, seed {seed}")
                    pass
                
    plot_learning_rates_contour(
        metric=min_train_losses.mean(axis=-1), 
        activity_lrs=activity_lrs, 
        param_lrs=param_lrs, 
        metric_id="loss", 
        save_path=f"{results_dir}/min_train_losses_over_lrs_{dataset}_{act_fn}_width_{width}.pdf",
        show_axes_label=False
    )
    plot_learning_rates_contour(
        metric=max_test_accs.mean(axis=-1), 
        activity_lrs=activity_lrs, 
        param_lrs=param_lrs, 
        metric_id="accuracy", 
        save_path=f"{results_dir}/max_test_accs_over_lrs_{dataset}_{act_fn}_width_{width}.pdf",
        show_axes_label=False
    )


## Theory energy results

In [3]:
from matplotlib.colors import TwoSlopeNorm

In [37]:
def get_label_from_iv(iv):
    if iv == 1:
        label = "2^0"
    if iv == 2:
        label = "2^1"
    if iv == 4:
        label = "2^2"
    if iv == 8:
        label = "2^3"
    elif iv == 16:
        label = "2^4"
    elif iv == 32:
        label = "2^5"
    elif iv == 64:
        label = "2^6"
    elif iv == 128:
        label = "2^7"
    elif iv == 256:
        label = "2^8"
    elif iv == 512:
        label = "2^9"
    return label

        
def plot_loss_energy_ratio_phase_diagram(
        metric, 
        colorbar_title, 
        param_type,
        save_path, 
        cmap="inferno",
        title=None,
        show_cells=False,
        norm=None
    ):
    n_widths, n_hiddens = metric.shape[0], metric.shape[1]

    fig, ax = plt.subplots()
    im = ax.imshow(
        metric, 
        norm=norm,
        origin="lower", 
        interpolation="bicubic",
        cmap=cmap
    )
    ax.set_xlabel("$H$", fontsize=30, labelpad=15)
    ax.set_ylabel("$N$", fontsize=30, labelpad=15)
    if title is not None:
        ax.set_title(title, fontsize=30, pad=20)
    
    cbar = fig.colorbar(im, ax=ax)
    cbar.set_label(
        colorbar_title, 
        fontsize=30, 
        labelpad=15
    )
    cbar.ax.tick_params(labelsize=18)

    xtick_positions = [i for i in range(n_widths)]
    ytick_positions = [i for i in range(n_hiddens)]
    tick_labels = [f"$2^{i}$" for i in range(n_hiddens)]
    
    ax.set_xticks(xtick_positions)
    ax.set_yticks(ytick_positions)
    ax.set_xticklabels(tick_labels, fontsize=18)
    ax.set_yticklabels(tick_labels, fontsize=18)

    if show_cells:
        for i in range(n_widths):
            for j in range(n_hiddens):
                plt.text(
                    j, 
                    i, 
                    f"{metric[i, j]:.2f}", 
                    color="white", 
                    ha="center", 
                    va="center", 
                    fontsize=8
                )
    
    fig.savefig(save_path, bbox_inches="tight")
    plt.close("all")
    return fig


def plot_metrics_per_iv(
        first_metric,
        second_metric,
        ivs,
        iv_id,
        save_path,
        plot_second_metric=True
    ):
    fig = go.Figure()

    n_train_iters = len(first_metric[ivs[0]])
    train_iters = [i for i in range(n_train_iters)]
    
    n_ivs = len(ivs)
    colorscale = "Purples" if plot_second_metric else "Reds"
    colors = pc.sample_colorscale(colorscale, n_ivs + 3)[3:]
    for i, iv in enumerate(ivs):
        label = get_label_from_iv(iv)
        fig.add_traces(
            go.Scatter(
                x=train_iters,
                y=first_metric[iv],
                name=f"${{{iv_id}}} = {{{label}}}$",
                mode="lines",
                line=dict(width=2, color=colors[i]),
                opacity=0.8
            )
        )

    if plot_second_metric:
        fig.add_traces(
            go.Scatter(
                x=train_iters,
                y=second_metric[ivs[-1]],
                name="loss",
                mode="lines",
                line=dict(width=1, color="#EF553B"),
                legendrank=0
            )
        )

    ticks = [0, int(train_iters[-1]/2), train_iters[-1]]
    fig.update_layout(
        height=350,
        width=550,
        xaxis=dict(
            title="Training iteration",
            tickvals=ticks,
            ticktext=ticks
        ),
        yaxis=dict(title="Equilib. energy" if plot_second_metric else "Loss"),
        font=dict(size=18),
        margin=dict(r=140, b=90, l=110)
    )
    fig.write_image(save_path)


def plot_loss_energy_ratio_over_training(
        loss_energy_ratios,
        ivs,
        iv_id,
        save_path
    ):
    fig = go.Figure()

    n_train_iters = len(loss_energy_ratios[ivs[0]])
    train_iters = [i for i in range(n_train_iters)]
    
    n_ivs = len(ivs)
    colors = pc.sample_colorscale("Viridis", n_ivs + 3)[3:]  # Plasma or Inferno
    for i, iv in enumerate(ivs):
        label = get_label_from_iv(iv)
        fig.add_traces(
            go.Scatter(
                x=train_iters,
                y=loss_energy_ratios[iv],
                name=f"${{{iv_id}}} = {{{label}}}$",
                mode="lines",
                line=dict(width=2, color=colors[i]),
                opacity=0.8
            )
        )

    ticks = [0, int(train_iters[-1]/2), train_iters[-1]]
    fig.update_layout(
        height=350,
        width=550,
        xaxis=dict(
            title="Training iteration",
            tickvals=ticks,
            ticktext=ticks
        ),
        yaxis=dict(title="Ratio"),
        font=dict(size=18),
        margin=dict(r=140, b=90, l=110)
    )
    fig.write_image(save_path)


Vary depth & width.

In [56]:
RESULTS_DIR = "energy_theory_results"
USE_SKIPS = True
PARAM_OPTIM_ID = "adam"
ACT_FN = "linear"

PARAM_TYPE = "sp"   #"mupc", 
PARAM_LR = 1e-1 if PARAM_TYPE == "mupc" else 1e-4

WIDTHS = [2**i for i in range(7)]
N_HIDDENS = [2**i for i in range(7)]
SEED = 4320

DATASET_SIZE = 60000
BATCH_SIZE = 64
n_train_iters = DATASET_SIZE // BATCH_SIZE

PLOT_TS = [0, 5, 10, 50, 100, 200, 450, 900]
vmin = 1
vmax = 10**3 if PARAM_TYPE == "mupc" else 10**7

loss_energy_ratios_per_N_L = np.zeros(
    (len(WIDTHS), 
     len(N_HIDDENS), 
     n_train_iters)
)
for i, width in enumerate(WIDTHS):
    for j, n_hidden in enumerate(N_HIDDENS):
        save_dir = os.path.join(
            RESULTS_DIR,
            ACT_FN,
            f"{PARAM_TYPE}",
            "skips" if USE_SKIPS else "no_skips",
            PARAM_OPTIM_ID,
            f"param_lr_{PARAM_LR}",
            f"width_{width}",
            f"{n_hidden}_n_hidden",
            str(SEED)
        )
        loss_energy_ratios = np.load(f"{save_dir}/loss_energy_ratios.npy")
        loss_energy_ratios_per_N_L[i, j] = loss_energy_ratios

for t in PLOT_TS:
    plot_loss_energy_ratio_phase_diagram(
        loss_energy_ratios_per_N_L[:, :, t], 
        colorbar_title="$\mathcal{L} / \mathcal{F}^*$", 
        param_type=PARAM_TYPE,
        save_path=f"{save_dir}/loss_energy_ratio_t_{t}.pdf",
        cmap="inferno",
        title=f"$t = {{{t}}}$",
        show_cells=True if PARAM_TYPE == "mupc" else False,
        norm=LogNorm(vmin=vmin, vmax=vmax)
    )

Vary width, fix depth.

In [13]:
RESULTS_DIR = "energy_theory_results"
USE_SKIPS = True
PARAM_OPTIM_ID = "adam"
PARAM_LR = 1e-3

ACT_FNS = ["linear", "tanh", "relu"]
PARAM_TYPES = ["sp", "mupc"]
WIDTHS = [2**i for i in range(10)]
N_HIDDEN = 4
SEED = 4320

DATASET_SIZE = 60000
BATCH_SIZE = 64
n_train_iters = DATASET_SIZE // BATCH_SIZE

losses_per_width = { n: np.zeros(n_train_iters) for n in WIDTHS }
energies_per_width = { n: np.zeros(n_train_iters) for n in WIDTHS }
loss_energy_ratios_per_width = { n: np.zeros(n_train_iters) for n in WIDTHS }
for act_fn in ACT_FNS:
    for param_type in PARAM_TYPES:
        for width in WIDTHS:
            save_dir = os.path.join(
                RESULTS_DIR,
                act_fn,
                f"{param_type}",
                "skips" if USE_SKIPS else "no_skips",
                PARAM_OPTIM_ID,
                f"param_lr_{PARAM_LR}",
                f"width_{width}",
                f"{N_HIDDEN}_n_hidden",
                str(SEED)
            )
            losses = np.load(f"{save_dir}/train_losses.npy")
            energies = np.load(f"{save_dir}/train_energies.npy")
            loss_energy_ratios = np.load(f"{save_dir}/loss_energy_ratios.npy")
            
            losses_per_width[width] = losses
            energies_per_width[width] = energies
            loss_energy_ratios_per_width[width] = loss_energy_ratios

        plot_metrics_per_iv(
            first_metric=energies_per_width,
            second_metric=losses_per_width,
            ivs=WIDTHS[1::2],
            iv_id="N", 
            save_path=f"{save_dir}/energies_per_width.pdf",
            plot_second_metric=True
        )
        plot_metrics_per_iv(
            first_metric=losses_per_width,
            second_metric=energies_per_width,
            ivs=WIDTHS[1::2],
            iv_id="N", 
            save_path=f"{save_dir}/losses_per_width.pdf",
            plot_second_metric=False
        )
        plot_loss_energy_ratio_over_training(
            loss_energy_ratios=loss_energy_ratios_per_width,
            ivs=WIDTHS[1::2],
            iv_id="N",
            save_path=f"{save_dir}/loss_energy_ratios_per_width.pdf"
        )
