In [10]:
import numpy as np
from sklearn.datasets import load_digits
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score
from sklearn.linear_model import LogisticRegression
import torch
from torchvision import datasets, transforms


### train on handwritten digits dataset

In [None]:
# Load and normalize the digits dataset
digits = load_digits()
X = digits.data.astype(float)
X /= X.max()  # scale pixel values to [0,1]
y = digits.target

# Train/test split
X_train, X_test, y_train, y_test = train_test_split(
    X, y, test_size=0.2, random_state=42)

num_classes = 10
num_features = X_train.shape[1]
weights = np.zeros((num_classes, num_features))
learning_rate = 0.01

predictions = np.argmax(weights @ X_test.T, axis=0)
print(f"Test initilize:", accuracy_score(y_test, predictions))
# Hebbian/perceptron learning
for epoch in range(5):
    for x_vec, label in zip(X_train, y_train):
        activations = weights @ x_vec
        pred_class = np.argmax(activations)
        # Reinforce correct class and penalize predicted class
        weights[label] += learning_rate * x_vec
        weights[pred_class] -= learning_rate * x_vec

    # Evaluate
    predictions = np.argmax(weights @ X_test.T, axis=0)
    print(f"Test accuracy on epoch {epoch}:", accuracy_score(y_test, predictions))


### train on CIFAR-10

In [8]:
# Transformation: convert images to tensors and flatten them
transform = transforms.Compose([
    transforms.ToTensor(),                  # convert PIL image to tensor in [0,1]
    transforms.Lambda(lambda x: x.view(-1)) # flatten 3×32×32 image to 3072‑D vector
])

# Load CIFAR‑10 training and test data
train_dataset = datasets.CIFAR10(root='./data', train=True, download=True,
                                transform=transform)
test_dataset  = datasets.CIFAR10(root='./data', train=False, download=True,
                                transform=transform)

train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=1, shuffle=True)
test_loader  = torch.utils.data.DataLoader(test_dataset,  batch_size=256, shuffle=False)

num_classes  = 10
num_features = 32 * 32 * 3

# Initialize weights as a NumPy array for efficiency
weights = np.zeros((num_classes, num_features), dtype=np.float32)
learning_rate = 1e-3

# Training loop (one epoch for simplicity)
for epoch in range(1):
    for data, target in train_loader:
        # data is a tensor of shape [1,3072]; convert to NumPy
        x = data.numpy().reshape(-1)
        label = target.item()
        # Compute activations
        activations = weights @ x
        pred_class  = activations.argmax()
        # Hebbian/perceptron update: reinforce correct class and penalize predicted class
        weights[label] += learning_rate * x
        weights[pred_class] -= learning_rate * x

# Evaluate on the test set
correct = 0
count   = 0
for data, target in test_loader:
    x_batch = data.numpy().reshape(data.size(0), -1)  # batch_size × 3072
    outputs = weights @ x_batch.T
    preds   = outputs.argmax(axis=0)
    labels  = target.numpy()
    correct += (preds == labels).sum()
    count   += labels.shape[0]
print(f"Test accuracy: {correct / count:.4f}")

Test accuracy: 0.2674


### improve model by Using Oja’s/GHA

$\mathbf{w}_j \leftarrow \mathbf{w}_j + \eta\,y_j \left(\mathbf{x} - \sum_{k\leq j} y_k \mathbf{w}_k\right)$

In [11]:
# Load and normalize data
X, y = load_digits(return_X_y=True)
X = X.astype(float)
X /= X.max()

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

# Generalized Hebbian Algorithm (GHA)
num_components = 32
lr = 0.005
W = np.random.randn(num_components, X_train.shape[1]) * 0.01

for epoch in range(15):
    for x in X_train:
        y_out = W @ x
        for j in range(num_components):
            # reconstruction using already learned components
            recon = np.sum([y_out[k] * W[k] for k in range(j+1)], axis=0)
            W[j] += lr * y_out[j] * (x - recon)
# normalize
W /= np.linalg.norm(W, axis=1, keepdims=True)

# Project data and train classifier
X_train_proj = X_train @ W.T
X_test_proj  = X_test  @ W.T
clf = LogisticRegression(max_iter=1000)
clf.fit(X_train_proj, y_train)
print('Test accuracy:', accuracy_score(y_test, clf.predict(X_test_proj)))

Test accuracy: 0.9527777777777777


### on cifar-10

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
import torchvision
import torchvision.transforms as T
import math, random, time
from dataclasses import dataclass
# -------------------------------
# Utilities
# -------------------------------
def set_seed(seed=42):
    random.seed(seed); torch.manual_seed(seed); torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = False
    torch.backends.cudnn.benchmark = True

