$$HDC\ Image\ Baseline: Binarized\ MNIST$$

# Setup

In [1]:
import torch
from torch.utils.data import DataLoader, random_split
from torchvision import datasets, transforms
import numpy as np, time, os

# Repro / device
def set_seed(seed=123):
    import random
    random.seed(seed); np.random.seed(seed)
    torch.manual_seed(seed); torch.cuda.manual_seed_all(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)

set_seed(123)
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
DEVICE


device(type='cpu')

# Data

In [2]:
# We use raw grayscale MNIST but binarize at encode time (threshold=0.5).
BATCH_SIZE = 512
transform = transforms.ToTensor()

train_ds = datasets.MNIST("/content/data", train=True,  download=True, transform=transform)
test_ds  = datasets.MNIST("/content/data", train=False, download=True, transform=transform)

train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True,  num_workers=2, pin_memory=False)
test_loader  = DataLoader(test_ds,  batch_size=BATCH_SIZE, shuffle=False, num_workers=2, pin_memory=False)

len(train_ds), len(test_ds)


100%|██████████| 9.91M/9.91M [00:00<00:00, 41.1MB/s]
100%|██████████| 28.9k/28.9k [00:00<00:00, 1.02MB/s]
100%|██████████| 1.65M/1.65M [00:00<00:00, 9.49MB/s]
100%|██████████| 4.54k/4.54k [00:00<00:00, 4.10MB/s]


(60000, 10000)

# Utils

In [3]:
@torch.no_grad()
def bipolar_sign(x: torch.Tensor) -> torch.Tensor:
    # We map to {-1,+1}; ties go to +1 for determinism.
    return torch.where(x >= 0, torch.ones_like(x), -torch.ones_like(x))

def _rand_bip(shape, device=None):
    device = device or DEVICE
    r = torch.randint(0, 2, shape, device=device, dtype=torch.int8)
    return r.float().mul_(2).sub_(1)  # {0,1} -> {-1,+1}

def make_position_hvs(n_positions: int, dim: int, device=None) -> torch.Tensor:
    # We assign each pixel position a random bipolar hypervector.
    return _rand_bip((n_positions, dim), device=device or DEVICE)

@torch.no_grad()
def encode_binary_images_to_hvs(imgs: torch.Tensor,
                                pos_hvs: torch.Tensor,
                                threshold: float = 0.5) -> torch.Tensor:
    # Strategy: binarize -> flatten -> sum position HVs for '1' pixels -> majority sign.
    B, C, H, W = imgs.shape
    assert C == 1
    P = H * W
    x = (imgs.view(B, P) >= threshold).to(torch.float32).to(pos_hvs.device)  # [B,P], {0,1}
    hv = x @ pos_hvs                                                            # [B,D]
    return bipolar_sign(hv)

@torch.no_grad()
def build_class_prototypes(enc_loader: DataLoader,
                           pos_hvs: torch.Tensor,
                           n_classes: int = 10,
                           threshold: float = 0.5) -> torch.Tensor:
    # We bundle all sample HVs per class and sign to get one prototype per class.
    D = pos_hvs.shape[1]
    accum = torch.zeros((n_classes, D), device=pos_hvs.device, dtype=torch.float32)
    for imgs, labels in enc_loader:
        imgs   = imgs.to(pos_hvs.device, non_blocking=True)
        labels = labels.to(pos_hvs.device, non_blocking=True)
        hvs = encode_binary_images_to_hvs(imgs, pos_hvs, threshold=threshold)
        for c in range(n_classes):
            m = (labels == c)
            if m.any():
                accum[c] += hvs[m].sum(dim=0)
    return bipolar_sign(accum)

@torch.no_grad()
def cosine_sim(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
    a32 = a.float(); b32 = b.float()
    an = torch.linalg.norm(a32, dim=1, keepdim=True).clamp_min_(1e-8)
    bn = torch.linalg.norm(b32, dim=1, keepdim=True).clamp_min_(1e-8).T
    return (a32 @ b32.T) / (an * bn)

@torch.no_grad()
def predict_with_prototypes(imgs: torch.Tensor,
                            pos_hvs: torch.Tensor,
                            prototypes: torch.Tensor,
                            threshold: float = 0.5) -> torch.Tensor:
    hvs = encode_binary_images_to_hvs(imgs, pos_hvs, threshold=threshold)
    sims = cosine_sim(hvs, prototypes)
    return sims.argmax(dim=1)

@torch.no_grad()
def evaluate_loader(data_loader: DataLoader,
                    pos_hvs: torch.Tensor,
                    prototypes: torch.Tensor,
                    threshold: float = 0.5) -> float:
    correct = total = 0
    for imgs, labels in data_loader:
        imgs   = imgs.to(pos_hvs.device, non_blocking=True)
        labels = labels.to(pos_hvs.device, non_blocking=True)
        preds = predict_with_prototypes(imgs, pos_hvs, prototypes, threshold=threshold)
        correct += (preds == labels).sum().item()
        total   += labels.numel()
    return 100.0 * correct / total


# Model

## Initialize HDC space

In [4]:
DIM = 10_000  # We typically use 8k–20k; 10k is a solid default.
H, W = 28, 28
P = H * W

pos_hvs = make_position_hvs(P, DIM, device=DEVICE)
pos_hvs.shape, pos_hvs.device

(torch.Size([784, 10000]), device(type='cpu'))

## Train prototypes

In [5]:
t0 = time.time()
prototypes = build_class_prototypes(train_loader, pos_hvs, n_classes=10, threshold=0.5)
print("Prototypes:", prototypes.shape, "built in %.2fs" % (time.time() - t0))


Prototypes: torch.Size([10, 10000]) built in 18.35s


In [11]:
# torch.save(prototypes, "prototypes_MNIST.pt")

## Evaluate

In [None]:
# torch.load("prototypes_MNIST.pt")

In [7]:
acc = evaluate_loader(test_loader, pos_hvs, prototypes, threshold=0.5)
print(f"Test accuracy (MNIST): {acc:.2f}%")

Test accuracy (MNIST): 81.58%
