In [1]:
import numpy as np
import pickle
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score
from sklearn.metrics.pairwise import cosine_distances
from torchvision.datasets import MNIST, FashionMNIST, CIFAR10
from torchvision import transforms
import torch.nn.functional as F
import torch
from tqdm import tqdm
from scipy.io import savemat, loadmat
import torchvision

In [2]:
# 1. Load CIFAR-10 into torch Tensors
def load_dataset(path: str, device: torch.device):
    ds_train = MNIST(path, train=True,  download=True)
    ds_test  = MNIST(path, train=False, download=True)

    # -- TRAIN --
    X_train = (ds_train.data)        \
    .reshape(-1, 1 * 28 * 28)        \
    .to(device)                      \
    .long()                          # (N, 3072)
    Y_train = ds_train.targets.to(device)

    # -- TEST --
    X_test  = (ds_test.data)         \
    .reshape(-1, 1 * 28 * 28)        \
    .to(device)                      \
    .long()                          # (N_test, 3072)
    Y_test  = ds_test.targets.to(device)

    return X_train, Y_train, X_test, Y_test


# 2. Build bipolar lookup table (torch version)
def lookup_generate(dim: int, datatype: str, n_keys: int, device: torch.device):
    if datatype != 'bipolar':
        raise ValueError("Only 'bipolar' supported")
    tbl = torch.randint(0, 2, (n_keys, dim), device=device, dtype=torch.int8)
    return tbl * 2 - 1  # map {0,1} → {-1,+1}

# 3. Encode a batch of images into hypervectors
@torch.no_grad()
def encode_batch(X: torch.LongTensor, position_table: torch.Tensor, grayscale_table: torch.Tensor):
    """
    X:               (N,3072) int tensor in [0..255]
    position_table:  (3072, dim)
    grayscale_table: (256,  dim)
    → returns (N, dim) int tensor
    """
    gray = grayscale_table[X]              # (N,3072,dim)
    pos  = position_table.unsqueeze(0)     # (1,3072,dim)
    hv   = (pos * gray).sum(dim=1)         # (N,dim)
    return hv

# 4. Train associative memory by summing all encodings per class
def train_am(X_train, Y_train, position_table, grayscale_table, dim: int):
    H_train = encode_batch(X_train, position_table, grayscale_table).float()  # (N,dim)
    C = int(Y_train.max().item()) + 1
    am = torch.zeros((C, dim), device=X_train.device, dtype=torch.float32)
    am = am.index_add(0, Y_train, H_train)
    return am

# 5. Single‐image prediction (returns class and query HV)
@torch.no_grad()
def predict_(am, img, position_table, grayscale_table):
    qhv = encode_batch(img.unsqueeze(0), position_table, grayscale_table).squeeze(0).float()
    sims = F.cosine_similarity(qhv.unsqueeze(0), am, dim=1)  # (C,)
    pred = int(sims.argmax().item())
    return pred, qhv

def predict(am, img, position_table, grayscale_table):
    pred, _ = predict_(am, img, position_table, grayscale_table)
    return pred

# 6. Test on full set
@torch.no_grad()
def test(am, X_test, Y_test, position_table, grayscale_table):
    H_test = encode_batch(X_test, position_table, grayscale_table).float()  # (N_test,dim)
    h_norm = H_test.norm(dim=1, keepdim=True)                              # (N,1)
    a_norm = am.norm(dim=1, keepdim=True).t()                              # (1,C)
    sims   = (H_test @ am.t()) / (h_norm * a_norm)                         # (N,C)
    preds  = sims.argmax(dim=1)                                            # (N,)
    acc    = (preds == Y_test).float().mean().item()
    print(f"Testing accuracy: {acc:.4f}")
    return acc

# 7. Load a saved model (AM + tables)
def loadmodel(fpath: str, device: torch.device = None):
    with open(fpath, 'rb') as f:
        am_np, pos_np, gray_np = pickle.load(f)
    am   = torch.from_numpy(am_np)
    pos  = torch.from_numpy(pos_np)
    gray = torch.from_numpy(gray_np)
    if device is not None:
        am, pos, gray = am.to(device), pos.to(device), gray.to(device)
    return am, pos, gray

