# Continuous Hopfield Neural Networks

## Introduction

In [None]:
# packages and versions
import os
from typing import Callable, List, Tuple
from tqdm import notebook

import torch
print(f"{torch.__version__=}")
from torch import nn
from torch import optim
from torch.optim import optimizer
from torch.nn import modules
from torch.nn.modules import loss
from torch.utils import data

import torchvision
print(f"{torchvision.__version__=}")
from torchvision import datasets
from torchvision import transforms

# settings
print(f"{torch.cuda.is_available()=}")

In [None]:
# download data
data_root = os.path.join('.', 'data', 'db')
datasets.CIFAR100(os.path.join(data_root, 'CIFAR100'), train=True, download=True)
datasets.CIFAR100(os.path.join(data_root, 'CIFAR100'), train=False, download=True)
datasets.CIFAR10(os.path.join(data_root, 'CIFAR10'), train=True, download=True)
datasets.CIFAR10(os.path.join(data_root, 'CIFAR10'), train=False, download=True)
datasets.MNIST(os.path.join(data_root, 'MNIST'), train=True, download=True)
datasets.MNIST(os.path.join(data_root, 'MNIST'), train=False, download=True)

# Implementations Hopfield network

In [17]:
class Conv2dHopfield(nn.Module):
    
    channels_in: int  # number of input channels
    channels_out: int  # number of output channels
    kernel_size: Tuple[int, int]  # kernel size
    padding: Tuple[int, int]  # padding
    stride: Tuple[int, int]  # stride

    iterations: int  # number of iterations

    _logbeta: nn.Parameter  # logbeta
    _patterns: nn.Parameter  # patterns

    def __init__(
            self,
            channels_in: int,
            channels_out: int,
            kernel_size: Tuple[int, int],
            padding: Tuple[int, int] = (0, 0),
            stride: Tuple[int, int] = (1, 1),
            iterations: int = 1,
            beta_trainable: bool = False,
            patterns_trainable: bool = False,
        ) -> None:
        super(Conv2dHopfield, self).__init__()

        assert channels_in > 0
        assert channels_out > 0
        assert kernel_size[0] > 0
        assert kernel_size[1] > 0
        assert padding[0] >= 0
        assert padding[1] >= 0
        assert stride[0] > 0
        assert stride[1] > 0
        assert iterations > 0

        self.channels_in = channels_in
        self.channels_out = channels_out
        self.kernel_size = kernel_size
        self.padding = padding
        self.stride = stride
        self.iterations = iterations

        self._logbeta = nn.Parameter(torch.tensor([0.0]), requires_grad=beta_trainable)
        self._patterns = nn.Parameter(torch.tanh(torch.randn(channels_out, channels_in, kernel_size[0], kernel_size[1])), requires_grad=patterns_trainable)

    def update_patterns(self, patterns: torch.Tensor) -> None:
        assert patterns.shape == self._patterns.data.shape
        self._patterns = nn.Parameter(
            patterns.clone().detach().requires_grad_(self._patterns.requires_grad),
            requires_grad=self._patterns.requires_grad)
    def update_beta(self, beta: float) -> None:
        assert beta > 0
        self._logbeta.data = nn.Parameter(
            torch.log(torch.tensor(beta)).clone().detach().requires_grad_(self._logbeta.requires_grad),
            requires_grad=self._logbeta.requires_grad)

    def forward(self, x: torch.Tensor) -> torch.Tensor:

        assert x.shape[1] == self.channels_in

        weight_loop = (
            torch.exp(self._logbeta) 
            * self._patterns.view(self.channels_out, -1) 
            @ self._patterns.view(self.channels_out, -1).T
        ).view(self.channels_out, self.channels_out, 1, 1)

        x = nn.functional.conv2d(x, torch.exp(self._logbeta) * self._patterns, bias=False, padding=self.padding, stride=self.stride)
        for _ in range(self.iterations):
            x = nn.functional.softmax(x, dim=1)
            x = nn.functional.conv2d(x, weight_loop, bias=False, padding=(0, 0), stride=(1, 1))

        return x

