# Toy experiments

## Main Contents
- 🏔️ [Inference landscape geometry of 1-MLPs](#Inference-landscape-geometry-of-1-MLPs)
- 🏔️ [Inference landscape geometry of 2-MLPs](#Inference-landscape-geometry-of-2-MLPs)
- 🎢 [Inference dynamics of 2-MLPs](#Inference-dynamics-of-2-MLPs)
- 🏔️ [Inference landscape slices of deep & wide MLPs](#Inference-landscape-slices-of-deep-&-wide-MLPs)
- 🎢 [Inference and learning dynamics of deep chains](#Inference-and-learning-dynamics-of-deep-chains)
- 🎢 [Training with $z^*$](#Training-with-$z^*$)


## Installations & imports

In [1]:
%%capture
!pip install kaleido==0.2.1
!pip install plotly==5.24.1

In [2]:
import os
import numpy as np

import jax
import jax.random as jr
import jax.numpy as jnp
from jax import value_and_grad

import jpc
import equinox as eqx
import optax
from diffrax import (
    Euler, 
    Heun, 
    Dopri5, 
    ConstantStepSize,
    PIDController
)

import plotly.graph_objs as go
import plotly.colors as pc
import plotly.figure_factory as ff

import warnings
warnings.simplefilter('ignore')  # ignore warnings
os.environ["BROWSER_PATH"] = "/home/myhome/chrome-headless-shell/linux-132.0.6834.83/chrome-headless-shell-linux64/chrome-headless-shell"

## Utils

In [3]:
def make_gaussian_dataset(key, mean, std, batch_size):
    x = mean + std * jr.normal(key, (batch_size, 1))
    y = x
    return (x, y)


def init_weight(network, idx, value):
    where = lambda l: l[idx][1].weight
    new_network = eqx.tree_at(where, network, jnp.array([value]))
    return new_network


def get_network_weights(network):
    weights = [network[l][1].weight for l in range(len(network))]
    return weights
    

def slice_energy_1D(network, zs, x, y):
    w1 = network[0][1].weight
    w2 = network[1][1].weight
    
    energy_slice = []
    for z in zs:
        energy = jpc.pc_energy_fn(
            params=(network, None),
            activities=[z, z],
            y=y,
            x=x
        )
        energy_slice.append(energy)
    
    z0 = w1*x
    energy0 = jpc.pc_energy_fn(
        params=(network, None),
        activities=[z0, w2*zs],
        y=y,
        x=x
    )
    z_star = (w1*x + w2*y) / (1+w2**2)
    energy_star = jpc.pc_energy_fn(
        params=(network, None),
        activities=[z_star, w2*z_star],
        y=y,
        x=x
    )
    return (
        energy_slice, 
        z0.mean(), 
        energy0, 
        z_star.mean(), 
        energy_star
    )


def slice_energy_2D(network, skip_model, zs, x, y, param_type):
    sampling_resolution = zs[0].shape[0]
    use_skips = False if skip_model is None else True
    
    energy_mesh = jnp.zeros((sampling_resolution, sampling_resolution))
    for k, z1 in enumerate(zs[0]):
        for j, z2 in enumerate(zs[1]):
            energy = jpc.pc_energy_fn(
                (network, skip_model),
                [z1, z2, z2],
                y,
                x=x,
                param_type=param_type
            )
            energy_mesh = energy_mesh.at[j, k].set(energy)
            
    zs_star = jpc.compute_linear_activity_solution(
        network, 
        x, 
        y, 
        use_skips=use_skips, 
        param_type=param_type
    )
    energy_star = jpc.pc_energy_fn(
        params=(network, skip_model),
        activities=zs_star,
        y=y,
        x=x,
        param_type=param_type
    )
    return (
        energy_mesh, 
        [z.mean() for z in zs_star], 
        energy_star
    )


def unwrap_hessian_pytree(hessian_pytree, activities, keep_last_term=False):
    batch_size = activities[0].shape[0]
    if not keep_last_term:
        activities = activities[:-1]
        hessian_pytree = hessian_pytree[:-1]
        
    widths = [a.shape[1] for a in activities]
    N = sum(widths)
    hessian_matrix = np.zeros((N, N))
    
    start_row_idx = 0
    for l, pytree_l in enumerate(hessian_pytree):

        if not keep_last_term:
            pytree_l = pytree_l[:-1]
            
        start_col_idx = 0
        for k, pytree_k in enumerate(pytree_l):
            block = pytree_k.reshape(
                batch_size, batch_size, widths[l], widths[k]
            ).sum(axis=(0, 1))
            
            hessian_matrix[
                start_row_idx:start_row_idx + widths[l], 
                start_col_idx:start_col_idx + widths[k]
            ] = block
  
            start_col_idx += widths[k]
    
        start_row_idx += widths[l]

    return hessian_matrix
    

## Plotting

In [4]:
def plot_1D_energy_slices(network, weight_idx, zs, x, y, save_path, show_init=False):
    base_energy_slice, base_z0, base_energy0, base_z_star, base_energy_star = slice_energy_1D(network, zs, x, y)
    
    larger_w_network = init_weight(network, weight_idx, 2)
    smaller_w_network = init_weight(network, weight_idx, 0.1)
    larger_w_energy_slice, larger_w_z0, larger_w_energy0, larger_w_z_star, larger_w_energy_star = slice_energy_1D(
        larger_w_network, 
        zs, 
        x, 
        y
    )
    smaller_w_energy_slice, smaller_w_z0, smaller_w_energy0, smaller_w_z_star, smaller_w_energy_star = slice_energy_1D(
        smaller_w_network, 
        zs, 
        x, 
        y
    )
    
    z = zs[:, 0, 0]
    colors = pc.sample_colorscale("Blues", 5)
    fig = go.Figure()
    fig.add_traces(
        go.Scatter(
            x=z,
            y=larger_w_energy_slice,
            name=f"$\Large{{w_{weight_idx+1}>1}}$",
            mode="lines",
            line=dict(width=3, color=colors[3])
        )
    )
    fig.add_traces(
        go.Scatter(
            x=z,
            y=base_energy_slice,
            name=f"$\Large{{w_{weight_idx+1}=1}}$",
            mode="lines",
            line=dict(width=3, color=colors[2])
        )
    )
    fig.add_traces(
        go.Scatter(
            x=z,
            y=smaller_w_energy_slice,
            name=f"$\Large{{w_{weight_idx+1}<1}}$",
            mode="lines",
            line=dict(width=3, color=colors[1])
        )
    )
    for i, (z_star, energy_star) in enumerate(zip(
        [base_z_star, larger_w_z_star, smaller_w_z_star],
        [base_energy_star, larger_w_energy_star, smaller_w_energy_star]
    )):
        fig.add_traces(
            go.Scatter(
                x=[z_star],
                y=[energy_star],
                name="$\Large{z^*}$",
                mode="markers",
                marker=dict(
                    color="yellow",
                    size=10,
                    line=dict(
                        color='DarkSlateGrey',
                        width=1
                    ),
                    symbol="star"
                ),
                showlegend=True if i == 0 else False
            )
        )
    
    if show_init:
        for i, (z0, energy0) in enumerate(zip(
            [base_z0, larger_w_z0, smaller_w_z0], 
            [base_energy0, larger_w_energy0, smaller_w_energy0]
        )):
            fig.add_traces(
                go.Scatter(
                    x=[z0],
                    y=[energy0],
                    name="$\Large{z_0}$",
                    mode="markers",
                    marker=dict(
                        color="black",
                        size=10,
                        symbol="circle-open"
                    ),
                    showlegend=True if i == 0 else False
                )
            )

    fig.update_layout(
        height=350,
        width=600,
        xaxis=dict(title="$\LARGE{z}$"),
        yaxis=dict(title="$\LARGE{\mathcal{F}}$"),
        font=dict(size=16),
        margin=dict(r=130)
    )
    fig.write_image(save_path)


def plot_energy_contour_2mlp(
        network, 
        skip_model,
        zs, 
        x, 
        y, 
        param_type,
        title, 
        save_path, 
        smooth_contours=True, 
        activity_updates=None
):
    energy_mesh, zs_star, energy_star = slice_energy_2D(
        network=network, 
        skip_model=skip_model,
        zs=zs, 
        x=x, 
        y=y,
        param_type=param_type
    )
    z1, z2 = zs[0][:, 0, 0], zs[1][:, 0, 0]

    colorscale, points_color = "Viridis", "yellow"
    contours_coloring = "heatmap" if smooth_contours else "fill"

    # contour plot
    contour = go.Contour(
        z=energy_mesh,
        x=z1,
        y=z2,
        colorscale=colorscale,
        showscale=False,
        contours_coloring=contours_coloring
    )

    # gradient vector field
    # gradient_norm = jnp.sqrt(jnp.sum(gradient_field**2, axis=-1, keepdims=True))
    # gradient_norm = jnp.where(gradient_norm > 0, gradient_norm, 1.0)
    # gradient_field_normalized = gradient_field / gradient_norm
    # z1_mesh, z2_mesh = jnp.meshgrid(z1, z2)
    # quiver = ff.create_quiver(
    #     x=z1_mesh,
    #     y=z2_mesh,
    #     u=gradient_field_scaled[:, :, 0],
    #     v=gradient_field_scaled[:, :, 1],
    #     marker_color="rgb(255, 255, 51)",
    #     opacity=0.9,
    #     scale=0.1,
    #     line_width=0.8,
    #     showlegend=False
    # )
    fig = go.Figure(data=[contour]) #, quiver.data[0]])
    fig.add_trace(
        go.Scatter(
            x=[zs_star[0]],
            y=[zs_star[1]],
            mode="markers",
            marker=dict(
                color=points_color,
                size=8,
                line=dict(
                    color="white",
                    width=1
                ),
                symbol="circle"
            ),
            showlegend=False
        )
    )
    z_star_displacement = 0.4 if activity_updates is None else 0.2
    fig.add_trace(
        go.Scatter(
            x=[zs_star[0]+z_star_displacement], 
            y=[zs_star[1]+z_star_displacement],  
            text=["$\Large{z^*}$"],  
            mode="text", 
            textfont=dict(size=20, color=points_color),  
            showlegend=False
        )
    )

    # colorbar
    max_energy, min_energy = float(energy_mesh.max()), float(energy_mesh.min())
    colorbar_trace = go.Scatter(
        x=[None],
        y=[None],
        mode="markers",
        showlegend=False,
        marker=dict(
            colorscale=colorscale,
            showscale=True,
            cmin=min_energy,
            cmax=max_energy,
            colorbar=dict(
                title="$\LARGE{\mathcal{F}}$",
                len=0.5,
                title_side="right",
                tickfont=dict(size=16),
                tickvals=[min_energy, max_energy],
                ticktext=["Low", "High"]
            )
        ),
        hoverinfo="none"
    )
    fig.add_trace(colorbar_trace)

    if activity_updates is not None:
        fig.add_traces(
            go.Scatter(
                x=activity_updates[0],
                y=activity_updates[1],
                mode="lines+markers",
                line=dict(
                    color=points_color,
                    width=2
                ),
                marker=dict(
                    size=4,
                    color=points_color
                ),
                showlegend=False
            )
        )
        
    fig.update_layout(
        xaxis=dict(title="$\Large{z_1}$", nticks=5),
        yaxis=dict(title="$\Large{z_2}$", nticks=5),
        font=dict(size=16),
        plot_bgcolor="white",
        width=500, 
        height=400,
        margin=dict(
            r=100, 
            b=100,  # 50 
            l=50, 
            t=80
        )
    )

    if title is not None:
        fig.update_layout(
            title=dict(
                text=title,
                y=0.85,
                x=0.25 if "," not in title else 0.18,
                xanchor="left",
                yanchor="top"
            )
        )
    
    fig.write_image(save_path)


def plot_energy_surface_2mlp(
        network,
        skip_model,
        zs,
        x,
        y,
        param_type,
        save_path
    ):
    energy_mesh, zs_star, energy_star = slice_energy_2D(
        network=network, 
        skip_model=skip_model,
        zs=zs, 
        x=x, 
        y=y,
        param_type=param_type
    )
    z1, z2 = zs[0][:, 0, 0], zs[1][:, 0, 0]
    
    colorscale = "Viridis"
    fig = go.Figure(
        data=go.Surface(
            z=energy_mesh,
            x=z1,
            y=z2,
            colorscale=colorscale,
        )
    )
    fig.update_traces(
        contours_z=dict(
            show=True,
            usecolormap=True,
            highlightcolor="limegreen",
            project_z=True
        ),
        showscale=False
    )

    fig.update_layout(
        scene=dict(
            xaxis=dict(
                title="",
                autorange="reversed",
                showticklabels=False
            ),
            yaxis=dict(
                title="",
                autorange="reversed",
                showticklabels=False
            ),
            zaxis=dict(
                title="",
                showticklabels=False
            )
        ),
        scene_camera=dict(
            center=dict(x=0.05, y=0.1, z=0),
            eye=dict(x=0.75, y=1.8, z=1.25)
        ),
        font=dict(size=16),
        height=600,
        width=700,
        scene_aspectmode="cube"
    )
    fig.write_image(save_path)


def plot_energy_surface_slice(energy_mesh, activities, save_path, showbackground=True):
    energy_max, energy_min = energy_mesh.max(), energy_mesh.min()
    min_max_diff = energy_max - energy_min
    showscale = True if showbackground else False
    fig = go.Figure(
        data=go.Surface(
            z=energy_mesh,
            x=activities[0],
            y=activities[1],
            colorscale="Viridis",
            showscale=showscale,
            colorbar=dict(
                title=f"$\LARGE{{\mathcal{{F}}}}$",
                x=0.85,
                y=0.57,
                len=0.3,
                titleside="right",
                tickfont=dict(size=16),
                tickvals=[energy_min, energy_max],
                ticktext=["Low" if min_max_diff > 0.1 else "", "High" if min_max_diff > 0.1 else ""]
            )
        )
    )
    if showbackground:
        fig.update_traces(
            contours_z=dict(
                show=True,
                usecolormap=True,
                highlightcolor="limegreen",
                project_z=True,
            )
        )
    fig.update_layout(
        scene=dict(zaxis=(dict(
            title="",
            showticklabels=False
        ))),
        font=dict(size=16),
        height=600,
        width=700,
        margin=dict(r=30, b=10, l=0, t=40),
        scene_aspectmode="cube"
    )
    fig.update_layout(
        scene=dict(
            xaxis=dict(
                title="",
                autorange="reversed",
                showticklabels=False,
                showbackground=showbackground
            ),
            yaxis=dict(
                title="",
                autorange="reversed",
                showticklabels=False,
                showbackground=showbackground
            ),
            zaxis=dict(
                showbackground=showbackground
            )
        ),
        scene_camera=dict(
            center=dict(x=0.05, y=0.2, z=0),
            eye=dict(x=1.4, y=1.4, z=1.25)
        )
    )
    fig.write_image(save_path)


def plot_energy_contour_slice(
        energy_mesh,
        activities,
        plot_solution,
        save_path,
        smooth_contours=True
):
    colorscale, points_color = "Viridis", "yellow"
    contours_coloring = "heatmap" if smooth_contours else "fill"

    # contour plot
    contour = go.Contour(
        z=energy_mesh,
        x=activities[0],
        y=activities[1],
        colorscale=colorscale,
        showscale=False,
        contours_coloring=contours_coloring
    )
    fig = go.Figure(data=[contour])
    if plot_solution:
        fig.add_trace(
            go.Scatter(
                x=[0],
                y=[0],
                mode="markers",
                marker=dict(
                    color=points_color,
                    size=8,
                    line=dict(
                        color="white",
                        width=1
                    ),
                    symbol="circle"
                ),
                showlegend=False
            )
        )
        fig.add_trace(
            go.Scatter(
                x=[0 + 0.3],
                y=[0 + 0.3],
                text=["$\Large{z^*}$"],
                mode="text",
                textfont=dict(size=20, color=points_color),
                showlegend=False
            )
        )

    # colorbar
    max_energy, min_energy = float(energy_mesh.max()), float(energy_mesh.min())
    colorbar_trace = go.Scatter(
        x=[None],
        y=[None],
        mode="markers",
        showlegend=False,
        marker=dict(
            colorscale=colorscale,
            showscale=True,
            cmin=min_energy,
            cmax=max_energy,
            colorbar=dict(
                title="$\LARGE{\mathcal{F}}$",
                len=0.5,
                title_side="right",
                tickfont=dict(size=16),
                tickvals=[min_energy, max_energy],
                ticktext=["Low", "High"]
            )
        ),
        hoverinfo="none"
    )
    fig.add_trace(colorbar_trace)

    fig.update_layout(
        xaxis=dict(title="$\Large{\hat{\mathbf{v}}_{min}}$", nticks=5),
        yaxis=dict(title="$\Large{\hat{\mathbf{v}}_{max}}$", nticks=5),
        font=dict(size=16),
        plot_bgcolor="white",
        width=500,
        height=400,
        margin=dict(
            r=100,
            b=100,
            l=50,
            t=80
        )
    )

    fig.write_image(save_path)


def visualise_energy_slice(
        model,
        skip_model,
        x,
        y,
        use_skips,
        param_type,
        activity_decay,
        domain,
        sampling_resolution,
        hessian_eigenvecs,
        save_path,
        plot_solution,
        showbackground=True
):
    scaling_factors = [
        np.linspace(
            -domain, domain, sampling_resolution
        ) for _ in range(2)
    ]
    energy_mesh = np.zeros((sampling_resolution, sampling_resolution))
    for j, a in enumerate(scaling_factors[0]):
        for i, b in enumerate(scaling_factors[1]):
            # get vector of activities' linear solution
            all_zs_star = jpc.compute_linear_activity_solution(
                network=model,
                x=x,
                y=y,
                use_skips=use_skips,
                param_type=param_type,
                activity_decay=activity_decay
            )
            zs_star = all_zs_star[:-1]
            z_widths = [z.shape[1] for z in zs_star]

            # flatten for perturbation
            z_vec = np.hstack(zs_star)

            # perturb along 2 main hessian directions
            perturbed_z_vec = z_vec + (a * hessian_eigenvecs[0]) + (b * hessian_eigenvecs[-1])

            # reshape perturbed vector into layers
            split_indices = np.cumsum(z_widths)[:-1]
            perturbed_activities = np.hsplit(perturbed_z_vec, split_indices)

            energy = jpc.pc_energy_fn(
                (model, skip_model),
                perturbed_activities + [all_zs_star[-1]],
                y,
                x=x,
                param_type=param_type,
                activity_decay=activity_decay
            )
            energy_mesh[i, j] = energy

    plot_energy_surface_slice(
        energy_mesh,
        scaling_factors,
        f"{save_path}_surface.pdf",
        showbackground=showbackground
    )
    plot_energy_contour_slice(
        energy_mesh,
        scaling_factors,
        plot_solution,
        f"{save_path}_contour.pdf"
    )
    

## Hyperparameters

In [5]:
DATA_MEAN, DATA_STD = 1., 0.1
BATCH_SIZE = 64
ACT_FNS = ["linear", "tanh", "relu"]
SAMPLING_RESOLUTION = 30

SAVE_DIR = "toy_results"
os.makedirs(SAVE_DIR, exist_ok=True)

## Inference landscape geometry of 1-MLPs
This code snippet visualises the activity or inference landscape of scalar predictive coding networks with 1 hidden unit (1-MLPs), i.e. $\mathcal{F}(z) = (z - w_1x)^2 + (y-w_2z)^2$. Since the landscape is convex (at least in the linear case), it is completely characterised by the activity Hessian, which for 1-MLPs is given by $H_{\mathbf{z}} \coloneqq \partial^2 \mathcal{F}/ \partial z^2 = 1 + w_2^2$. The plots verify this by showing that only perturbations of $w_2$, and not $w_1$, change the curvature of the landscape.

In [None]:
mlp1_save_dir = f"{SAVE_DIR}/1mlps"
os.makedirs(mlp1_save_dir, exist_ok=True)

key = jr.PRNGKey(54829)
keys = jr.split(key, 2)

X, Y = make_gaussian_dataset(keys[0], DATA_MEAN, DATA_STD, BATCH_SIZE)
DOMAIN = 3
Zs = jnp.tile(
    jnp.linspace(-DOMAIN, DOMAIN, SAMPLING_RESOLUTION), 
    (BATCH_SIZE, 1)
).T[:, :, jnp.newaxis]

for act_fn in ACT_FNS:
    print(f"act fn: {act_fn}")
    network = jpc.make_mlp(
        key=keys[1],
        input_dim=1,
        width=1,
        depth=2,
        output_dim=1,
        act_fn=act_fn,
        use_bias=False
    )

    # set all the weights to 1 for simplicity
    where1, where2 = lambda l: l[0][1].weight, lambda l: l[1][1].weight
    network = eqx.tree_at(where1, network, jnp.array([1]))
    network = eqx.tree_at(where2, network, jnp.array([1]))
    
    for i in range(2):
        plot_1D_energy_slices(network, i, Zs, X, Y, f"{mlp1_save_dir}/{act_fn}_w{i+1}_energy_slices.pdf")

## Inference landscape geometry of 2-MLPs
This extends the previous analysis by visualising the inference landscape of scalar predictive coding networks with 2 hidden units (2-MLPs), i.e. $\mathcal{F}(\mathbf{z}) = (z_1 - a_1 w_1x)^2 + (z_2 - a_2 w_2z_1)^2 + (y - a_3 w_3z_2)^2$. The (2D) landscape is visualised both as a contour and as a surface plot. The results confirm the 1-MLP analysis.

We also compare the standard parameterisation (SP) given by $a_1 = a_2 = a_3 = 1$ with what we call "$\mu$PC" where $a_2 = 3$. We observe no notable differences, which is expected given the small depth. 

In [None]:
mlp2_save_dir = f"{SAVE_DIR}/2mlps"
os.makedirs(mlp2_save_dir, exist_ok=True)

key = jr.PRNGKey(4328704)
keys = jr.split(key, 2)

X, Y = make_gaussian_dataset(
    keys[0], 
    DATA_MEAN, 
    DATA_STD, 
    BATCH_SIZE
)
DOMAIN = 2
Zs = [
    jnp.tile(
        jnp.linspace(-DOMAIN, DOMAIN, SAMPLING_RESOLUTION), 
        (BATCH_SIZE, 1)
    ).T[:, :, jnp.newaxis] for i in range(2)
]

PARAM_TYPES = ["sp", "mupc"]
for param_type in PARAM_TYPES:
    print(f"\nparam type: {param_type}\n")
    
    for act_fn in ACT_FNS:
        print(f"\tact fn: {act_fn}")
        save_dir = f"{mlp2_save_dir}/{param_type}/{act_fn}"
        os.makedirs(save_dir, exist_ok=True)
        
        network = jpc.make_mlp(
            key=keys[1],
            input_dim=1,
            width=1,
            depth=3,
            output_dim=1,
            act_fn=act_fn,
            use_bias=False,
            param_type=param_type
        )
        skip_model = jpc.make_skip_model(3) if param_type == "mupc" else None
    
        # set all the weights to 1 for simplicity
        where1, where2, where3 = lambda l: l[0][1].weight, lambda l: l[1][1].weight, lambda l: l[2][1].weight
        network = eqx.tree_at(where1, network, jnp.array([[1]]))
        network = eqx.tree_at(where2, network, jnp.array([[1]]))
        network = eqx.tree_at(where3, network, jnp.array([[1]]))
    
        # create networks with imbalanced weights
        larger_w2_network = init_weight(network, 1, 3)
        larger_w3_network = init_weight(network, 2, 3)
    
        # landscape contour plots
        plot_energy_contour_2mlp(
            network, 
            skip_model,
            Zs, 
            X, 
            Y, 
            param_type,
            title="$\Large{w_3=w_2=w_1=1}$", 
            save_path=f"{save_dir}/energy_contour.pdf"
        )
        plot_energy_contour_2mlp(
            larger_w2_network,
            skip_model,
            Zs, 
            X, 
            Y, 
            param_type,
            title="$\Large{w_3=1, w_2=3, w_1=1}$", 
            save_path=f"{save_dir}/energy_contour_large_w2.pdf"
        )
        plot_energy_contour_2mlp(
            larger_w3_network, 
            skip_model,
            Zs, 
            X, 
            Y, 
            param_type,
            title="$\Large{w_3=3, w_2=1, w_1=1}$", 
            save_path=f"{save_dir}/energy_contour_large_w3.pdf"
        )
    
        # landscape surface plots
        plot_energy_surface_2mlp(
            network,
            skip_model,
            Zs,
            X,
            Y,
            param_type,
            save_path=f"{save_dir}/energy_surface.pdf"
        )
        plot_energy_surface_2mlp(
            larger_w2_network,
            skip_model,
            Zs,
            X,
            Y,
            param_type,
            save_path=f"{save_dir}/energy_surface_large_w2.pdf"
        )
        plot_energy_surface_2mlp(
            larger_w3_network,
            skip_model,
            Zs,
            X,
            Y,
            param_type,
            save_path=f"{save_dir}/energy_surface_large_w3.pdf"
        )


## Inference dynamics of 2-MLPs
This extends the previous analysis of 2-MLPs by also visualising the minimisation $\min_z \mathcal{F}$ with different optimisers (e.g. gradient flow vs gradient descent vs Adam).

In [None]:
mlp2_save_dir = f"{SAVE_DIR}/2mlps"
os.makedirs(mlp2_save_dir, exist_ok=True)

key = jr.PRNGKey(10473)
keys = jr.split(key, 2)

SKIP_MODEL = None
PARAM_TYPE = "sp"
MAX_T1 = 50

X, Y = make_gaussian_dataset(
    keys[0], 
    DATA_MEAN, 
    DATA_STD, 
    BATCH_SIZE
)
DOMAIN = 2
Zs = [
    jnp.tile(
        jnp.linspace(-DOMAIN, DOMAIN, SAMPLING_RESOLUTION), 
        (BATCH_SIZE, 1)
    ).T[:, :, jnp.newaxis] for i in range(2)
]

ODE_SOLVERS = {
    "Euler": Euler(), 
    "Heun": Heun(), 
    "Dopri5": Dopri5(), 
}

for act_fn in ACT_FNS:
    print(f"act fn: {act_fn}")
    save_dir = f"{mlp2_save_dir}/{PARAM_TYPE}/{act_fn}"
    os.makedirs(save_dir, exist_ok=True)
    
    network = jpc.make_mlp(
        key=keys[1],
        input_dim=1,
        width=1,
        depth=3,
        output_dim=1,
        act_fn=act_fn,
        use_bias=False,
        param_type=PARAM_TYPE
    )
    where1, where2, where3 = lambda l: l[0][1].weight, lambda l: l[1][1].weight, lambda l: l[2][1].weight
    network = eqx.tree_at(where1, network, jnp.array([[1 if act_fn == "relu" else -1]]))
    network = eqx.tree_at(where2, network, jnp.array([[2]]))
    network = eqx.tree_at(where3, network, jnp.array([[10]]))

    w1 = network[0][1].weight
    w2 = network[1][1].weight
    w3 = network[2][1].weight
    
    weights = get_network_weights(network)
    activity_hessian = jpc.compute_linear_activity_hessian(weights)
    cond_num = jnp.linalg.cond(activity_hessian)
    
    optim = optax.adam(1e-3)
    opt_state = optim.init(
        (eqx.filter(network, eqx.is_array), SKIP_MODEL)
    )

    #################### ODE solvers ####################
    for solver_id, solver in ODE_SOLVERS.items():
        dt = 5e-1 if solver_id == "Euler" else 5e-1
        stepsize_controller = ConstantStepSize() if (
            solver_id == "Euler" 
        ) else PIDController(rtol=1e-3, atol=1e-3)
        
        result = jpc.make_pc_step(
            network,
            optim,
            opt_state,
            output=Y,
            input=X,
            ode_solver=solver,
            max_t1=MAX_T1,
            dt=dt,
            stepsize_controller=stepsize_controller,
            record_activities=True
        )
        t_max, activities = result["t_max"], result["activities"]    
        plot_energy_contour_2mlp(
            network, 
            SKIP_MODEL,
            Zs, 
            X, 
            Y,
            PARAM_TYPE,
            title=None, 
            save_path=f"{save_dir}/{solver_id}_dynamics_dt_{dt}_t1_{MAX_T1}.pdf",
            activity_updates=[
                activities[0][:t_max+1].mean(axis=(-2, -1)), 
                activities[1][:t_max+1].mean(axis=(-2, -1))
            ]
        )

    #################### GD ####################
    activity_optim = optax.sgd(5e-1)
    activities = jpc.init_activities_with_ffwd(
        model=network,
        input=X
    )
    activity_opt_state = activity_optim.init(activities)    
    activity_updates = [[], []]
    t_max = 100
    for t in range(t_max):
        energy, activity_grads = jpc.compute_activity_grad(
            params=(network, SKIP_MODEL),
            activities=activities,
            y=Y,
            x=X
        )
        updates, activity_opt_state = activity_optim.update(
            updates=activity_grads,
            state=activity_opt_state,
            params=activities
        )
        activities = eqx.apply_updates(
            model=activities,
            updates=updates
        )
        activity_updates[0].append(activities[0].mean())
        activity_updates[1].append(activities[1].mean())
    
    plot_energy_contour_2mlp(
        network, 
        SKIP_MODEL,
        Zs, 
        X, 
        Y,
        PARAM_TYPE,
        title=None, 
        save_path=f"{save_dir}/GD_dynamics_t_{t_max}.pdf",
        activity_updates=activity_updates
    )
    
    #################### Adam ####################
    activity_optim = optax.adam(5e-1)
    activities = jpc.init_activities_with_ffwd(
        model=network,
        input=X
    )
    activity_opt_state = activity_optim.init(activities)    
    activity_updates = [[], []]
    t_max = 4
    for t in range(t_max):
        energy, activity_grads = jpc.compute_activity_grad(
            params=(network, SKIP_MODEL),
            activities=activities,
            y=Y,
            x=X
        )
        updates, activity_opt_state = activity_optim.update(
            updates=activity_grads,
            state=activity_opt_state,
            params=activities
        )
        activities = eqx.apply_updates(
            model=activities,
            updates=updates
        )
        activity_updates[0].append(activities[0].mean())
        activity_updates[1].append(activities[1].mean())
    
    plot_energy_contour_2mlp(
        network, 
        SKIP_MODEL,
        Zs, 
        X, 
        Y,
        PARAM_TYPE,
        title=None, 
        save_path=f"{save_dir}/Adam_dynamics_t_{t_max}.pdf",
        activity_updates=activity_updates
    )


## Inference landscape slices of deep & wide MLPs
This plots slices of the energy for arbitrary MLPs at the linear inference solution along the biggest and smallest curvature directions, i.e. $\mathcal{F}(z^* + \alpha \hat{v}_{\text{min}} + \beta \hat{v}_{\text{max}})$ where $\hat{v}_{\text{max}}$ and $\hat{v}_{\text{min}}$ are the maximum and minimum eigenvectors of the activity Hessian, respectively. For both linear and nonlinear networks, we find that: 
* the landscape always becomes more ill-conditioned with depth, and
* skip connections greatly increase ill-conditionness at large depth, although more so for the standard (vs $\mu$PC) parameterisation.

In [6]:
mlp_landscape_save_dir = f"{SAVE_DIR}/mlps/landscape_slices"
os.makedirs(mlp_landscape_save_dir, exist_ok=True)

key = jr.PRNGKey(19235)
keys = jr.split(key, 2)

DOMAIN = 2
X, Y = make_gaussian_dataset(keys[0], DATA_MEAN, DATA_STD, 1)

WIDTH_DEPTH_COMBOS = [
    {"width": 128, "depth": 16},
    #{"width": 2, "depth": 3}
]
PARAM_TYPES = ["sp", "mupc"]

for width_depth in WIDTH_DEPTH_COMBOS:
    print(f"\nwidth and depth: {width_depth}\n")
    width = width_depth["width"]
    depth = width_depth["depth"]
    
    for param_type in PARAM_TYPES:
        print(f"\tparam type: {param_type}\n")
        
        skip_uses = [False, True] if param_type == "sp" else [True]
        for use_skips in skip_uses:
            print(f"\t\tuse_skips: {use_skips}\n")
        
            for act_fn in ACT_FNS:
                print(f"\t\t\tact fn: {act_fn}\n")
                
                save_dir = os.path.join(
                    mlp_landscape_save_dir,
                    param_type,
                    "skips" if use_skips else "no_skips",
                    act_fn
                )
                os.makedirs(save_dir, exist_ok=True)
            
                network = jpc.make_mlp(
                    keys[1], 
                    input_dim=1,
                    width=width,
                    depth=depth,
                    output_dim=1,
                    act_fn=act_fn, 
                    use_bias=False,
                    param_type=param_type
                )
                skip_model = jpc.make_skip_model(depth) if use_skips else None
                
                activities = jpc.init_activities_with_ffwd(
                    network, 
                    X, 
                    skip_model=skip_model,
                    param_type=param_type
                )
                hessian_pytree = jax.hessian(jpc.pc_energy_fn, argnums=1)(
                    (network, skip_model),
                    activities,
                    Y,
                    x=X,
                    param_type=param_type
                )
                num_H = unwrap_hessian_pytree(hessian_pytree, activities)
                eigenvals, eigenvecs = jnp.linalg.eigh(num_H)
                cond_num = np.abs(max(eigenvals))/np.abs(min(eigenvals))
                print(f"\t\t\t\tcondition number: {cond_num:.4f}\n")
                
                min_max_eigenvecs = [eigenvecs[:, 0], eigenvecs[:, -1]]
                    
                visualise_energy_slice(
                    model=network, 
                    skip_model=skip_model, 
                    x=X, 
                    y=Y, 
                    use_skips=use_skips,
                    param_type="sp",
                    activity_decay=False, 
                    domain=DOMAIN, 
                    sampling_resolution=SAMPLING_RESOLUTION, 
                    hessian_eigenvecs=min_max_eigenvecs, 
                    save_path=f"{save_dir}/energy_slice_N_{width}_L_{depth}", 
                    plot_solution=True if act_fn == "linear" else False,
                    showbackground=False  # changed
                )



width and depth: {'width': 128, 'depth': 16}

	param type: sp

		use_skips: False

			act fn: linear

				condition number: 42.1188

			act fn: tanh

				condition number: 41.9170

			act fn: relu

				condition number: 15.7973

		use_skips: True

			act fn: linear

				condition number: 45765.4336

			act fn: tanh

				condition number: 534.2700

			act fn: relu

				condition number: 9763.3066

	param type: mupc

		use_skips: True

			act fn: linear

				condition number: 1917.8442

			act fn: tanh

				condition number: 1343.7063

			act fn: relu

				condition number: 1345.4940



## Inference & learning dynamics of deep chains
The script below analyses the inference and learning dynamics of deep chains or scalar PCNs ($N=1$) for different parameterisations.

In [None]:
from jax.tree_util import tree_map

from utils import (
    init_weights,    
    compute_param_spectral_norms, 
    compute_hessian_eigens
)
from plotting import (
    plot_loss,
    plot_norms,
    plot_energies,
    plot_max_min_eigenvals,
    plot_max_min_eigenvals_2_axes
)

In [None]:
def plot_activities(activities, save_path, theory_activities=None, log=False):
    n_layers = activities.shape[1]
    n_train_iters = activities.shape[0]
    train_iters = [t for t in range(n_train_iters)]
    layer_idxs = [1, "1/4L", "1/2L", "3/4L", "L"]
    
    fig = go.Figure()
    if theory_activities is not None:
        fig.add_traces(
            go.Scatter(
                x=[None],
                y=[None],
                mode="lines",
                line=dict(width=3, color="black", dash="dash"),
                name="theory"
            )
        )

    colorscale = "Reds"
    colors = pc.sample_colorscale(colorscale, n_layers + 2)[2:]
    for i, color in enumerate(colors):
        layer_idx = layer_idxs[i]

        if n_train_iters == 1:
            fig.add_hline(
                y=activities[0, i],
                name=f"$\ell = {{{layer_idx}}}$",
                line=dict(
                    color=color,
                    width=2
                ),
                showlegend=True
            )
        else:
            fig.add_traces(
                go.Scatter(
                    x=train_iters,
                    y=activities[:, i],
                    name=f"$\ell = {{{layer_idx}}}$",
                    mode="lines",
                    line=dict(
                        width=2,
                        color=color
                    )
                )
            )
        
        if theory_activities is not None:
            fig.add_hline(
                y=theory_activities[i],
                line=dict(
                    color=color,
                    width=4,
                    dash="dash"
                )
            )

    fig.update_layout(
        height=350,
        width=525,
        xaxis=dict(
            title="Inference iteration",
            showticklabels=False
        ),
        yaxis=dict(title="$\Large{z_\ell}$"),
        font=dict(size=18),
        margin=dict(r=140, l=100, b=90)
    )
    if log:
        fig.update_layout(
            yaxis=dict(
                type="log",
                exponentformat="power",
                dtick=1
            )
        )
    if n_train_iters > 1:
        fig.update_layout(
            xaxis=dict(
                showticklabels=True,
                tickvals=[0, int(train_iters[-1] / 2), train_iters[-1]],
                ticktext=[0, int(train_iters[-1] / 2), train_iters[-1]]
            )
        )
    
    fig.write_image(save_path)


def plot_2D_data(x, y, y_hat, save_path):
    fig = go.Figure()
    fig.add_trace(
        go.Scatter(
            x=x[:, 0],
            y=y[:, 0],
            mode="markers",
            marker=dict(size=8, color="#636EFA"),
            showlegend=False
        )
    )
    fig.add_trace(
        go.Scatter(
            x=x[:, 0],
            y=y_hat[:, 0],
            mode="markers",
            marker=dict(size=8, color="#EF553B"),
            name="$\hat{y}$"
        )
    )
    fig.update_layout(
        height=300,
        width=400,
        xaxis=dict(title="$\Large{x}$"),
        yaxis=dict(title="$\Large{y}$"),
        font=dict(size=16)
    )
    fig.write_image(save_path)


In [None]:
chains_save_dir = f"{SAVE_DIR}/deep_chains"
os.makedirs(chains_save_dir, exist_ok=True)

key = jax.random.PRNGKey(23853)
keys = jax.random.split(key, 3)

DATA_MEAN, DATA_STD = 1, 1
NOISE_STD = 0.5
WIDTH = 1
BATCH_SIZE = 64
N_TRAIN_ITERS = 1000
TEST_EVERY = 100

N_HIDDENS = [64]#[2**i for i in range(3, 7)]
PARAM_TYPES = ["mupc"]#, "sp"]
ACTIVITY_INITS = ["ffwd"]#, "ffwd", "random"]

for n_hidden in N_HIDDENS:
    print(f"\nn_hidden = {n_hidden}")
    
    for param_type in PARAM_TYPES:
        print(f"\n\tparam type: {param_type}")
        
        param_lr = 1e-3 #if param_type == "sp" else 1e-1
        skip_uses = [True] if param_type == "sp" else [True]  # NOTE: removed nonskip from sp for testing
        
        for use_skips in skip_uses:
            print(f"\n\t\tuse_skips: {use_skips}")

            for activity_init in ACTIVITY_INITS:
                print(f"\n\t\t\tactivity init: {activity_init}")

                if activity_init == "theory":
                    max_infer_iters = [0] 
                    activity_lrs = [1]
                    act_fns = ["linear"]
                else:  
                    max_infer_iters = [1, n_hidden, n_hidden*2]
                    activity_lrs = [1, 10, 50]
                    act_fns = ["linear"] #, "tanh", "relu"]  # NOTE: removed for testing
                    
                for act_fn in act_fns:
                    print(f"\n\t\t\t\tact fn: {act_fn}")

                    for activity_lr in activity_lrs:
                        print(f"\n\t\t\t\t\tactivity_lr: {activity_lr}")
            
                        for n_infer_iters in max_infer_iters: 
                            print(f"\n\t\t\t\t\t\tn_infer_iters: {n_infer_iters}\n")

                            save_dir = os.path.join(
                                chains_save_dir,
                                f"noise_std_{NOISE_STD}",
                                f"{n_hidden}_n_hidden",
                                param_type,
                                "skips" if use_skips else "no_skips",
                                f"{activity_init}_activity_init",    
                                act_fn,
                                f"activity_lr_{activity_lr}",
                                f"{n_infer_iters}_n_infer_iters"
                            )
                            os.makedirs(save_dir, exist_ok=True)
                    
                            # create and initialise model
                            d_in, d_out = 1, 1
                            L = n_hidden + 1
                            model = jpc.make_mlp(
                                key=keys[0],
                                input_dim=d_in,
                                width=WIDTH,
                                depth=L,
                                output_dim=d_out,
                                act_fn=act_fn,
                                use_bias=False,
                                param_type=param_type
                            )
                            skip_model = jpc.make_skip_model(L) if use_skips else None
                            
                            # optimisers
                            param_optim = optax.adam(param_lr)
                            param_opt_state = param_optim.init(
                                (eqx.filter(model, eqx.is_array), skip_model)
                            )
                            activity_optim = optax.sgd(activity_lr) 
                                                
                            # metrics
                            train_losses = np.zeros(N_TRAIN_ITERS+1)
                    
                            n_test_iters = N_TRAIN_ITERS // TEST_EVERY
                            layer_idxs = [0, int(L / 4) - 1, int(L / 2) - 1, int(L * 3 / 4) - 1, L - 1]
                            
                            train_num_activities = np.zeros(
                                (N_TRAIN_ITERS+1, n_infer_iters+1, len(layer_idxs))
                            )
                            theory_activities_start_end_train = np.zeros((2, len(layer_idxs)))
                                                    
                            train_num_energies = np.zeros((n_test_iters+1, len(layer_idxs)))
                            train_theory_energies = np.zeros_like(train_num_energies)
    
                            param_spectral_norms = np.zeros((N_TRAIN_ITERS+1, len(layer_idxs)))
                            max_min_hess_eigenvals = np.zeros((2, n_test_iters+1))
        
                            iter_keys = jr.split(keys[-1], N_TRAIN_ITERS)
                            for train_iter in range(N_TRAIN_ITERS+1):
                                data_key, noise_key, init_key = jr.split(iter_keys[train_iter], 3)
                                x = DATA_MEAN + DATA_STD * jr.normal(data_key, (BATCH_SIZE, 1))
                                eps = NOISE_STD * jr.normal(noise_key, (BATCH_SIZE, 1))
                                y = -x + eps

                                # preds
                                activities = jpc.init_activities_with_ffwd(
                                    model=model,
                                    input=x,
                                    skip_model=skip_model,
                                    param_type=param_type
                                )
                                preds = activities[-1]
                                train_loss = jpc.mse_loss(preds, y)            
                                train_losses[train_iter] = train_loss
                                if np.isinf(train_loss) or np.isnan(train_loss):
                                    break
                                
                                # initialise activities
                                if activity_init == "random":
                                    activities = jpc.init_activities_from_normal(
                                        key=init_key,
                                        layer_sizes=[d_in] + [WIDTH]*n_hidden + [d_out],
                                        mode="supervised",
                                        batch_size=BATCH_SIZE,
                                        sigma=1
                                    )
                                
                                elif activity_init == "theory":
                                    activities = jpc.compute_linear_activity_solution(
                                        network=model,
                                        x=x,
                                        y=y,
                                        use_skips=use_skips,
                                        param_type=param_type
                                    )
                                                                
                                activity_opt_state = activity_optim.init(activities)
                            
                                # record metrics at init
                                i = 0
                                for l, activity in enumerate(activities):
                                    if l in layer_idxs:
                                        train_num_activities[train_iter, 0, i] = activity.mean()
                                        i += 1
                            
                                if train_iter == 0:
                                    theory_activities = jpc.compute_linear_activity_solution(
                                        network=model,
                                        x=x,
                                        y=y,
                                        use_skips=use_skips,
                                        param_type=param_type,
                                    )
                                    i = 0
                                    for l, activity in enumerate(theory_activities):
                                        if l in layer_idxs:
                                            theory_activities_start_end_train[train_iter, i] = activity.mean()
                                            i += 1
    
                                    num_energies = jpc.pc_energy_fn(
                                        params=(model, skip_model),
                                        activities=activities,
                                        y=y,
                                        x=x,
                                        param_type=param_type,
                                        record_layers=True
                                    )
                                    theory_energies = jpc.pc_energy_fn(
                                        params=(model, skip_model),
                                        activities=theory_activities,
                                        y=y,
                                        x=x,
                                        param_type=param_type,
                                        record_layers=True
                                    )  
                                    train_num_energies[0] = np.array([
                                        e for l, e in enumerate(reversed(num_energies)) if l in layer_idxs
                                    ])
                                    train_theory_energies[0] = np.array([
                                        e for l, e in enumerate(reversed(theory_energies)) if l in layer_idxs
                                    ])
                                                                                                            
                                    eigenvals, eigenvecs = compute_hessian_eigens(
                                        params=(model, skip_model),
                                        activities=tree_map(lambda a: a[[0], :], activities),
                                        y=y[[0], :],
                                        x=x[[0], :],
                                        param_type=param_type
                                    )
                                    max_min_hess_eigenvals[:, 0] = np.array(
                                        [max(eigenvals), min(eigenvals)]
                                    )
                            
                                # inference
                                if activity_init != "theory":
                                    for t in range(n_infer_iters):
                                        activity_update_result = jpc.update_activities(
                                            params=(model, skip_model),
                                            activities=activities,
                                            optim=activity_optim,
                                            opt_state=activity_opt_state,
                                            output=y,
                                            input=x,
                                            param_type=param_type
                                        )
                                        activities = activity_update_result["activities"]
                                        activity_opt_state = activity_update_result["opt_state"]
                                                                                                
                                        i = 0
                                        for l, act in enumerate(activities):
                                            if l in layer_idxs:
                                                train_num_activities[train_iter, t+1, i] = jnp.mean(act)
                                                i += 1
                                else:
    
                                    i = 0
                                    for l, act in enumerate(activities):
                                        if l in layer_idxs:
                                            train_num_activities[train_iter, 0, i] = jnp.mean(act)
                                            i += 1
    
                                param_spectral_norms[train_iter] = compute_param_spectral_norms(
                                    model=model,
                                    act_fn=act_fn,
                                    layer_idxs=layer_idxs
                                )
                            
                                # update parameters
                                param_update_result = jpc.update_params(
                                    params=(model, skip_model),
                                    activities=activities,
                                    optim=param_optim,
                                    opt_state=param_opt_state,
                                    output=y,
                                    input=x,
                                    param_type=param_type
                                )
                                model = param_update_result["model"]
                                skip_model = param_update_result["skip_model"]
                                param_opt_state = param_update_result["opt_state"]
                            
                                if train_iter > 0 and train_iter % TEST_EVERY == 0:
                                    print(
                                        f"Train loss: {train_loss:.7f} [{train_iter}/{N_TRAIN_ITERS}]"
                                    )
                                    test_iter = int(train_iter / TEST_EVERY)
                                    theory_activities = jpc.compute_linear_activity_solution(
                                        network=model,
                                        x=x,
                                        y=y,
                                        use_skips=use_skips,
                                        param_type=param_type
                                    )
                                    i = 0
                                    for l, activity in enumerate(theory_activities):
                                        if l in layer_idxs:
                                            theory_activities_start_end_train[-1, i] = activity.mean()
                                            i += 1
    
                                    num_energies = jpc.pc_energy_fn(
                                        params=(model, skip_model),
                                        activities=activities,
                                        y=y,
                                        x=x,
                                        param_type=param_type,
                                        record_layers=True
                                    )
                                    theory_energies = jpc.pc_energy_fn(
                                        params=(model, skip_model),
                                        activities=theory_activities,
                                        y=y,
                                        x=x,
                                        param_type=param_type,
                                        record_layers=True
                                    )  
                                    train_num_energies[test_iter] = np.array([
                                        e for l, e in enumerate(reversed(num_energies)) if l in layer_idxs
                                    ])
                                    train_theory_energies[test_iter] = np.array([
                                        e for l, e in enumerate(reversed(theory_energies)) if l in layer_idxs
                                    ])
                                    
                                    eigenvals, eigenvecs = compute_hessian_eigens(
                                        params=(model, skip_model),
                                        activities=tree_map(lambda a: a[[0], :], activities),
                                        y=y[[0], :],
                                        x=x[[0], :],
                                        param_type=param_type
                                    )
                                    max_min_hess_eigenvals[:, test_iter] = np.array(
                                        [max(eigenvals), min(eigenvals)]
                                    )
                             
                            # plotting
                            plot_2D_data(x, y, preds, f"{save_dir}/data_samples.pdf")
                            
                            plot_loss(
                                loss=train_losses,
                                yaxis_title="Train loss",
                                xaxis_title="Iteration",
                                save_path=f"{save_dir}/train_losses.pdf",
                                mode="lines"
                            )
                            plot_norms(
                                norms=param_spectral_norms,
                                norm_type="param_spectral",
                                save_path=f"{save_dir}/param_spectral_norms.pdf"
                            )
                            plot_energies(
                                energies=train_num_energies.T,
                                test_every=TEST_EVERY,
                                save_path=f"{save_dir}/energies.pdf",
                                theory_energies=train_theory_energies.T,
                                log=False
                            )
                            plot_energies(
                                energies=train_num_energies.T,
                                test_every=TEST_EVERY,
                                save_path=f"{save_dir}/log_energies.pdf",
                                theory_energies=train_theory_energies.T,
                                log=True
                            )
                            plot_max_min_eigenvals(
                                max_min_eigenvals=max_min_hess_eigenvals,
                                test_every=TEST_EVERY,
                                save_path=f"{save_dir}/max_min_activity_hess_eigenvals.pdf"
                            )
                            plot_max_min_eigenvals_2_axes(
                                max_min_eigenvals=max_min_hess_eigenvals,
                                test_every=TEST_EVERY,
                                save_path=f"{save_dir}/max_min_activity_hess_eigenvals_2_axes.pdf"
                            )
    
                            # plot activities at start & end of training
                            plot_activities(
                                activities=train_num_activities[0],
                                save_path=f"{save_dir}/activities_over_inference_at_init.pdf",
                                theory_activities=theory_activities_start_end_train[0],
                                log=False
                            )                                
                            plot_activities(
                                activities=train_num_activities[-1],
                                save_path=f"{save_dir}/activities_over_inference_at_last_train_iter.pdf",
                                theory_activities=theory_activities_start_end_train[-1],
                                log=False
                            )


## Training with $z^*$

This analyses the inference and learning dynamics of PCNs trained using the analytical inference solution $\mathbf{z}^* = \mathbf{H}^{-1}\mathbf{b}$.

In [None]:
from jax.tree_util import tree_map

from experiments.datasets import get_dataloaders
from utils import (
    set_seed,    
    init_weights,
    compute_param_spectral_norms, 
    compute_hessian_eigens, 
    spectral_norm,
    compute_metric_stats
)
from plotting import (
    plot_loss,
    plot_loss_and_accuracy,
    plot_n_infer_iters, 
    plot_norms,
    plot_energies,
    plot_hessian_eigenvalues_during_training,
    plot_max_min_eigenvals,
    plot_max_min_eigenvals_2_axes
)

In [None]:
def evaluate(params, test_loader, param_type):
    model, skip_model = params
    avg_test_loss, avg_test_acc = 0, 0
    for batch_id, (img_batch, label_batch) in enumerate(test_loader):
        img_batch, label_batch = img_batch.numpy(), label_batch.numpy()

        test_loss, test_acc = jpc.test_discriminative_pc(
            model=model,
            output=label_batch,
            input=img_batch,
            skip_model=skip_model,
            param_type=param_type
        )
        avg_test_loss += test_loss
        avg_test_acc += test_acc

    return avg_test_loss / len(test_loader), avg_test_acc / len(test_loader)


def plot_metric_stats(metric, yaxis_title, test_every, save_path, yaxis_type="linear"):
    key = next(iter(metric))
    n_train_iters = len(metric[key][0])
    train_iters = [t for t in range(n_train_iters)]
    
    ivs = metric.keys()
    colorscale = "Reds" if "loss" in yaxis_title else "Blues"
    colors = pc.sample_colorscale(colorscale, len(ivs)+2)[2:]
    
    fig = go.Figure()
    for iv, color in zip(ivs, colors):
        means, stds = metric[iv][0], metric[iv][1]
        y_upper, y_lower = means + stds, means - stds

        fig.add_trace(
            go.Scatter(
                x=list(train_iters) + list(train_iters[::-1]),
                y=list(y_upper) + list(y_lower[::-1]),
                fill="toself",
                fillcolor=color,
                line=dict(color="rgba(255,255,255,0)"),
                hoverinfo="skip",
                showlegend=False,
                opacity=0.3
            )
        )
        fig.add_trace(
            go.Scatter(
                x=train_iters,
                y=means,
                mode="lines" if "Train" in yaxis_title else "lines+markers",
                line=dict(width=2, color=color),
                name=f"$H={iv}$"
            )
        )

    xtickvals = [0, int(train_iters[-1]/2), train_iters[-1]]
    xticktext = xtickvals if (
            "Train" in yaxis_title
    ) else [(t+1) * test_every for t in xtickvals]
    fig.update_layout(
        height=300,
        width=400,
        xaxis=dict(
            title="Training iteration",
            tickvals=xtickvals,
            ticktext=xticktext
        ),
        yaxis=dict(title=yaxis_title, type=yaxis_type),
        font=dict(size=16),
        margin=dict(r=120)
    )
    fig.write_image(save_path)


In [None]:
theory_save_dir = f"{SAVE_DIR}/mlps/analytical_inference"
os.makedirs(theory_save_dir, exist_ok=True)

# for accurate inversion of the Hessian with wide and deep nets
jax.config.update("jax_enable_x64", True)

dataset = "MNIST"
width = 128
act_fn = "linear"
use_skips = True
param_optim_id = "Adam"
activity_optim_id = "GD"
activity_lr = 1
batch_size = 64
activity_decay = 0
weight_decay = 0
spectral_penalty = 0
max_epochs = 1
test_every = 100
n_seeds = 3

sigma = 1

PARAM_TYPES = ["mupc"]  #sp
N_HIDDENS = [8, 16, 32]
for param_type in PARAM_TYPES:
    for n_hidden in N_HIDDENS:
        for seed in range(n_seeds):
            print(
                f"\nStarting experiment with {param_type}, H = {n_hidden}, seed {seed}\n"
            )
            save_dir = os.path.join(
                theory_save_dir,
                param_type,
                f"{n_hidden}_n_hidden",
                str(seed)
            )
            set_seed(seed)
            key = jax.random.PRNGKey(seed)
            keys = jax.random.split(key, 2)
            os.makedirs(save_dir, exist_ok=True)
        
            # create and initialise model
            d_in, d_out = 784, 10
            L = n_hidden + 1
            model = jpc.make_mlp(
                key=keys[0],
                input_dim=d_in,
                width=width,
                depth=L,
                output_dim=d_out,
                act_fn=act_fn,
                use_bias=False,
                param_type=param_type
            )
            skip_model = jpc.make_skip_model(L) if use_skips else None
            
            # optimisers
            param_lr = 1e-3 if param_type == "sp" else 1e-1
            if param_optim_id == "SGD":
                param_optim = optax.sgd(param_lr)
            elif param_optim_id == "Adam":
                param_optim = optax.adam(param_lr)
            
            param_opt_state = param_optim.init(
                (eqx.filter(model, eqx.is_array), skip_model)
            )
            
            activity_optim = optax.sgd(activity_lr) if (
                    activity_optim_id == "GD"
            ) else optax.adam(activity_lr)
            
            # data
            train_loader, test_loader = get_dataloaders(dataset, batch_size)
            
            # metrics
            train_losses = []
            test_losses, test_accs = [], []
    
            n_train_iters = len(train_loader.dataset) // batch_size * max_epochs
            n_test_iters = n_train_iters // test_every * max_epochs
            layer_idxs = [0, int(L / 4) - 1, int(L / 2) - 1, int(L * 3 / 4) - 1, L - 1]
        
            activity_l2_norms = np.zeros((n_train_iters, len(layer_idxs)))
            
            param_spectral_norms = np.zeros((n_train_iters+1, len(layer_idxs)))    
            hessian_eigenvals = np.zeros((n_test_iters + 1, width * n_hidden))
            max_min_hess_eigenvals = np.zeros((2, n_test_iters + 1))
            
            train_num_energies = np.zeros((n_test_iters + 1, len(layer_idxs)))
            
            global_batch_id = 0
            for train_iter, (img_batch, label_batch) in enumerate(train_loader):
                img_batch, label_batch = img_batch.numpy(), label_batch.numpy()
        
                # compute theory activities
                activities = jpc.compute_linear_activity_solution(
                    network=model,
                    x=img_batch,
                    y=label_batch,
                    use_skips=use_skips,
                    param_type=param_type,
                    activity_decay=activity_decay
                )
                train_loss = jpc.mse_loss(activities[-1], label_batch)
        
                i = 0
                for l, act in enumerate(activities):
                    if l in layer_idxs:
                        activity_l2_norms[global_batch_id, i] = np.array(
                            jnp.linalg.norm(act, axis=1, ord=2).mean()
                        )
                        i += 1
                
                if global_batch_id == 0:
                    energies = jpc.pc_energy_fn(
                        params=(model, skip_model),
                        activities=activities,
                        y=label_batch,
                        x=img_batch,
                        param_type=param_type,
                        activity_decay=activity_decay,
                        weight_decay=weight_decay,
                        spectral_penalty=spectral_penalty,
                        record_layers=True
                    )
                    train_num_energies[0] = np.array([
                            e for l, e in enumerate(reversed(energies)) if l in layer_idxs
                    ])
            
                    param_spectral_norms[0] = compute_param_spectral_norms(
                        model=model,
                        act_fn=act_fn,
                        layer_idxs=layer_idxs
                    )
                    eigenvals, eigenvecs = compute_hessian_eigens(
                        params=(model, skip_model),
                        activities=tree_map(lambda a: a[[0], :], activities),
                        y=label_batch[[0], :],
                        x=img_batch[[0], :],
                        param_type=param_type,
                        activity_decay=activity_decay,
                        weight_decay=weight_decay,
                        spectral_penalty=spectral_penalty
                    )
                    hessian_eigenvals[0] = eigenvals
                    max_min_hess_eigenvals[:, 0] = np.array(
                        [max(eigenvals), min(eigenvals)]
                    )
        
                # update parameters
                param_update_result = jpc.update_params(
                    params=(model, skip_model),
                    activities=activities,
                    optim=param_optim,
                    opt_state=param_opt_state,
                    output=label_batch,
                    input=img_batch,
                    param_type=param_type,
                    activity_decay=activity_decay,
                    weight_decay=weight_decay,
                    spectral_penalty=spectral_penalty
                )
                model = param_update_result["model"]
                skip_model = param_update_result["skip_model"]
                param_opt_state = param_update_result["opt_state"]
            
                param_spectral_norms[global_batch_id+1] = compute_param_spectral_norms(
                    model=model,
                    act_fn=act_fn,
                    layer_idxs=layer_idxs
                )
                train_losses.append(train_loss)
                global_batch_id += 1
            
                if global_batch_id % test_every == 0:
                    print(
                        f"Train loss: {train_loss:.7f} [{train_iter * len(img_batch)}/{len(train_loader.dataset)}]"
                    )
                    avg_test_loss, avg_test_acc = evaluate(
                        params=(model, skip_model),
                        test_loader=test_loader,
                        param_type=param_type
                    )
                    test_losses.append(avg_test_loss)
                    test_accs.append(avg_test_acc)
                    print(f"Avg test accuracy: {avg_test_acc:.4f}\n")
            
                    test_iter = int(global_batch_id / test_every)
                    eigenvals, eigenvecs = compute_hessian_eigens(
                        params=(model, skip_model),
                        activities=tree_map(lambda a: a[[0], :], activities),
                        y=label_batch[[0], :],
                        x=img_batch[[0], :],
                        param_type=param_type,
                        activity_decay=activity_decay,
                        weight_decay=weight_decay,
                        spectral_penalty=spectral_penalty
                    )
                    hessian_eigenvals[test_iter] = eigenvals
                    max_min_hess_eigenvals[:, test_iter] = np.array(
                        [max(eigenvals), min(eigenvals)]
                    )
                    energies = jpc.pc_energy_fn(
                        params=(model, skip_model),
                        activities=activities,
                        y=label_batch,
                        x=img_batch,
                        param_type=param_type,
                        activity_decay=activity_decay,
                        weight_decay=weight_decay,
                        spectral_penalty=spectral_penalty,
                        record_layers=True
                    )
                    train_num_energies[test_iter] = np.array([
                            e for l, e in enumerate(reversed(energies)) if l in layer_idxs
                        ])
    
            np.save(f"{save_dir}/train_losses.npy", train_losses)
            np.save(f"{save_dir}/test_losses.npy", test_losses)
            np.save(f"{save_dir}/test_accs.npy", test_accs)
             
            # plotting
            plot_loss(
                loss=train_losses,
                yaxis_title="Train loss",
                xaxis_title="Iteration",
                save_path=f"{save_dir}/train_losses.pdf"
            )
            plot_loss_and_accuracy(
                loss=test_losses,
                accuracy=test_accs,
                mode="test",
                xaxis_title="Training iteration",
                save_path=f"{save_dir}/test_losses_and_accs.pdf",
                test_every=test_every
            )
            plot_norms(
                norms=param_spectral_norms,
                norm_type="param_spectral",
                save_path=f"{save_dir}/param_spectral_norms.pdf"
            )
            plot_energies(
                energies=train_num_energies.T,
                test_every=test_every,
                save_path=f"{save_dir}/energies.pdf",
                log=False
            )
            plot_energies(
                energies=train_num_energies.T,
                test_every=test_every,
                save_path=f"{save_dir}/log_energies.pdf",
                log=True
            )
            plot_hessian_eigenvalues_during_training(
                eigenvals=[e for i, e in enumerate(hessian_eigenvals) if i % 2 == 0],
                test_every=200,
                save_path=f"{save_dir}/hessian_eigenvals.pdf"
            )
            plot_max_min_eigenvals(
                max_min_eigenvals=max_min_hess_eigenvals,
                test_every=test_every,
                save_path=f"{save_dir}/max_min_eigenvals.pdf"
            )
            plot_max_min_eigenvals_2_axes(
                max_min_eigenvals=max_min_hess_eigenvals,
                test_every=test_every,
                save_path=f"{save_dir}/max_min_eigenvals_2_axes.pdf"
            )
            plot_norms(
                norms=np.array([activity_l2_norms[0]] * 2),
                norm_type="activity",
                save_path=f"{save_dir}/theory_activity_l2_norms_at_init.pdf",
                theory_norms=activity_l2_norms[0],
                showticklabels=False
            )
            plot_norms(
                norms=np.array([activity_l2_norms[-1]] * 2),
                norm_type="activity",
                save_path=f"{save_dir}/theory_activity_l2_norms_last_train_iter.pdf",
                theory_norms=activity_l2_norms[-1],
                showticklabels=False
            )


In [None]:
############ plot results ############
PARAM_TYPES = ["sp", "mupc"]
N_HIDDENS = [8, 16, 32]
n_seeds = 3
test_every = 100

train_losses_per_H = {}
test_losses_per_H = {}
test_accs_per_H = {}
for param_type in PARAM_TYPES:
    for n_hidden in N_HIDDENS:
            
        train_losses_all_seeds = np.zeros((n_seeds, 937))
        test_losses_all_seeds = np.zeros((n_seeds, 9))
        test_accs_all_seeds = np.zeros_like(test_losses_all_seeds)
        for seed in range(n_seeds):
            save_dir = os.path.join(
                theory_save_dir,
                param_type,
                f"{n_hidden}_n_hidden",
                str(seed)
            )
            train_losses = np.load(f"{save_dir}/train_losses.npy")
            test_losses = np.load(f"{save_dir}/test_losses.npy")
            test_accs = np.load(f"{save_dir}/test_accs.npy")

            train_losses_all_seeds[seed] = train_losses
            test_losses_all_seeds[seed] = test_losses
            test_accs_all_seeds[seed] = test_accs
    
        train_losses_means, train_losses_stds = compute_metric_stats(train_losses_all_seeds)
        test_losses_means, test_losses_stds = compute_metric_stats(test_losses_all_seeds)
        test_acc_means, test_acc_stds = compute_metric_stats(test_accs_all_seeds)
    
        train_losses_per_H[n_hidden] = (train_losses_means, train_losses_stds)
        test_losses_per_H[n_hidden] = (test_losses_means, test_losses_stds)
        test_accs_per_H[n_hidden] = (test_acc_means, test_acc_stds)
    
    plot_metric_stats(
        metric=train_losses_per_H, 
        yaxis_title="Train loss", 
        test_every=test_every,
        save_path=f"{theory_save_dir}/{param_type}_train_losses.pdf",
        yaxis_type="log" if param_type == "sp" else "linear"
    )
    plot_metric_stats(
        metric=test_losses_per_H, 
        yaxis_title="Test loss",
        test_every=test_every,
        save_path=f"{theory_save_dir}/{param_type}_test_losses.pdf",
        yaxis_type="log" if param_type == "sp" else "linear"
    )
    plot_metric_stats(
        metric=test_accs_per_H, 
        yaxis_title="Test accuracy (%)",
        test_every=test_every,
        save_path=f"{theory_save_dir}/{param_type}_test_accs.pdf"
    )