@dataclass
class HebbCfg:
    eta: float = 0.05          # Hebbian learning rate for the trace update
    trace_decay: float = 0.9   # decay for running Hebbian trace (0=no decay, 1=keep)
    bp_through_plasticity: bool = True  # allow gradients through H updates
    clip_h: float | None = 1.0 # clip Hebbian trace for stability (None to disable)
    act: str = "tanh"          # nonlinearity to bound post-activity for Hebb ("tanh" or "identity")

# -------------------------------
# Plastic Linear Layer
# W_eff = W + alpha ⊙ H
# H_{t+1} = trace_decay * H_t + eta * (y_hat ⊗ x)
# where y_hat is bounded post-activation (e.g., tanh(pre-softmax logits) or pre-activation).
# alpha is learnable (either scalar or per-weight)
# -------------------------------
class PlasticLinear(nn.Module):
    def __init__(self, in_features, out_features, cfg: HebbCfg,
                 learn_alpha=True, alpha_init=0.0, per_weight_alpha=True, bias=True):
        super().__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.cfg = cfg

        self.weight = nn.Parameter(torch.empty(out_features, in_features))
        self.bias = nn.Parameter(torch.zeros(out_features)) if bias else None

        if per_weight_alpha:
            self.alpha = nn.Parameter(torch.full((out_features, in_features), alpha_init)) if learn_alpha \
                         else torch.full((out_features, in_features), alpha_init, requires_grad=False)
        else:
            self.alpha = nn.Parameter(torch.tensor(alpha_init)) if learn_alpha \
                         else torch.tensor(alpha_init, requires_grad=False)

        # Hebbian trace H (fast weights): same shape as weight
        self.register_buffer("H", torch.zeros(out_features, in_features))

        # Init weights (Kaiming)
        nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))

        if bias:
            fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight)
            bound = 1 / math.sqrt(fan_in)
            nn.init.uniform_(self.bias, -bound, bound)

    @torch.no_grad()
    def reset_hebb(self):
        self.H.zero_()

    def _bound_post(self, y_pre):
        if self.cfg.act == "tanh":
            return torch.tanh(y_pre)
        elif self.cfg.act == "identity":
            return y_pre
        else:
            raise ValueError("Unknown act")

    def forward(self, x):
        # Effective weights = baseline + plastic component
        H_used = self.H.detach()
        alpha = self.alpha
        if alpha.dim() == 0:
            eff_w = self.weight + alpha * H_used
        else:
            eff_w = self.weight + alpha * H_used  # broadcast matches (out, in)

        y = F.linear(x, eff_w, self.bias)

        # Update Hebbian trace (online) **during forward** for next step
        # Use bounded post-activation y_hat to avoid runaway values.
        with torch.enable_grad():
            # If we don't want to backprop through plasticity, detach x or y_hat (or both)
            x_for_hebb = x if self.cfg.bp_through_plasticity else x.detach()
            y_hat = self._bound_post(y) if self.cfg.bp_through_plasticity else self._bound_post(y.detach())

            # Batch outer product: average over batch for stability
            # y_hat: [B, out], x: [B, in] => H_delta: [out, in]
            H_delta = (y_hat.unsqueeze(2) * x_for_hebb.unsqueeze(1)).mean(dim=0)

        # Hebbian decay + add new outer product
        # IMPORTANT: do NOT wrap in no_grad; we want the option to backprop through H if enabled.
        # But updating a buffer in graph requires a trick; we keep H as buffer and update out-of-graph.
        # So when bp_through_plasticity=True, gradients flow into eff_w via current H usage, not into H itself.
        with torch.no_grad():
            self.H.mul_(self.cfg.trace_decay).add_(self.cfg.eta * H_delta)
            if self.cfg.clip_h is not None:
                self.H.clamp_(-self.cfg.clip_h, self.cfg.clip_h)

        return y

In [2]:
# -------------------------------
# Simple ConvNet feature extractor + plastic classifier head
# -------------------------------
class HebbNet(nn.Module):
    def __init__(self, hebb_cfg: HebbCfg, hidden=256, plastic_layers=1):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(3, 64, 3, padding=1), nn.ReLU(inplace=True),
            nn.Conv2d(64, 64, 3, padding=1), nn.ReLU(inplace=True),
            nn.MaxPool2d(2),  # 16x16
            nn.Conv2d(64, 128, 3, padding=1), nn.ReLU(inplace=True),
            nn.Conv2d(128, 128, 3, padding=1), nn.ReLU(inplace=True),
            nn.MaxPool2d(2),  # 8x8
        )
        self.flatten_dim = 128 * 8 * 8

        self.hebb_cfg = hebb_cfg
        self.plastic_layers = plastic_layers

        if plastic_layers == 2:
            self.fc1 = PlasticLinear(self.flatten_dim, hidden, hebb_cfg, learn_alpha=True, alpha_init=0.0, per_weight_alpha=False)
            self.fc2 = PlasticLinear(hidden, 10, hebb_cfg, learn_alpha=True, alpha_init=0.0, per_weight_alpha=False)
        elif plastic_layers == 1:
            self.fc1 = nn.Linear(self.flatten_dim, hidden)
            self.fc2 = PlasticLinear(hidden, 10, hebb_cfg, learn_alpha=True, alpha_init=0.0, per_weight_alpha=False)
        else:
            raise ValueError("plastic_layers must be 1 or 2")

    def reset_hebb(self):
        # Reset all plastic layers
        if isinstance(self.fc1, PlasticLinear): self.fc1.reset_hebb()
        if isinstance(self.fc2, PlasticLinear): self.fc2.reset_hebb()

    def forward(self, x):
        x = self.conv(x)
        x = torch.flatten(x, 1)
        if isinstance(self.fc1, PlasticLinear):
            x = self.fc1(x)
        else:
            x = self.fc1(x)
        x = F.relu(x, inplace=True)
        x = self.fc2(x)
        return x

