In [1]:
import numpy as np
from sklearn import metrics
from sklearn.metrics import accuracy_score
from sklearn.model_selection import train_test_split
from tqdm import tqdm
# import torch
# from torchvision.datasets import FashionMNIST, MNIST
# from torchvision.transforms import ToTensor, Lambda
import matplotlib.pyplot as plt
import os
import gzip
import csv
import torch
from torch.utils.data import TensorDataset, DataLoader
from torch import Tensor
from torchvision import datasets, transforms
from scipy.io import loadmat, savemat

In [2]:
hyperdims = loadmat('../EHDGNet_MNIST_nHD.mat')
hyperdims = hyperdims['EHDGNet_MNIST_nHD']
hyperdims = np.mean(hyperdims, axis=1, dtype=int)
hyperdims

array([ 5000,  6000,  7000,  8000,  9000, 10000, 11000, 11750, 12750,
       13750, 14750])

In [3]:
# 1. Device and hyperparameters
device     = torch.device("cuda" if torch.cuda.is_available() else "cpu")
n_splits   = 20
hyperdims  = hyperdims
batch_size = 10

# 2. Load & preprocess MNIST
transform = transforms.Compose([
    transforms.ToTensor(),                           # → [0,1], shape (1,28,28)
    transforms.Lambda(lambda x: (x > 0.5).float()),  # binarize
    transforms.Lambda(lambda x: x.view(-1))         # flatten → (784,)
])

train_ds = datasets.MNIST(root='../../Data', train=True,  download=True, transform=transform)
test_ds  = datasets.MNIST(root='../../Data', train=False, download=True, transform=transform)

X_train = torch.stack([img for img, _ in train_ds], dim=0).to(device)  # (60000, 784)
y_train = torch.tensor([lbl for _, lbl in train_ds], device=device)

X_test  = torch.stack([img for img, _ in test_ds],  dim=0).to(device)  # (10000, 784)
y_test  = torch.tensor([lbl for _, lbl in test_ds],  device=device)

B = X_train.size(1)                   # 784 pixels
C = len(torch.unique(y_train))        # 10 classes
split_size = X_test.size(0) // n_splits

# 3. HDC utility functions

def generate_base_HDVs(D, B, device):
    """Generate a (B × D) random binary matrix."""
    return (torch.rand(B, D, device=device) > 0.5).int()  # intTensor of 0/1

@torch.no_grad()
def encode_dataset_batched(X, base_HDVs, batch_size=128):
    """
    Encode X in chunks to avoid OOM.
    X:          (N, B) floatTensor {0,1}
    base_HDVs:  (B, D) intTensor {0,1}
    returns:    (N, D) intTensor {0,1}
    """
    N, B = X.shape
    D    = base_HDVs.shape[1]

    # Precompute roll-shifted HDVs once
    perm_HDVs = base_HDVs.roll(shifts=1, dims=1)  # (B, D)

    # Expand for broadcasting
    base = base_HDVs.unsqueeze(0)   # (1, B, D)
    perm = perm_HDVs.unsqueeze(0)   # (1, B, D)

    chunks = []
    for i in (range(0, N, batch_size)):
        xb    = X[i : i+batch_size]           # (b, B)
        xb_exp= xb.unsqueeze(-1)              # (b, B, 1)

        # When pixel==1 pick perm, else pick base
        weighted = xb_exp * perm + (1 - xb_exp) * base  # (b, B, D)
        H_float  = weighted.mean(dim=1)                 # (b, D)
        chunks.append(torch.round(H_float).int())       # (b, D)

    return torch.cat(chunks, dim=0)  # (N, D)

def encode_class_HDVs(H_train, y_train, C):
    """
    Bundle all train-HDVs per class.
    H_train: (N, D), y_train: (N,)
    returns: (C, D)
    """
    class_HDVs = []
    for c in range(C):
        subset = H_train[y_train == c]        # (Nc, D)
        m      = subset.float().mean(dim=0)   # (D,)
        class_HDVs.append(torch.round(m).int())
    return torch.stack(class_HDVs, dim=0)    # (C, D)

@torch.no_grad()
def predict(H_test, class_HDVs):
    """
    Nearest-neighbor by Hamming distance.
    H_test:     (M, D), class_HDVs: (C, D)
    returns:    (M,) predicted labels
    """
    diffs = H_test.unsqueeze(1) != class_HDVs.unsqueeze(0)  # (M, C, D)
    dists = diffs.sum(dim=2)                                # (M, C)
    return dists.argmin(dim=1)                              # (M,)

In [4]:
n_splits   = 20
hyperdims  = hyperdims
accuracies = np.zeros((len(hyperdims), n_splits), dtype=float)
for idx, D in enumerate(hyperdims):
    print(f"\n==> Hyperdimension: {D}")
    base_HDVs  = generate_base_HDVs(D, B, device)
    H_train    = encode_dataset_batched(X_train, base_HDVs, batch_size)
    class_HDVs = encode_class_HDVs(H_train, y_train, C)

    for i in tqdm(range(n_splits)):
        s, e = i * split_size, (i + 1) * split_size
        Xs, ys = X_test[s:e], y_test[s:e]

        Hs    = encode_dataset_batched(Xs, base_HDVs, batch_size)
        preds = predict(Hs, class_HDVs)
    
        accuracies[idx, i] = (preds == ys).float().mean().item()
        # print(f'Accuracy for split index {i}: {accuracies[idx, i]}')

    print(f"Average accuracy: {accuracies[idx].mean().item():.4f} for Hyperdim {D}")


==> Hyperdimension: 5000


100%|██████████| 20/20 [00:01<00:00, 12.89it/s]


Average accuracy: 0.7948 for Hyperdim 5000

==> Hyperdimension: 6000


100%|██████████| 20/20 [00:01<00:00, 10.87it/s]


Average accuracy: 0.7973 for Hyperdim 6000

==> Hyperdimension: 7000


100%|██████████| 20/20 [00:02<00:00,  9.27it/s]


Average accuracy: 0.8066 for Hyperdim 7000

==> Hyperdimension: 8000


100%|██████████| 20/20 [00:02<00:00,  8.13it/s]


Average accuracy: 0.7953 for Hyperdim 8000

==> Hyperdimension: 9000


100%|██████████| 20/20 [00:02<00:00,  7.24it/s]


Average accuracy: 0.8066 for Hyperdim 9000

==> Hyperdimension: 10000


100%|██████████| 20/20 [00:03<00:00,  6.56it/s]


Average accuracy: 0.8093 for Hyperdim 10000

==> Hyperdimension: 11000


100%|██████████| 20/20 [00:03<00:00,  5.86it/s]


Average accuracy: 0.8044 for Hyperdim 11000

==> Hyperdimension: 11750


100%|██████████| 20/20 [00:03<00:00,  5.54it/s]


Average accuracy: 0.8059 for Hyperdim 11750

==> Hyperdimension: 12750


100%|██████████| 20/20 [00:03<00:00,  5.11it/s]


Average accuracy: 0.8104 for Hyperdim 12750

==> Hyperdimension: 13750


100%|██████████| 20/20 [00:04<00:00,  4.74it/s]


Average accuracy: 0.8095 for Hyperdim 13750

==> Hyperdimension: 14750


100%|██████████| 20/20 [00:04<00:00,  4.43it/s]

Average accuracy: 0.8036 for Hyperdim 14750





In [6]:
from scipy.io import savemat, loadmat
savemat('HoloGN_MNIST.mat', {'HoloGN_MNIST': accuracies*100})