In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from torchvision.datasets import MNIST
import torchmetrics

from torchhd import functional
from torchhd import embeddings

In [2]:
BATCH_SIZE=1
NUM_LEVELS=10
DIMENSIONS=10000
IMG_SIZE = 28

In [3]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
device

device(type='cuda', index=0)

In [4]:
transform = torchvision.transforms.ToTensor()

train_ds = MNIST("data", train=True, transform=transform, download=False)
train_ld = torch.utils.data.DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True)

test_ds = MNIST("data", train=False, transform=transform, download=False)
test_ld = torch.utils.data.DataLoader(test_ds, batch_size=BATCH_SIZE, shuffle=False)

In [5]:
class Model(nn.Module):
    def __init__(self, num_classes, size):
        super(Model, self).__init__()

        self.flatten = torch.nn.Flatten()

        self.position = embeddings.Random(size * size, DIMENSIONS)
        self.value = embeddings.Level(NUM_LEVELS, DIMENSIONS)

        self.classify = nn.Linear(DIMENSIONS, num_classes, bias=False)
        self.classify.weight.data.fill_(0.0)

    def encode(self, x):
        x = self.flatten(x)
        sample_hv = functional.bind(self.position.weight, self.value(x))
        sample_hv = functional.multiset(sample_hv)
        return functional.hard_quantize(sample_hv)

    def forward(self, x):
        enc = self.encode(x)
        logit = self.classify(enc)
        return logit


In [6]:


model = Model(len(train_ds.classes), IMG_SIZE)
model = model.to(device)

In [1]:
with torch.no_grad():
    for samples, labels in train_ld:
        samples = samples.to(device)
        labels = labels.to(device)

        samples_hv = model.encode(samples)
        model.classify.weight[labels] += samples_hv

    model.classify.weight[:] = F.normalize(model.classify.weight)

NameError: name 'torch' is not defined

In [8]:
accuracy = torchmetrics.Accuracy("multiclass", num_classes=10)

In [9]:
with torch.no_grad():
    for samples, labels in test_ld:
        samples = samples.to(device)

        outputs = model(samples)
        predictions = torch.argmax(outputs, dim=-1)
        accuracy.update(predictions.cpu(), labels)

In [10]:
print(f"Testing accuracy of {(accuracy.compute().item() * 100):.3f}%")

Testing accuracy of 82.990%
