## Training a Hopfield network on MNIST

**Objective:** Implement and train a Hopfield network. The goal is to reconstruct the perturbed/masked patterns.

**Dataset:** MNIST

**Tasks:**
1. Data preparation: Binarize the MNIST images to have pixel values -1 or 1 (i.e., map pixels with value less than 127 to -1, and pixels with value greater or equal to 127 to 1). 
2. Choose a subset of 500 binarized MNIST images as above, and resize to 7x7 (by downsampling).
3. Choose 6 of these patterns and design a Hopfield Network to memorize these patterns while discouraging the memorization of other patterns.
4. Randomly pick one of the memorized patterns and flip 8 randomly-selected pixels of the corresponding image (i.e., changing $1 \rightarrow -1$ and $-1 \rightarrow 1$). Use the trained Hopfield Network to recover the original stored pattern.
5. Repeat this process 500 times and report the success rate.

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
import numpy as np

In [3]:
""" Data Preparation """

from torchvision.datasets import MNIST
from torchvision import transforms
from torch.utils.data import DataLoader

# Define a transform to normalize the data and apply basic augmentations
transform = transforms.Compose([
    transforms.RandomAffine(degrees=5, translate=(0.05, 0.05)),  # slight rotation and translation
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))  # Normalize to [-1, 1]
])

# Load the MNIST dataset
train_dataset = MNIST(root='./data', train=True, download=True, transform=transform)
test_dataset = MNIST(root='./data', train=False, download=True, transform=transform)

# DataLoader
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)


In [16]:
""" Hopfield network Architecture """
import torch
import torch.nn as nn
import torch.nn.functional as F

from tqdm import tqdm
from torch.nn import Parameter
from collections import defaultdict

# Please implement the Hopfield network with pytorch
class HopfieldNet(nn.Module):
  # You can add other necessary parameters to be passed in the _init_ function.
  # You can add other functions in the HopfieldNet class if necessary.
    def __init__(self, pattern_size):
        self.weights = np.zeros((pattern_size ** 2, pattern_size ** 2))

    def train(self, patterns):
        num_patterns = len(patterns)
        # pattern_size = patterns.shape[1]

        for pattern in patterns:
            pattern_flat = pattern.flatten()
            self.weights += np.outer(pattern_flat, pattern_flat)
        
        self.weights /= num_patterns
        np.fill_diagonal(self.weights, 0)

    def energy(self, pattern):
        return -0.5 * np.dot(pattern.flatten(), np.dot(self.weights, pattern.flatten()))

    def update(self, pattern, max_iter=100):
        pattern_flat = pattern.flatten()
        for _ in range(max_iter):
            new_pattern = np.sign(np.dot(self.weights, pattern_flat))
            if np.array_equal(new_pattern, pattern_flat):
                break
            pattern_flat = new_pattern
        return pattern_flat.reshape(pattern.shape)

In [37]:
"""Training"""
# Implement a training process that is suitable for your HopfieldNet class
# It is recommended to use "import time, t0 = time.time(), print(f"{i_epoch} ({time.time() - t0}s): {loss}")" to render the training time

subset_indices = np.random.choice(len(train_dataset), 500, replace=False)
subset_loader = DataLoader(train_dataset, batch_size=500, sampler=torch.utils.data.sampler.SubsetRandomSampler(subset_indices))

# Get the sampled images
subset_images, _ = next(iter(subset_loader))

# Resize to 7x7
subset_images_resized = F.interpolate(subset_images.view(-1, 1, 28, 28), size=7)  # Resize to 7x7
subset_images_resized = subset_images_resized.view(-1, 7*7)

# Normalize data between -1 and 1
subset_images_resized = torch.where(subset_images_resized < 0.5, torch.tensor(-1.0), torch.tensor(1.0))

train_patterns = subset_images_resized[:6]

# Train the Hopfield Network
hopfield_net = HopfieldNet(pattern_size=7)
hopfield_net.train(train_patterns)

# Step 4: Perturbation and Recovery
def perturb_image(image, num_flips=8):
    # Convert PyTorch tensor to NumPy array
    image_np = image.numpy()

    # Perform perturbation
    perturbed_image = image_np.copy()
    indices = np.random.choice(np.arange(image_np.size), num_flips, replace=False)
    perturbed_image.flat[indices] *= -1

    # Convert NumPy array back to PyTorch tensor
    perturbed_image_tensor = torch.from_numpy(perturbed_image)
    return perturbed_image_tensor

def test_recovery(original_image, perturbed_image, hopfield_net):
    recovered_image = hopfield_net.update(perturbed_image)
    return np.array_equal(original_image, recovered_image)

# Step 5: Success Rate Calculation
success_count = 0
num_trials = 500

for _ in range(num_trials):
    index = np.random.randint(len(train_patterns))
    original_image = train_patterns[index]
    # print(original_image)
    perturbed_image = perturb_image(original_image)
    # print(perturbed_image)
    if test_recovery(original_image, perturbed_image, hopfield_net):
        success_count += 1

success_rate = success_count / num_trials
print("Success rate:", success_rate)

Success rate: 0.174