In [13]:
class FCHopfield(nn.Module):

    channels : int  # number of input channels
    features : int  # number of features
    neurons : int  # number of neurons

    iterations : int  # number of iterations

    _logbeta : nn.Parameter  # logbeta
    _patterns : nn.Parameter  # patterns

    def __init__(
            self,
            channels: int,
            features: int,
            neurons: int,
            iterations: int = 1,
            beta_trainable: bool = True,
            patterns_trainable: bool = True,
        ) -> None:
        super(FCHopfield, self).__init__()

        assert channels > 0
        assert features > 0
        assert neurons > 0
        assert iterations > 0

        self.channels = channels
        self.features = features
        self.neurons = neurons
        self.iterations = iterations

        self._logbeta = nn.Parameter(torch.zeros(channels), requires_grad=beta_trainable)
        self._patterns = nn.Parameter(torch.tanh(torch.randn(channels, features, neurons)), requires_grad=patterns_trainable) 
    
    def update_patterns(self, patterns: torch.Tensor) -> None:
        assert patterns.shape == self._patterns.data.shape
        self._patterns = nn.Parameter(
            patterns.clone().detach().requires_grad_(self._patterns.requires_grad),
            requires_grad=self._patterns.requires_grad
        )
    def update_beta(self, beta: torch.Tensor) -> None:
        assert beta.shape == self._logbeta.data.shape
        assert torch.all(beta > 0)

        self._logbeta.data = nn.Parameter(
            torch.log(beta).clone().detach().requires_grad_(self._logbeta.requires_grad),
            requires_grad=self._logbeta.requires_grad)

    def forward(self, x: torch.Tensor) -> torch.Tensor:

        assert x.shape[1] == self.channels
        assert x.shape[2] == self.neurons

        L = torch.exp(self._logbeta)[:, torch.newaxis, torch.newaxis] * self._patterns  # L.shape = (channels, features, neurons)
        A = L @ self._patterns.transpose(1, 2)  # A.shape = (channels, features, features)
        
        x = torch.einsum('cfn, bcn -> bcf', L, x)  # x.shape = (batch, channels, features)
        for _ in range(self.iterations):
            x = nn.functional.softmax(x, dim=2)
            x = torch.einsum('cfp, bcp -> bcf', A, x)  # x.shape = (batch, channels, features)

        return x

In [14]:
def trainer(
        model: nn.Module,
        data_loader: data.DataLoader,
        loss_fn: loss._Loss,
        reg_fn: Callable[[nn.Module], torch.Tensor],
        optimizer: optimizer.Optimizer,
        n_epochs: int,
        input_converter: None|Callable[[torch.Tensor], torch.Tensor] = None,
        output_converter: None|Callable[[torch.Tensor], torch.Tensor] = None,
        device: torch.device = torch.device("cpu")
    ):
    torch.cuda.empty_cache()
    model.to(device)
    model.train()

    progressBar_epoch = notebook.tqdm(
        range(n_epochs),
        desc="Epochs",
        leave=False)
    for _ in progressBar_epoch:

        loss_epoch = 0.0
        reg_epoch = 0.0

        x: torch.Tensor
        y: torch.Tensor

        progressBar_batch = notebook.tqdm(
            data_loader,
            desc="Batches",
            total=len(data_loader),
            leave=False)
        for x, y in progressBar_batch:

            x, y = x.to(device), y.to(device)

            if input_converter != None:
                x = input_converter(x)
            if output_converter != None:
                y = output_converter(y)

            optimizer.zero_grad()
            _y = model(x)
            batch_reg: torch.Tensor = reg_fn(model)
            batch_loss: torch.Tensor = loss_fn(_y, y)
            batch_total: torch.Tensor = batch_loss + batch_reg
            batch_total.backward()
            optimizer.step()

            loss_epoch += batch_loss.item()
            reg_epoch += reg_fn(model).item()

        loss_epoch /= len(data_loader)
        reg_epoch /= len(data_loader)
        progressBar_epoch.set_postfix(
            loss=loss_epoch,
            reg=reg_epoch,
        )

In [15]:
def tester_classification(model: nn.Module, data_loader: data.DataLoader, device: torch.device) -> float:
    model.to(device)
    model.eval()

    correct = 0
    total = 0

    with torch.no_grad():
        
        x: torch.Tensor
        y: torch.Tensor

        for x, y in notebook.tqdm(data_loader, desc="Batches", leave=False):
            x, y = x.to(device), y.to(device)
            _y = model(x)
            _, predicted = torch.max(_y, 1)
            total += y.size(0)
            correct += (predicted == y).sum().item()
    
    return correct / total

