This repository is a light-weight re-implementation of the Hamux framework for generalized hierarchical associative memory (HAM) networks in PyTorch.
The paper on hierarchical associative memory can be found here.
The purpose of this repository is mostly educational for myself. I wanted to understand better how energy-based models can be implemented in PyTorch (especially with iterated gradients). This is possible by using torch.autograd.grad with the option create_graph=True. This way one can backpropagate through a model where the forward pass is a discretized gradient flow with respect to some global energy.
The code in this repository is experimental and not built for efficient training. Furthermore, it is purely a re-implementation of existing methods. Credits go to Ben Hoover for the original JAX implementation and to Dmitry Krotov for the invention of HAM in this form.
curl -LsSf https://astral.sh/uv/install.sh | sh
From the root of the repository, run the command
uv sync
Then activate the virtual environment by
source .venv/bin/activate
A HAM network is specified by the following data:
- A list of neuron layers with associated lagrangians (whose derivatives will be the activation function acting on the neurons).
- A list of (hyper) synapses, which take a number of neuron activations as input and return a scalar value, representing the energy of the synapse.
- A list of synaptic connections, i.e. which layers wire to which synapses.
As an example, consider the following hierarchical network with three neuron layers and a synapse connecting subsequent layers:
(Image taken from the HAM paper)
The following code snippet implements this network as a fully functional torch module:
from ham.config import (
NeuronLayerConfig,
SynapseType,
HyperSynapseConfig,
LagrangianConfig,
LagrangianType,
HAMConfig,
)
from ham.network import HAMNetwork
# Three neuron layers with different number of neurons and lagrangians
layers = [
NeuronLayerConfig(
num_neurons=8, lagrangian=LagrangianConfig(type=LagrangianType.QUADRATIC)
),
NeuronLayerConfig(
num_neurons=16,
lagrangian=LagrangianConfig(type=LagrangianType.LSE, kwargs={"beta": 1.0}),
),
NeuronLayerConfig(
num_neurons=32,
lagrangian=LagrangianConfig(
type=LagrangianType.LOG_COSH, kwargs={"beta": 1.0}
),
time_constant=0.1,
),
]
# Two synapses that each take two neuron activations as input
synapses = [
HyperSynapseConfig(
type=SynapseType.DENSE_2, kwargs={"dim_1": 8, "dim_2": 16}
),
HyperSynapseConfig(
type=SynapseType.DENSE_2, kwargs={"dim_1": 16, "dim_2": 32}
),
]
# Synaptic connections. Synapse 0 takes activations from layers 0 and 1, synapse 1 takes activations from layers 1 and 2
synapse_connections = {0: (0, 1), 1: (1, 2)}
network_config = HAMConfig(
neuron_layers=layers, synapses=synapses, synapse_connections=synapse_connections
)
# initialize network as fully functional torch module
network = HAMNetwork(config=network_config)