# -------------------------------
# Training
# -------------------------------
def get_loaders(data_root="./data", batch_size=128, num_workers=4):
    mean = (0.4914, 0.4822, 0.4465)
    std = (0.2470, 0.2435, 0.2616)
    train_tf = T.Compose([
        T.RandomHorizontalFlip(),
        T.RandomCrop(32, padding=4),
        T.ToTensor(),
        T.Normalize(mean, std),
    ])
    test_tf = T.Compose([
        T.ToTensor(),
        T.Normalize(mean, std),
    ])

    train_ds = torchvision.datasets.CIFAR10(root=data_root, train=True, download=True, transform=train_tf)
    test_ds = torchvision.datasets.CIFAR10(root=data_root, train=False, download=True, transform=test_tf)

    train_ld = DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory=True)
    test_ld = DataLoader(test_ds, batch_size=256, shuffle=False, num_workers=num_workers, pin_memory=True)
    return train_ld, test_ld

@torch.no_grad()
def evaluate(model, loader, device):
    model.eval()
    correct, total = 0, 0
    for x, y in loader:
        x, y = x.to(device, non_blocking=True), y.to(device, non_blocking=True)
        logits = model(x)
        pred = logits.argmax(dim=1)
        correct += (pred == y).sum().item()
        total += y.numel()
    return correct / total

In [None]:
set_seed(123)
device = "cuda" if torch.cuda.is_available() else "cpu"

# --- Hebbian/backprop config ---
hebb_cfg = HebbCfg(
    eta=0.05,                 # Hebbian update step
    trace_decay=0.95,         # decay of Hebbian trace
    bp_through_plasticity=True,  # let grads flow through plastic path
    clip_h=1.0,
    act="tanh",               # bound post-activity for stability
)

# --- Data ---
train_ld, test_ld = get_loaders(batch_size=128)

# --- Model ---r
model = HebbNet(hebb_cfg, hidden=256, plastic_layers=1).to(device)

# --- Optimizer (learn baseline weights + alpha) ---
# Weight decay only on baseline weights and biases; exclude alpha if you want
decay, no_decay = [], []
for n, p in model.named_parameters():
    if p.requires_grad:
        if n.endswith("weight") or n.endswith("bias"):
            decay.append(p)
        else:
            # alpha etc.
            no_decay.append(p)
optim = torch.optim.AdamW([
    {"params": decay, "weight_decay": 5e-4},
    {"params": no_decay, "weight_decay": 0.0},
], lr=3e-4, betas=(0.9, 0.999))

sched = torch.optim.lr_scheduler.CosineAnnealingLR(optim, T_max=50)
criterion = nn.CrossEntropyLoss()

In [4]:

# --- Train ---
epochs = 30
best_acc = 0.0
for ep in range(1, epochs + 1):
    model.train()
    # Episodic behavior: reset Hebbian trace each epoch (you can reset per batch instead)
    model.reset_hebb()

    t0 = time.time()
    for x, y in train_ld:
        x, y = x.to(device, non_blocking=True), y.to(device, non_blocking=True)

        logits = model(x)
        loss = criterion(logits, y)

        optim.zero_grad(set_to_none=True)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optim.step()

        acc = evaluate(model, test_ld, device)
        print(f"Epoch {ep:02d} | train loss: {loss.item():.4f} | test acc: {acc*100:5.2f}%")

    sched.step()
    train_time = time.time() - t0

    # Eval
    acc = evaluate(model, test_ld, device)
    best_acc = max(best_acc, acc)
    print(f"Epoch {ep:02d} | test acc: {acc*100:5.2f}% | best: {best_acc*100:5.2f}% | time {train_time:5.1f}s")

print("Done. Best accuracy:", best_acc)


RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.cuda.FloatTensor [10, 256]] is at version 4; expected version 1 instead. Hint: enable anomaly detection to find the operation that failed to compute its gradient, with torch.autograd.set_detect_anomaly(True).