# Applications

In [8]:
# CNN2d layers

class CNN2d(nn.Module):

    """This convolutional layer get a 2D data with in_channels: input channels

        Usage:
            >>> CNN2d(in_channels, deep, encrease=e, momentum=0.4)
            This layer return an image with in_channels*4*e channels and
            halve the size.
            >>> CNN2d(in_channels, deep, decrease=d, momentum=0.4)
            This layer return an image with in_channels*4//d channels and
            halve the size.

        In details:
            - the layer applies kernel (3,3) with padding (1,1)
                encode the input with out_channels
            - for deep times:
                - the layer applies a different PReLU over each channel
                - the layer applies a 1x1 convolution over out_channels with momentum
                - the layer applies kernel (3,3) with padding (1,1) over out_channels without momentum
            - the layer applies a batch normalisation over each channel
                decode the input with maxpooling with kernel (2,2)
            - finally, the layer applies a SiLU activation
    """

    def __init__(
        self,
        in_channels: int,
        deep: int,
        encrease: int = 1, decrease: int = 1,
        momentum: float = 0.5):

        super(CNN2d, self).__init__()

        if encrease > 1 and decrease > 1:
            raise ValueError("encrease and decrease cannot be both greater than 1")
        if encrease < 1 or decrease < 1:
            raise ValueError("encrease and decrease must be not less than 1")
        if deep < 0:
            raise ValueError("deep must be greater than 0")

        out_channels = in_channels * 4 * encrease // decrease
        out_channels_Ymomentum = int(out_channels * momentum)
        out_channels_Nmomentum = out_channels - out_channels_Ymomentum

        self.encode = nn.Conv2d(
            in_channels, out_channels, (3, 3), padding=(1, 1)
        )
        self.list = nn.ModuleList(
            [
                nn.ModuleDict(
                    {
                        'act': nn.PReLU(out_channels),
                        '1': nn.Sequential(
                            nn.Conv2d(
                                out_channels, out_channels_Ymomentum, (1, 1)
                            ),
                        ),
                        '3': nn.Sequential(
                            nn.PReLU(out_channels),
                            nn.Conv2d(
                                out_channels, out_channels_Nmomentum, (3, 3), padding=(1, 1)
                            ),
                        ),
                    }
                )
                for _ in range(deep)
            ]
        )
        self.bn = torch.nn.BatchNorm2d(out_channels)
        self.decode = torch.nn.MaxPool2d((2, 2), (2, 2))
        self.silu = torch.nn.SiLU()

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.encode(x)
        for layer in self.list:
            x = layer['act'](x)
            y1 = layer['1'](x)
            y3 = layer['3'](x)
            x = torch.cat([y1, y3], dim=1)
        x = self.bn(x)
        x = self.decode(x)
        x = self.silu(x)
        return x

## MNIST

