# Weight changes due to STDP

Bimodal distribution of synaptic weights in the network after training on N4P2 shapes, shown for modifiable connections between excitatory neurons.

**Plots (labelled by columns):**

A) Feedforward weights: L0 -> L1

B) Lateral L1 <-> L1

C) Feedback L2 -> L1

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from pydantic import BaseModel

from hsnn.core import Projection
from hsnn import simulation, viz
from hsnn.utils import handler, io
from hsnn.utils.handler import TrialView

pidx = pd.IndexSlice
RESULTS_DIR = io.BASE_DIR / "out/figures/fig7"
RESULTS_DIR.mkdir(parents=True, exist_ok=True)

def load_syn_params(
    trial: TrialView, projs: list[str] | None = None
) -> dict[str, pd.DataFrame]:
    sim = simulation.Simulator.from_config(trial.config)
    dst = {}
    for state in ("pre", "post"):
        chkpt = state_chkpt_mapping[state]
        store_path = (
            trial.checkpoints[chkpt].store_path if isinstance(chkpt, int) else None
        )
        if store_path is not None:
            sim.restore(store_path)
        dst[state] = sim.network.get_syn_params(projections=projs)
    return dst


class InferenceConfig(BaseModel):
    """Inference kwargs passed to `handler.load_results`."""

    amplitude: int
    subdir: str | None = None


syn_params: dict[str, list[pd.DataFrame]] = {"pre": [], "post": []}
state_chkpt_mapping = {"pre": None, "post": -1}

viz.setup_journal_env()

### 1) Load N4P2 recorded results

**Get Trials logged in metadata for the combination: (`E2E`, `FF`)**

In [None]:
logdir = "n4p2/train_n4p2_lrate_0_02_181023"

expt = handler.ExperimentHandler(logdir)
dataset_name = expt.logdir.parent.stem

trial_names = expt.metadata.get_trials_dict("ALL")["inference"]
trials = [expt[trial_name] for trial_name in trial_names]

In [None]:
# Common parameters
config = trials[0].config

input_shape = tuple(config["topology"]["poisson"]["EXC"])
layer_shape = tuple(config["topology"]["spatial"]["EXC"])

projs_plastic = ["FF", "E2E", "FB"]

In [None]:
# Get trial synapse parameters
projs_enum = [Projection[key] for key in projs_plastic]
for trial in trials:
    trial_params = load_syn_params(trial, projs=projs_enum)
    for state in ["pre", "post"]:
        syn_params[state].append(trial_params[state])

### 2) Group weights 

Gather by (layer-proj)

In [None]:
weights: dict[str, pd.DataFrame] = {}
for state, weights_trials in syn_params.items():
    weights[state] = pd.concat([syn["w"] for syn in weights_trials], axis=1)
    weights[state].columns = pd.MultiIndex.from_product(
        [["w"], range(len(syn_params["pre"]))]
    )


### 3) Plot weights

In [None]:
figsize = (5.5, 4)

hist_kwargs = {
    "density": True,
    "facecolor": "#1f77b4",
    "edgecolor": "#1f77b4",
}
text_kwargs = {"fontsize": 8, "bbox": {"facecolor": "white"}}

axes = viz.hist_weights(
    weights["post"],
    projs_plastic,
    bins=80,
    text_kwargs=text_kwargs,
    hist_kwargs=hist_kwargs,
    figsize=figsize,
)

ax: plt.Axes
for ax in axes[-1, :]:
    ax.set_xlabel("Weight")


def to_nearest(value, step=2):
    return value - (value % step)


ymax_values = [4, 8, 16, 21]
for i in range(len(axes)):
    for j in range(len(axes[i])):
        ax = axes[i, j]
        if ax.has_data():
            ax.set_ylim(0, ymax_values[i])
            ax.set_yticks(np.linspace(0, to_nearest(ymax_values[i]), 2))
            if ax.get_ylabel():
                ax.set_yticklabels([f"{ytick:.0f}" for ytick in ax.get_yticks()])
                ax.yaxis.set_label_coords(-0.3, 0.5)  # Set a fixed distance from axis

f = plt.gcf()
f.tight_layout()
f.savefig(
    RESULTS_DIR / "fig_synaptic_weights_yticks.pdf",
    dpi=300,
)