Skip to content

CoEich/HAM-PyTorch

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

25 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

HAM-PyTorch

This repository is a light-weight re-implementation of the Hamux framework for generalized hierarchical associative memory (HAM) networks in PyTorch.

The paper on hierarchical associative memory can be found here.

Disclaimer

The purpose of this repository is mostly educational for myself. I wanted to understand better how energy-based models can be implemented in PyTorch (especially with iterated gradients). This is possible by using torch.autograd.grad with the option create_graph=True. This way one can backpropagate through a model where the forward pass is a discretized gradient flow with respect to some global energy.

The code in this repository is experimental and not built for efficient training. Furthermore, it is purely a re-implementation of existing methods. Credits go to Ben Hoover for the original JAX implementation and to Dmitry Krotov for the invention of HAM in this form.

Install

Download uv

curl -LsSf https://astral.sh/uv/install.sh | sh

Install package

From the root of the repository, run the command

uv sync

Then activate the virtual environment by

source .venv/bin/activate

How to build a HAM network

A HAM network is specified by the following data:

  • A list of neuron layers with associated lagrangians (whose derivatives will be the activation function acting on the neurons).
  • A list of (hyper) synapses, which take a number of neuron activations as input and return a scalar value, representing the energy of the synapse.
  • A list of synaptic connections, i.e. which layers wire to which synapses.

As an example, consider the following hierarchical network with three neuron layers and a synapse connecting subsequent layers:

alt text

(Image taken from the HAM paper)

The following code snippet implements this network as a fully functional torch module:

from ham.config import (
    NeuronLayerConfig,
    SynapseType,
    HyperSynapseConfig,
    LagrangianConfig,
    LagrangianType,
    HAMConfig,
)
from ham.network import HAMNetwork

# Three neuron layers with different number of neurons and lagrangians
layers = [
        NeuronLayerConfig(
            num_neurons=8, lagrangian=LagrangianConfig(type=LagrangianType.QUADRATIC)
        ),
        NeuronLayerConfig(
            num_neurons=16,
            lagrangian=LagrangianConfig(type=LagrangianType.LSE, kwargs={"beta": 1.0}),
        ),
        NeuronLayerConfig(
            num_neurons=32,
            lagrangian=LagrangianConfig(
                type=LagrangianType.LOG_COSH, kwargs={"beta": 1.0}
            ),
            time_constant=0.1,
        ),
    ]

# Two synapses that each take two neuron activations as input
synapses = [
    HyperSynapseConfig(
        type=SynapseType.DENSE_2, kwargs={"dim_1": 8, "dim_2": 16}
    ),
    HyperSynapseConfig(
        type=SynapseType.DENSE_2, kwargs={"dim_1": 16, "dim_2": 32}
    ),
]

# Synaptic connections. Synapse 0 takes activations from layers 0 and 1, synapse 1 takes activations from layers 1 and 2
synapse_connections = {0: (0, 1), 1: (1, 2)}

network_config = HAMConfig(
    neuron_layers=layers, synapses=synapses, synapse_connections=synapse_connections
)

# initialize network as fully functional torch module
network = HAMNetwork(config=network_config)

About

Implementation of Hierarchical Associative Memory in PyTorch

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors