This is a guide to drawing Neural Circuit Diagrams by Vincent Abbott from the paper [*Neural Circuit Diagrams: Robust Diagrams for the Communication, Implementation, and Analysis of Deep Learning Architectures*](https://openreview.net/forum?id=RyZB4qXEgt). It allows for deep learning algorithms to be comprehensively expressed using a novel diagrammatic scheme.

This is the Mathcha component of the guide. Implementations can be found in the [Mathcha](https://www.mathcha.io/editor/p8KjdC6yI3nH7VLEpfLoK4lOSNWnNl1hNNOmp3) portion.

### Basic Components.
We begin with an overview of the basics of neural circuit diagrams. Wires represent axes, dashed lines represent tuples, and operations change the shape of data and are represented as symbols or pictograms. Axes can be drawn to broadcast operations.

<img src="PNG/BasicComponents.png" width="1000">

### Guidelines.
We use [mathcha.io](https://www.mathcha.io/editor) (math-*cha*) to make and modify diagrams. There are standard settings for snapping and managing diagrams which make the process of creating diagrams easier.

<img src="PNG/Guidelines.png" width="700">

### Implementation.
Diagrams are in close correspondence to implementation. This means that once a diagram is made, implementing an algorithm is straightforward. Here, I have included an implementation of the above diagram as an example.

<img src="PNG/TwoLayer.png" width="850">

In [None]:
import torch
from torch import Tensor as T
from typing import Tuple, Any
import torch.nn as nn
import math

In [None]:
class TwoLayerResidualMLP(nn.Module):
    """ A feed-forward layer consisting of two learned linear layers."""
    def __init__(self, xbar : int, m: int, dff : int,
        device : Any | None = None, dtype : Any | None = None) -> None:
        super().__init__()
        self.xbar, self.m, self.dff = xbar, m, dff
        # Bold (learned) components must be initialized.
        # The + indicates that learned linear layers have bias.
        bias = True
        self.L0 = nn.Linear(m, dff, bias, device, dtype)
        self.L1 = nn.Linear(dff, m, bias, device, dtype)
    
    def forward(self, x : T):
        """ ... m -> ... m """
        # We keep "x" for an implicit copy.
        # Linear layers are applied onto the lowest axis.
        x1 = self.L0(x)
        x1 = nn.functional.relu(x1)
        x1 = self.L1(x1)
        x1 = x1 + x
        return x1

# Now we can run it on fake data;
xbar, m, dff = 256, 1024, 4096
# Input Data
x = torch.rand((xbar, m))

ff = TwoLayerResidualMLP(xbar, m, dff)
assert tuple(ff.forward(x).size()) == (xbar, m)

### Einstein Operations.
By keeping operation and shape columns separate, we can easily diagram more complex algorithms in Mathcha. Here, we diagram Multi-Head Attention, an algorithm which uses multiple Einstein operations. The interaction between axes is clearly shown with diagrams.

*(This is a subsection of the [full transformer diagram](https://twitter.com/jxmnop/status/1757244005639766157).)*

<img src="PNG/Einops.png" width="700">

In [None]:
%pip install einops
import einops

In [None]:
class Multilinear(nn.Module):
    """ A learned linear linear which supports tuple axis sizes. """
    def __init__(self, in_size  : Tuple[int] | int, out_size : Tuple[int] | int, 
        bias : bool = True, device : Any | None = None, dtype : Any | None = None) -> None:
        super().__init__()
        # Set the parameters
        get_size = lambda x: (x, math.prod(x)) \
            if isinstance(x, tuple) else ((x,), x)
        self.in_size,  self.in_features  = get_size(in_size)
        self.out_size, self.out_features = get_size(out_size)
        # Set up the linear module
        self.linear = nn.Linear(self.in_features, self.out_features, bias, device, dtype)
    
    def forward(self, x_in : torch.Tensor):
        # Reshape the input. The last axes should match, else there's an error.
        x_in = x_in.reshape(
            x_in.shape[:-len(self.in_size)] + (self.in_features,))
        # Apply the linear over the last axis.
        x = self.linear(x_in)
        # Return the proper output.
        return x.reshape(
            x.shape[:-1] + self.out_size)

def MultiHeadDotProductAttention(q: T, k: T, v: T) -> T:
    # In practice, we add dots to consider batched axes.
    ''' ... y k h, ... x k h, ... x k h -> ... y k h '''
    klength = k.size()[-2]
    x = einops.einsum(q, k, '... y k h, ... x k h -> ... y x h')
    x = torch.nn.Softmax(-2)(x / math.sqrt(klength))
    x = einops.einsum(x, v, '... y x h, ... x k h -> ... y k h')
    return x

# We implement this component as a neural network model.
# This is necessary when there are bold, learned components that need to be initialized.
class MultiHeadAttention(nn.Module):
    # Multi-Head attention has various settings, which become variables
    # for the initializer.
    def __init__(self, m, k, h):
        super().__init__()
        self.m, self.k, self.h = m, k, h
        # Set up all the boldface, learned components
        self.Lq = Multilinear(m, (k,h), False)
        self.Lk = Multilinear(m, (k,h), False)
        self.Lv = Multilinear(m, (k,h), False)
        self.Lo = Multilinear((k,h), m, False)


    # We have endogenous data (y) and external / injected data (x)
    def forward(self, y : T, x : T):
        """ ... ybar m, ... xbar m -> ... ybar m """
        # We first generate query, key, and value vectors.
        # Linear layers are automatically broadcast.
        q = self.Lq(y)
        k = self.Lk(x)
        v = self.Lv(x)

        # We feed q, k, and v to standard multi-head inner product attention
        o = MultiHeadDotProductAttention(q, k, v)
        return self.Lo(o)

# Now we can run it on fake data;
ybar, xbar, m, k, h = 20, 22, 128, 16, 4
# Internal Data
y = torch.rand((ybar, m))
# External Data
x = torch.rand((xbar, m))

mha = MultiHeadAttention(m,k,h)
assert tuple(mha.forward(y, x).size()) == (ybar, m)