# 8. Quantize the AM to a lower bit‐width
def quantize(am: torch.Tensor, before_bw: int, after_bw: int) -> torch.Tensor:
    if before_bw <= after_bw:
        return am.clone()
    shift = before_bw - after_bw
    return torch.round(am.float() / (2 ** shift)).to(am.dtype)

# 9. Batched AM training
@torch.no_grad()
def train_am_batched(
    X_train: torch.LongTensor,
    Y_train: torch.LongTensor,
    position_table: torch.Tensor,
    grayscale_table: torch.Tensor,
    dim: int,
    batch_size: int = 128,
    device: torch.device = None
) -> torch.Tensor:
    N = X_train.size(0)
    C = int(Y_train.max().item()) + 1
    am = torch.zeros(C, dim, device=device, dtype=torch.float32)
    for i in (range(0, N, batch_size)):
        xb = X_train[i : i + batch_size]
        yb = Y_train[i : i + batch_size]
        hb = encode_batch(xb, position_table, grayscale_table).float()
        am = am.index_add(0, yb, hb)
    return am

# 10. Test on a split (non-batched)
@torch.no_grad()
def test_split(am, X_split, Y_split, position_table, grayscale_table):
    Hs   = encode_batch(X_split, position_table, grayscale_table).float()  # (M,dim)
    sims = F.cosine_similarity(Hs.unsqueeze(1), am.unsqueeze(0), dim=2)   # (M,C)
    preds = sims.argmax(dim=1)                                            # (M,)
    return (preds == Y_split).float().mean().item()


@torch.no_grad()
def flip_rows_(tensor: torch.Tensor, perc: float) -> torch.Tensor:
    """
    In-place sign flip of `perc` fraction of elements in *every* row.

    tensor : (B, D)  – any real dtype (+1/-1 HVs work fine)
    perc   : 0‒1     – fraction of dimensions to flip per row

    Returns the same tensor object (for chaining).
    """
    if not 0.0 <= perc <= 1.0:
        raise ValueError("perc must be in [0,1]")

    # trivial cases
    if perc == 0.0 or tensor.numel() == 0:
        return tensor

    B, D = tensor.shape
    k = int(round(D * perc))            # exact #dims to flip / row
    if k == 0:
        return tensor

    device = tensor.device
    # --- vectorised selection of k unique indices per row -------------
    # 1) random scores per entry
    rnd = torch.rand(B, D, device=device)
    # 2) take indices of the largest k scores along each row
    _, idx = rnd.topk(k, dim=1, largest=True, sorted=False)  # (B,k)
    # 3) convert (row, col) pairs → flat indices
    base = torch.arange(B, device=device).unsqueeze(1) * D   # (B,1)
    flat_idx = (idx + base).reshape(-1)                      # (B*k,)
    # 4) flip in-place
    tensor.view(-1)[flat_idx] *= -1
    return tensor
    
# 11. Test on a split (batched)
@torch.no_grad()
def test_split_batched(
    am: torch.Tensor,
    X: torch.LongTensor,
    Y: torch.LongTensor,
    position_table: torch.Tensor,
    grayscale_table: torch.Tensor,
    encode_fn,
    flip_perc=0.0,
    batch_size: int = 128,
    device: torch.device = None
) -> float:
    correct, total = 0, 0
    for i in range(0, X.size(0), batch_size):
        xb = X[i : i + batch_size].to(device)
        yb = Y[i : i + batch_size].to(device)
        hb = encode_fn(xb, position_table, grayscale_table).float()
        flip_rows_(hb, perc=flip_perc)
        sims  = F.cosine_similarity(hb.unsqueeze(1), am.unsqueeze(0), dim=2)
        preds = sims.argmax(dim=1)
        correct += (preds == yb).sum().item()
        total   += yb.size(0)
    return correct / total

In [3]:
mnist_path = '../../Data'

