### Initialization

In [1]:
%config Completer.use_jedi=False

Latex headers
$\newcommand{\E}{\mathbb{E}}$
$\newcommand{\Var}{\mathrm{Var}}$
$\newcommand{\Cov}{\mathrm{Cov}}$

In [2]:
import tqdm
import math

import matplotlib.pyplot as plt
import seaborn as sns

import numpy as np
import torch
import salina
import wandb

In [3]:
%env "WANDB_NOTEBOOK_NAME" "self_organizing_multiparadigm_networks"
wandb.init(project="self_organizing_multiparadigm_networks", entity="jacobfv")

env: "WANDB_NOTEBOOK_NAME"="critical_multiparadigm_networks"


Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mjacobfv[0m (use `wandb login --relogin` to force relogin)


In [4]:
wandb.config = {
    "epochs": 100,
    "batch_size": 128,
    "sgd_learning_rate": 0.001,
    "ul_learning_rate": 0.001,
    "IP_stddev": 0.05,
    "IP_mean": 0.1,
    "cm_alpha": 2.0,
    "lambda_reconstruct_x": 1.0,
    "lambda_reconstruct_h1": 0.9,
    "lambda_reconstruct_h2": 0.8,
    "lambda_reconstruct_h3": 0.7,
}

# Self-Organizing Multi-Paradigm Networks (SOMPNets)

Self-organizing multi-paradigm networks (SOMPNets) are a family of deep learning architectures designed for iterative update processing (such as progressive representation refinement). SOMPNets are characterized by 
- activations displaying critical phenomena including very long correlation distance and ergodicity
- layer-local iteratively-applied unsupervised parameter update rules
- training by a joint combination of supervised, self-supervised, unsupervised, and reinforcement learnin
- pseudo-linear activation transformations allowing linear combination of separate hidden states and parameters from asynchronously trained networks

## Design

Each SOMPCell in a self-organizing multi-paradigm network updates by a system of fast and slow temporal dynamics

### Fast Dyanmics

**Activations** may be dense or sparse and are given by
$$Z_t = X^1_t W^1 + X^2_t W^2 + \dots + b + \zeta$$
$$B_t = a_{b,decay} B_{t-1} + Z_t - a_{b,discharge} |Y^{sparse}_{t-1}|$$
$$Y^{sparse}_t = \alpha_{sparse} \Theta_{\pm} (\alpha_{sparse}^{-1} Z)$$
$$Y^{dense}_t = \alpha_{dense} \tanh (\alpha_{dense}^{-1}Z_t)$$
where
- $X^1 \in \mathbb{R}^{\dots \times t \times n_1}, X^2 \in \mathbb{R}^{\dots \times t \times n_2}, \dots$ are the inputs,
- $W^1 \in \mathbb{R}^{n_1 \times n_{out}}, W^2 \in \mathbb{R}^{n_2 \times n_{out}}, \dots$ are weights corresponding to each input,
- $b \in \mathbb{R}^{n_{out}}$ is the bias vector,
- $\zeta \in \mathbb{R}^{n_{out}} \sim \mathcal{N}(\mu=\mu_{\zeta}, \Sigma=\Sigma_{\zeta})$ is a noise factor,
- $Z_t \in \mathbb{R}^{\dots \times t \times n_{out}}$ is the input summand
- $B_t \in \mathbb{R}^{\dots \times t \times n_{out}}$ is the sparse activation bucket
- $Y^{\{sparse,dense\}} \in \mathbb{R}^{\dots \times t \times n_{out}}$ is the output activation

and a positive-negative threshold function
$$\Theta_{\pm}(x) = \begin{cases}
    +1 & \text{if  } x \ge 1 \\
    0  & \text{if  } -1 < x < 1 \\
    -1 & \text{if  } x \le -1 
\end{cases}$$
with gradients flowing through $\Theta_{\pm}(x)$ as if it were linear with respect to $x$.

**Dense activations:** Temporality rescaling dense activation inputs extends the $\tanh$ nonsaturating regime to $\approx( -\alpha_{dense}/2, \alpha_{dense}/2)$ which facilitates long propagation distances and helps keep the gradients alive on backpropagation. Within the psuedo-linear regime, activations and weights may be linearly combined facilitating decentralized and asynchronous training. 

**Sparse activations** utilize a presynaptic bucket $B$ to store signals which gradually decay by a factor of $a_{b,decay}$ at every time step and abruptly decrement by $a_{b,discharge}$ when discharging. 

In both the dense and sparse case, activations are expressed within a common continuous/discrete $(-\alpha, +\alpha)$ open/closed range. Weights $W$ are intiailized with the expectation that inputs share an equivalent domain, and bias $b$ is initialized at $0$. Both trainable variables $W$, $b$ are $L_2$-regularized to minimize exploding gradients and alleviate soft or hard signal saturation. (Conversely, noise covariance $\Sigma_{\zeta}$ can be increased with the aim of encouraging weight growth (hence, saturating activations) and in later epochs might be increased (even cyclicly) to help sparsify weights.)

**Multiparadigm training** blends unsupervised processes, supervisory and reinforcement signals, and self-supervised objectives to optimize the trainable parameters of CMPNets. Being end-to-end differentiable, individual critical-multiparadigm cells (CMPCell's) can be combined with other deep learning layers such as convolutional layers and attention mechanisms. Unsupervised weight update rules compute gradients *on each foreward pass* with an unsupervised update rate that may optionally be modulated by a reward signal. Gradients may be applied immediantly, or they may be accumulated over a time sequence and then applied along with backpropagated gradients by the optimizer. Note that combining unsupervised update rules with supervisory ones requires flipping the gradient in code since the optimizer updates variables in the reverse direction of gradients (however I keep the sign positive in math). I apply the following 7 update rules with the intention that weights and biases should adapt to represent the statistical dynamics underlying their associated activations:

**1. Spike Timing Dependant Plasticity (STDP)** makes synapse $w^{ji}$ decrease when $y_i$ activates before $x_j$ (since the connection must've not been important) and increase when $x_i$ activates after $y_j$ (since the connection made a significant contribution). It has no effect when one of the values is $0$ and an inverse effect when one value is positive and the other is negative (or vice versa):
$$\Delta w_{STDP,ji} = \hat{x}_{t-1,j} \hat{y}_{t,i} - \hat{x}_{t,j} \hat{y}_{t-1,i}$$
with $\hat{\cdot{}}$ representing $\cdot{} \div \alpha$ where $\alpha$ is corresponding to the layer that produced the activation in heterogeneous architectures. In the extended temporal case, STDP is expressed using Python-style slicing as:
$$\Delta w_{STDP,ji} = \hat{x}_{:-1,j} \hat{y}_{1:,i} - \hat{x}_{1:,j} \hat{y}_{:-1,i}$$
Finally, full matrix deltas are computed as:
```python
dW_stdp = - (einsum('...j,...i->ji', xhat[..., :-1, :], yhat[..., 1:, :])
            - einsum('...j,...i->ji', xhat[..., 1:, :], yhat[..., :-1, :]))
```

**2. Covariance Decay (CD)** makes synapse $w^{ji}$ decay or grow in nonlinear proportion to the absolute covariance between $x_j$ and $y_i$ computed via $\beta_{cd}$ rolling means:
$$\sigma_{ji} = \Cov(\hat{x}_{:-1,j} \hat{y}_{1:,i}) = \E_{\beta_{cd}}{[ \hat{x}_{t-1,j}\hat{y}_{t,i} ]} - \E_{\beta_{cd}}{[ \hat{x}_{t,j} ]} \E_{\beta_{cd}}{[ \hat{y}_{t,i} ]}$$
The absolute of this covariance factor $|\sigma|$ is next scaled around $1$ by a learned parameter $a_{cd}$ giving the covariance decay coefficient: $$c_{ji} = a_{cd}(|\sigma_{ji}|-1)+1$$ Ideally, the covariance decay coefficient $c_{ji}$ could be applied to directly onto its corresponding weight as $w_{ji} \leftarrow c_{ji}w_{ji}$. However to remain compatible with gradient-based update paradigms, this coefficient is finally expressed as a weight increment to be compounded with other gradients:
$$\Delta w_{CD,ji} = (1-c_{ji}) w_{ji}$$

In tensor computations, this is expressed:
```python
cov = Bmean(xhat[..., :-1, :, None] * yhat[..., 1:, None, :], beta=beta_cd, axis=-3) \
    - Bmean(xhat, axis=-2)[..., :, :, None] * Bmean(yhat, axis=-2)[..., :, None, :]
coef = a_cd * cov - a_cd + 1
dW_cd = - (1 - coef) * W
```

**3. Structural Plasticity (SP)** randomly adds small values synapses between unconnected neurons by Bernoili probability factor $p_{SP}=\frac{a_{sp}}{N_x N_y}$ which scales inversely quatratically with respect to the number of input $N_x$ and output $N_y$ dimensions. 
$$\Delta w_{SP,ji} = d, \; \; \;d \sim \mathcal{B}(\ \cdot\ ; p=\small{\frac{a_{sp}}{N_x N_y}})$$
$a_{sp}$ is made differentiable by the reparametrization trick:
```python
dW_SP = -sigmoid(TODO)
```

**4. Intrinsic plasticity (IP)** homeostatically shift the inputs maintain a mean firing rate $H_{IP} \sim \mathcal{N}(\mu_{IP}=0.1, \Sigma_{IP}^2 = 0)$ 
$$\Delta \mu_{\zeta} = \eta_{IP}[x(t) - H_{IP}]$$
TODO


**5. Variance-Invariance-Covariance Sequence Regularization (VIC)** aims to build representations that are progressively insensitive to time by applying the following regularization penalties:
- local temporal element invariance: Assuming that sequence elements in a local neighborhood represent the same information, maximize the similarity $s$ between $y_{\dots,t,:}$ and $y_{\dots,t-1,:},y_{\dots,t-2,:},\dots$ assigning exponential weight to nearer elements
$$\mathcal{L}_{VIC} = $$ 
https://arxiv.org/pdf/2109.00783.pdf

**6. L2 regularization** prevents weights from growing excessively large unless they serve a meaningful statistical purpose. Formally,
$$ \mathcal{L}_{L2,W} = \frac{1}{2} \sum_{ji} w_{ji}^2$$
$$ \mathcal{L}_{L2,b} = \frac{1}{2} \sum_{i} b_{i}^2$$
and in code,
```python
add_loss(eta_L2W * (W**2).sum())
add_loss(eta_L2W * (b**2).sum())
```

**7. Mean Regularization (MR)** aim to preserve a mean activation across between SOMPCells by 
$$\mathcal{L}_{MR} = (\sum \hat{y}-\mu_{MR})^2$$
with gradients backpropagated through all trainable parameters by
```python
loss_MR = (yhat.sum() - mu_MR)**2
add_loss(eta_MR * loss_MR)
```

**7. Sparsity Regularization (SR)** (which applies only to sparse activation `SOMPCell`'s) aims to tune sparsity by penalizing activation KL-divergence from a Bernoilii distribution
$$\mathcal{L}_{SR} = \sum \biggl[ |\hat{y}| \log { \frac{ |\hat{y}| }{ a_{SR} } } - (1 - \hat{y})| \log { \frac{ 1 - |\hat{y}| }{ 1 - a_{SR} } } \biggr]$$
with gradients backpropagated through all trainable parameters by
```python
kl_SR = yhat * log(yhat / a_SR) - (1-yhat) * log((1-yhat)/(1-a_SR))
loss_SR = kl_SR.sum()
add_loss(eta_SR * loss_SR)
```

**8. Supervised learning (SL)** trains weights with respect to a supervised or self-supervised objective as normal. 
$$ \Delta W_{SL} = \frac{\delta \mathcal{L}}{\delta W} \biggr{|}_{x}$$
And in code,
```python
sl_loss = eta_SL * sl_loss
sl_loss.backward()
for var in trainable_vars:
    var.grad *= eta_SL
```

Many of the problem domains SOMPNet are applied to lend self-supervised next-sequence-element prediction as a powerfu pretraining objective:
```python
loss = (ypred - ytrue)**2
```

Unsupervised weight updates are applied as gradients which accumulate over steps on a sequence along with backpropagation gradients before being applied to the trainable variables by an optimizer. Putting it all togethor, a foreward pass looks *like* this:
```python
def foreward(self, inputs, previous, ...):
    
    # possibly build new weights if new inputs have been added
    self.build({k: inputs[k].shape for k in inputs.keys()})
    
    # foreward computations
    zeta = ...
    Z = sum(inputs[k] @ self.weights[k] for k in inputs.keys()) + self.bias + zeta
    Y = alpha * tanh(Z / alpha)  # the output
    # shaped like `previous` but shifted. i.e.: `previous[...,-1, :] == Y[..., -2, :]`
    
    # slow updates
    # NOTE: Gradients are reversed from delta's given in above 
    # formulae because optimization is a minimization process
    # TODO: double-check the above statement with torch.
    # 1. STDP
    for k in inputs.keys():
        dW_stdp = ...
        self.weights[k].grad -= lr_stdp * dW_stdp
        
    # 2. CD
    for k in inputs.keys():
        dW_cd = ...
        self.weights[k].grad -= lr_cd * dW_cd
        
    # 3. SP
    for k in inputs.keys():
        dW_sp = ...
        self.weights[k].grad -= lr_sp * dW_sp
    
    # 4. IP
    self.bias.grad -= ...
    
    # 5. VIC
    sim = ...
    covar = ...
    invar = ...
    self.add_loss(lambda_vic_sim*sim + lambda_vic_covar*covar + lambda_vic_invar*invar)
    
    # ... other update rules
    
    # return output signal with values for next time step
    return Y
```

### Training

The objective of weight training is to move optimization into the inner loop. This almost inherently demands cyclic multi-interaction step network toplogies and update processes. I blend gradient descent and unsupervised updates as:
```python
# `env` can be a dataset providing sequential examples (videos, text, robot task)
while not env.done:
    # foreward pass and local unsupervised updates
    y = model(x)
    # supervised loss, critic function
    loss = eta_SL * loss(y)
    loss.backward()  # accumulate SL gradients
    # all gradients have accumulated, now apply them (faster optimization)
    opt.apply_gradients()
    
# or wait until end  of interaction sequence to apply gradients (slower optimization)
opt.apply_gradients()
```

## Implementation

### CMPCell
TODO: 
- specify time must be >2
- change it to state that grads are added over the interaction steps


### CMPNet

Since `CMPCell`'s update iteratively, I use `salina` to wrap potentially cyclic layer connectivity graphs into a single seq2seq `CMPNet` which can connect and differentiate with other networks. Updates 

## Experiments

I test CMPNets on the following experiments:
- progressive representation refinement
- autoregression
- decision making

I also observe how well CMPNets perform acting as:
- feedforeward dense layers
- standard convolution layers {1D, 2D, 3D}
- replacing single/multi-head attention key/query weight matrices
- performing 

## Slow Dynamics

Weights are modified with the intention that they should describe the statistical dynamics underlying actual activation patterns. I interpret signal values as probabilities from $-\alpha$ (certainly false), $0$ (unknown), to $+\alpha$ (certainly true) update weights by the following unsupervised rules:

1. **Spike Timing Dependant Plasticity (STDP)** makes synapse $w^{ji}$ decrease when $y_i$ activates before $x_j$ (since the connection must've not been important) and increase when $x_i$ activates after $y_j$ (since the connection made a significant contribution). It has no effect when one of the values is $0$:
$$\Delta w_{ji} = \eta_{STDP} ( x_{t-1,j} y_{t,i} - x_{t,j} y_{t-1,i} )$$
In the extended temporal case,
```python
# input is shaped [.., T, Nin]
# output is shaped [..., T, Nout]
# W is shaped [Nin, Nout]

# dw_ji = eta_STDP * (input_j[:-1]output_i[1:] - input_j[1:]output_i[:-1])
dw = eta * (


)

```

- general coorelation updates: $cw_{ji} = Coor(x_

```python
# spike timing dependant plasticity
dW = input[:-1]output[1:] - input[1:]output[:-1]
if reward:
    dW *= (1+reward)  # or some other function to make rewarding events more memorable
```

The noise mean component homeostatically shifts to center average activation values around a desired setpoint
```python

```

## Regularization

Weights $W_{.}$ and biases $b$ are $L_2$-regularized to avoid saturating the $\tanh$ input:
$$\mathcal{L}_{L_2} = \lambda_{L_2} \sum_{i,j} w_{ij}^2 + \lambda_{L_2} \sum_{i} b_i^2$$ 
```python
add_loss(lambda_L2_reg*( (W**2).sum() + (b**2).sum() ))
```


modify W with the assumption that it should describe the statistical dynamics underlying actual activation patterns. Update by

w = beta_cW*w - eta_sgd*dw_sgd + eta_stdp*(1+arousal(r))*dW_stdp

- gradients: sgd
- normalized spike timing dependant plasticity: dW_stdp = j1[:-1]i1[1:] - j1[1:]i1[:-1]
- cW from absolute normalized coorelation on shifted time: cW = | Coor(i1[1:], j1[:-1]) |
- beta_cW is a rolling mean of cW
- i1 is the activation normalized by its scaling parameter to be in (-1,+1)

## Theory

Critical multi-paradigm networks utilize a rescaled activation $Y = \alpha \tanh(\alpha^{-1} XW + b + k + \zeta)$. Scaling the activation space around $\tanh$ extends the nonsaturating regime to $\approx( -\alpha/2, \alpha/2)$ facilitating psuedo-linear activation foreward propagation and gradient backpropagation across deep layer traversals. Weights $W$ and biases $b$ are $L_2$-regularized to further alleviate saturation. Conversely, noise $\zeta \sim \mathcal{N}(\mu=0, \sigma^2=\sigma_{\zeta}^2)$ encourages larger weights (hence, saturating activations) and in later epochs might be increased (even cyclicly) to help sparsify weights. Recieving and expressing activations in the $(-\alpha, +\alpha)$ range allows CMN's to interface and train end-to-end with other deep learning layers including attention mechanisms. CMN's internally utilize unsupervised weight update rules with a nonhomeogeneous reward-modulated update-rate.

## Implementation

Unsupervised rules should still apply their changes by addition on each timestep. (by default, gradients don't flow through unsupervised weight update rules). The previous activation outputs must be supplied if any other inputs are supplied (this is programmatically a hidden state)

```python
wandb.config = { ... }  # lowest precedence hyperparameters (optional)

...

# salina could turn this into a simple rnn
# retrieve variables from reuccrent workspace
y_next = CMCell({'x': X, 'y': y}, prev=y)  # generate y value for t+1 given values <=t
# write variable to reuccrent workspace

# open-ended growth
y_next = CMCell({'x': X, 'y': y, 'v': V}, prev=y)  # new inputs
y_next = CMCell({'y': y, 'v': V}, prev=y)  # optional inputs
y_next = CMCell({'x': X, 'y': y, 'v': V}, prev=y, update=False, hparams=overrides)

...

# build recurrent model from cell
model.fit(X,Ytrue)
```

Hyperparameters are overriden from `wandb.config` (optional, lowest precedence), `CMCell.HPARAM_DEFAULTS`, `self.hparams`, `hparams` (fn arg), and finally specific args and assignments (highest precedence).

Input values are passed as a dictionary `{'x1': x1, 'x2': x2, 'y': y}`. All inputs for a single call must have matching batch axes and sequence lengths but can have individually varying input dimensions (i.e.: input must be shaped `[..., T, N1], [..., T, N2], [..., T, N3], ...`). You can change the number of inputs at any time by calling `build` with a corresponding dictionary of input shapes or just passing a new dictionary key in. Keys cannot be deleted. All inputs are optional.

## Unsupervised Learning

I interpret signal values as probabilities from $-\alpha$ (certainly false), $0$ (unknown), to $+\alpha$ (certainly true) and assume neurons follow an exprapolated rule of implication. For each weight $w_{ji} \in W$ connecting input $j$ to output $i$, elementwise updates proceed as:

```python
def update_weight(w_ji, x, y, x_prev, y_prev, t):
    #  Updates the weight between input j and output i
    #  w_ji: current value of weight
    #  x: activation value of input j at time t
    #  y: activation value of output i at time t
    #  x_prev: activation value of input j at time t-1
    #  y_prev: activation value of output i at time t-1
    #  t: threshold to establish nonzero activation value = small number like 0.2
    
    # TODO: make these cases exclusive
    # maybe identify a few simple expressions that capture everything
    
    if w_ji == 0:  # weight is dead
        dw = 0  # do nothing
    if x_prev > t and y > t:  # parent was + and the child was +
        dw = 0  # nothing needs to change
    if x_prev > t and -t < y < t:  # parent was + and the child was neutral
        dw = +1  # the weight mattered so it should be strengthened
    if x_prev > t and y < -t:  # parent was + and the child was -
        dw = -1  # the weight didn't matter so it should be recycled
       
       
```

Bias $b$

## Supervised Learning

Use standard gradient descent to optimize:

$$\arg_{\theta_{model}} \min L(x)$$

In [None]:
class

In [None]:
# in the training loop
wandb.log({"loss": loss})

# Optional
wandb.watch(model)