In [None]:
class MyModel(nn.Module):

    """ input shape: (batch, 1, 32, 32) """
    
    class View(nn.Module):
        def __init__(self, *shape):
            super(MyModel.View, self).__init__()
            self.shape = shape
        def forward(self, x: torch.Tensor) -> torch.Tensor:
            return x.view(x.shape[0], *self.shape)
    
    def __init__(self):
        super(MyModel, self).__init__()
        self.CNN = nn.Sequential(
            # prima convoluzione
            Conv2dHopfield(1, 4, (3, 3), padding=(1, 1), iterations=1),
            torch.nn.PReLU(4),
            Conv2dHopfield(4, 16, (3, 3), padding=(1, 1), iterations=1),
            nn.BatchNorm2d(16),
            nn.MaxPool2d((2, 2), (2, 2)),
            nn.SiLU(),

            # seconda convoluzione
            Conv2dHopfield(16, 32, (3, 3), padding=(1, 1), iterations=1),
            torch.nn.PReLU(32),
            Conv2dHopfield(32, 64, (3, 3), padding=(1, 1), iterations=1),
            nn.BatchNorm2d(64),
            nn.MaxPool2d((2, 2), (2, 2)),
            nn.SiLU(),

            # terza convoluzione
            Conv2dHopfield(64, 64, (3, 3), padding=(1, 1), iterations=1),
            torch.nn.PReLU(64),
            Conv2dHopfield(64, 64, (3, 3), padding=(1, 1), iterations=1),
            nn.BatchNorm2d(64),
            nn.MaxPool2d((2, 2), (2, 2)),
            nn.SiLU(),

            MyModel.View(64, 4*4),
        )
        self.MLP = nn.Sequential(
            FCHopfield(64, 8, 16, iterations=1),
            nn.ReLU(),
            MyModel.View(1, 64*8),
            
            FCHopfield(1, 10, 64*8, iterations=2),
            nn.Flatten(),
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.CNN(x)
        x = self.MLP(x)
        return x

model = MyModel()
model.to(torch.device("cuda" if torch.cuda.is_available() else "cpu"))

# number of trainable parameters
print(f"number of parameters="
    f"{sum([torch.numel(parameter) for parameter in model.parameters()])}")

In [None]:
# training
"""
transform=transforms.Compose(
    [
        transforms.RandomResizedCrop(32, scale=(0.8, 1.0)),
        transforms.RandomAffine(degrees=0, scale=(1.0, 1.1), shear=0),
        transforms.ColorJitter(
            contrast=(0.9, 1.5),
            saturation=(0.9, 1.3),
            brightness=(0.9, 1.3),
            hue=(-0.05, 0.05),
        ),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize(
            mean=[0.4914, 0.4822, 0.4465], std=[0.2023, 0.1994, 0.2010]
        ),
    ]
)
"""
transform=transforms.Compose(
    [
        transforms.ToTensor(),
        transforms.Grayscale(),
        transforms.Pad(2),  # pad to 32x32
        transforms.RandomAffine(
            degrees=5, translate=(0.1, 0.1), scale=(0.9, 1.1), shear=10
        ),
        transforms.ColorJitter(contrast=(0.9, 1.5)),
        transforms.Normalize((0.5,), (0.5,)),
    ]
)

train_set = datasets.MNIST(root=os.path.join(data_root, 'MNIST'), train=True, download=False, transform=transform)
train_loader = data.DataLoader(train_set, batch_size=2_048, shuffle=True)

loss_fn = nn.CrossEntropyLoss()

optimizer = optim.Adam(model.parameters(), lr=0.0001)

def reg_fn(model: nn.Module) -> torch.Tensor:
    return torch.tensor(0.0)

trainer(model, train_loader, loss_fn, reg_fn, optimizer, 40, device=torch.device("cuda" if torch.cuda.is_available() else "cpu"))

In [None]:
# testing
test_set = datasets.CIFAR10(root=os.path.join(data_root, 'CIFAR10'), train=False, download=False, transform=transform)
test_loader = data.DataLoader(test_set, batch_size=2_048, shuffle=False)

accuracy = tester_classification(model, test_loader, torch.device("cuda" if torch.cuda.is_available() else "cpu"))

print(f"{accuracy=}")

In [33]:
# save model
torch.save(model.state_dict(), "data/models/Cifar10.pt")

## MNIST
In this example we apply a Hopfield network directly over MNIST.

In [71]:
transform=transforms.Compose(
    [
        transforms.ToTensor(),
        transforms.Grayscale(),
        transforms.Pad(2),  # pad to 32x32
        transforms.RandomAffine(
            degrees=5, translate=(0.1, 0.1), scale=(0.9, 1.1), shear=10
        ),
        transforms.ColorJitter(contrast=(0.9, 1.5)),
    ]
)

In [None]:
# load patterns
from PIL import Image
import numpy

mnist_models = './data/models/MNIST'

patterns = torch.empty(30, 32, 32)
for i in range(10):
    for j in range(3):
        img = Image.open(f"{mnist_models}/pattern{i}_{j}.png").convert('L')
        npimg = numpy.array(img)
        patterns[i*3+j] = torch.tensor(npimg)

patterns = patterns.unsqueeze(1).float() / 255.0
patterns = 2.0 * patterns - 1.0

# mostro i pattern in una tabella con 10 righe e 3 colonne
import matplotlib.pyplot as plt

fig, axs = plt.subplots(10, 3, figsize=(10, 30))
for i in range(10):
    for j in range(3):
        axs[i, j].imshow(patterns[i*3+j,0], cmap='gray', vmin=-1, vmax=1)
        axs[i, j].axis('off')
    
plt.show()

In [None]:
# creo la rete di Hopfield con i pattern forniti
model = FCHopfield(1, 30, 32*32, iterations=1, beta_trainable=False, patterns_trainable=False)
model.update_patterns(patterns.view(1, 30, -1))
model.update_beta(torch.ones(1)*0.07)

model.to(torch.device("cuda" if torch.cuda.is_available() else "cpu"))

test_set = datasets.MNIST(root=os.path.join(data_root, 'MNIST'), train=False, download=False, transform=transform)
data_loader = data.DataLoader(test_set, batch_size=2, shuffle=False)
batch_size = 2

# applico model a un batch di immagini
x, y = next(iter(data_loader))

#x = patterns.view(30, 1, 32, 32).to(torch.device("cuda" if torch.cuda.is_available() else "cpu"))
#batch_size = x.shape[0]
x = x.to(torch.device("cuda" if torch.cuda.is_available() else "cpu"))
x = nn.BatchNorm2d(1, affine=False).to('cuda')(x)
x = x.view(x.shape[0], 1, -1)
y = y.to(torch.device("cuda" if torch.cuda.is_available() else "cpu"))

y_pred = model(x)
probs = nn.functional.softmax(y_pred, dim=2)

# mostro le immagini e i risultati
reconstructed_images = torch.einsum('phw, bp -> bhw', patterns.view(30, 32, 32), probs.view(batch_size, 30).to('cpu'))
fig, axs = plt.subplots(batch_size, 2, figsize=(5, int(batch_size*5/2)))
for i in range(batch_size):
    axs[i, 0].imshow(x[i].view(32, 32).detach().cpu().numpy(), cmap='gray')
    axs[i, 0].axis('off')
    axs[i, 1].imshow(reconstructed_images[i].detach().cpu().numpy(), cmap='gray')
    axs[i, 1].axis('off')
# titolo sulla prima colonna
axs[0, 0].set_title('Original')
axs[0, 1].set_title('Recunstructed')

plt.savefig('data/images/reconstructed.png', dpi=1200)

In [None]:
train_set = datasets.MNIST(root=os.path.join(data_root, 'MNIST'), train=True, download=False, transform=transform)
train_loader = data.DataLoader(train_set, batch_size=16_384, shuffle=True)

class MyModel(nn.Module):

    def __init__(self):
        super(MyModel, self).__init__()
        self.preHopfield = nn.BatchNorm2d(1, affine=False)
        self.Hopfield = FCHopfield(1, 30, 32*32, iterations=1)
        self.Hopfield.update_patterns(patterns.view(1, 30, -1))
        self.Hopfield.update_beta(torch.ones(1)*0.07)
        self.batchnorm = nn.BatchNorm2d(1)
        self.fc = nn.Linear(30, 10)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.preHopfield(x)
        x = x.view(x.shape[0], 1, -1)
        x = self.Hopfield(x)
        x = x.view(x.shape[0], -1)
        x = self.fc(x)
        return x

model = MyModel()

loss_fn = nn.CrossEntropyLoss()

optimizer = optim.Adam(model.parameters())

def reg_fn(model: nn.Module) -> torch.Tensor:
    return torch.tensor(0.0)

trainer(model, train_loader, loss_fn, reg_fn, optimizer, 20, device=torch.device("cuda" if torch.cuda.is_available() else "cpu"))

In [None]:
test_set = datasets.MNIST(root=os.path.join(data_root, 'MNIST'), train=False, download=False, transform=transform)
test_loader = data.DataLoader(test_set, batch_size=16_384, shuffle=False)

tester_classification(model, test_loader, torch.device("cuda" if torch.cuda.is_available() else "cpu"))

Mostriamo ora i pattern individuati con l'addestramento

In [None]:
model.to(torch.device("cpu"))

print(torch.exp(model.Hopfield._logbeta.data))

found_patterns = model.Hopfield._patterns.data.view(30, 32, 32).clone().to('cpu').numpy()

fig, axs = plt.subplots(10, 3, figsize=(10, 30))

for i in range(10):
    for j in range(3):
        axs[i, j].imshow(found_patterns[i*3+j], cmap='gray', vmin=-1, vmax=1)
        axs[i, j].axis('off')

plt.show()

In [None]:
print(torch.norm(patterns.flatten() - torch.from_numpy(found_patterns).flatten()))