In [4]:
# # hyperdims = loadmat('../EHDGNet_MNIST_nHD.mat')['EHDGNet_MNIST_nHD']
# hyperdims = np.mean(hyperdims, axis=1, dtype=int)
hyperdim = 15_000

In [5]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

X_train, Y_train, X_test, Y_test = load_dataset(mnist_path, device)

n_splits   = 20
split_size = X_test.size(0) // n_splits
flip_percs = np.arange(0.0, 0.51, 0.05)
accuracies = np.zeros((len(flip_percs), n_splits))
n_class    = 10
q_bit      = 16

In [6]:
for i, perc in enumerate(flip_percs):
    print(f"\n==> Flip Percentage: {perc}")
    # a) lookup tables
    position_table  = lookup_generate(hyperdim, 'bipolar', 28*28, device=device)
    grayscale_table = lookup_generate(hyperdim, 'bipolar', 256, device=device)

    # b) train AM
    am = train_am_batched(
        X_train, Y_train,
        position_table, grayscale_table,
        dim=hyperdim,
        batch_size=1,
        device=device
    )
    # c) quantize AM
    am_q = quantize(am, before_bw=16, after_bw=q_bit)

    # d) evaluate on splits
    for split_idx in tqdm(range(n_splits)):
        start = split_idx * split_size
        end   = start + split_size

        acc = test_split_batched(
            am_q,
            X_test[start:end],
            Y_test[start:end],
            position_table,
            grayscale_table,
            encode_batch,  
            flip_perc=perc,
            batch_size=12,
            device=device
        )
        accuracies[i, split_idx] = acc

    print("Accuracy average for 20 splits:", accuracies[i].mean())

    # ─── Free GPU memory ───────────────────────────────
    # Delete the large tensors you no longer need
    del position_table, grayscale_table, am, am_q
    # Run empty_cache so PyTorch can reuse that memory immediately
    torch.cuda.empty_cache()



==> Flip Percentage: 0.0


100%|██████████| 20/20 [00:04<00:00,  4.80it/s]


Accuracy average for 20 splits: 0.8230999999999999

==> Flip Percentage: 0.05


100%|██████████| 20/20 [00:04<00:00,  4.64it/s]


Accuracy average for 20 splits: 0.8199

==> Flip Percentage: 0.1


100%|██████████| 20/20 [00:04<00:00,  4.69it/s]


Accuracy average for 20 splits: 0.8164

==> Flip Percentage: 0.15000000000000002


100%|██████████| 20/20 [00:04<00:00,  4.69it/s]


Accuracy average for 20 splits: 0.8147000000000002

==> Flip Percentage: 0.2


100%|██████████| 20/20 [00:04<00:00,  4.69it/s]


Accuracy average for 20 splits: 0.8152000000000001

==> Flip Percentage: 0.25


100%|██████████| 20/20 [00:04<00:00,  4.69it/s]


Accuracy average for 20 splits: 0.8107

==> Flip Percentage: 0.30000000000000004


100%|██████████| 20/20 [00:04<00:00,  4.69it/s]


Accuracy average for 20 splits: 0.8031000000000003

==> Flip Percentage: 0.35000000000000003


100%|██████████| 20/20 [00:04<00:00,  4.69it/s]


Accuracy average for 20 splits: 0.7859

==> Flip Percentage: 0.4


100%|██████████| 20/20 [00:04<00:00,  4.68it/s]


Accuracy average for 20 splits: 0.7338

==> Flip Percentage: 0.45


100%|██████████| 20/20 [00:04<00:00,  4.69it/s]


Accuracy average for 20 splits: 0.5477000000000001

==> Flip Percentage: 0.5


100%|██████████| 20/20 [00:04<00:00,  4.69it/s]

Accuracy average for 20 splits: 0.10390000000000002





In [7]:
np.mean(accuracies, axis=1)

array([0.8231, 0.8199, 0.8164, 0.8147, 0.8152, 0.8107, 0.8031, 0.7859,
       0.7338, 0.5477, 0.1039])

In [8]:
from scipy.io import savemat
savemat('VanillaHDC_MNIST.mat', {'VanillaHHDC_MNIST': accuracies*100})