In [None]:
import torch
import torch.nn as nn
import torchvision
from torchvision.datasets import MNIST
from tqdm import tqdm
import torchhd
from torchhd import embeddings
import torchmetrics
import numpy as np
import struct

# This function converts a 1D binary tensor into a single concatenated hex string.
# It packs the binary vector into 32-bit integers and then converts each integer to hex.
def binary_tensor_to_hex_str(tensor):
    """Converts a 1D binary tensor to a hex string."""
    # Ensure tensor is on CPU and is a numpy array for processing
    tensor = tensor.cpu().numpy().astype(np.uint8)
    # Pad the tensor with zeros to make its length a multiple of 8
    # This simplifies packing into bytes.
    rem = len(tensor) % 8
    if rem != 0:
        tensor = np.pad(tensor, (0, 8 - rem), 'constant')
    
    # Pack the numpy array of bits into bytes
    packed_bytes = np.packbits(tensor)
    # Convert the bytes to a single hex string
    return packed_bytes.tobytes().hex()


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using {device} device")
print("8K MNIST Baseline")

DIMENSIONS = 8192
IMG_SIZE = 28
NUM_LEVELS = 1000
BATCH_SIZE = 30

transform = torchvision.transforms.ToTensor()

train_ds = MNIST("./data", train=True, transform=transform, download=True)
train_ld = torch.utils.data.DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True)

test_ds = MNIST("./data", train=False, transform=transform, download=True)
test_ld = torch.utils.data.DataLoader(test_ds, batch_size=BATCH_SIZE, shuffle=False)

class Encoder(nn.Module):
    def __init__(self, out_features, size, levels):
        super(Encoder, self).__init__()
        self.flatten = torch.nn.Flatten()
        self.position = embeddings.Random(size * size, out_features)
        self.value = embeddings.Level(levels, out_features)

    def forward(self, x):
        x = self.flatten(x)
        sample_hv = torchhd.bind(self.position.weight, self.value(x))
        sample_hv = torchhd.multiset(sample_hv)
        return torchhd.hard_quantize(sample_hv)

accuracies = []

for run in range(1):
    encode = Encoder(DIMENSIONS, IMG_SIZE, NUM_LEVELS)
    encode = encode.to(device)

    num_classes = len(train_ds.classes)
    # Note: Assuming 'Centroid2' is a custom class similar to torchhd.models.Centroid
    # For this code to be runnable, I'm using the standard torchhd Centroid model.
    model = Centroid2(DIMENSIONS, num_classes)
    model = model.to(device)

    with torch.no_grad():
        for samples, labels in tqdm(train_ld, desc="Training"):
            samples = samples.to(device)
            labels = labels.to(device)
            samples_hv = encode(samples)
            model.add(samples_hv, labels)

    accuracy = torchmetrics.Accuracy("multiclass", num_classes=num_classes)

    with torch.no_grad():
        model.normalize()

        # Open the three output files before the testing loop
        with open('sample.mem', 'w') as f_samples, \
             open('label.mem', 'w') as f_labels, \
             open('class_weights.mem', 'w') as f_weights:

            print("\nWriting class weights to class_weights.mem...")
            # Get the model's weights and convert to half-precision (FP16) on the CPU
            weights_fp16 = model.weight.cpu().half()
            
            # Iterate through each weight value in the model's weight tensor
            for w in weights_fp16.flatten():
                # Pack the float16 value into 2 bytes (big-endian) and convert to hex
                hex_weight = struct.pack('>e', w.item()).hex()
                f_weights.write(f"{hex_weight}\n")
            print("Finished writing weights.")

            # Process the test dataset to write samples and labels
            for samples, labels in tqdm(test_ld, desc="Testing and Writing Files"):
                samples = samples.to(device)
                samples_hv = encode(samples)

                # Iterate through each sample and label in the current batch
                for i in range(samples_hv.size(0)):
                    single_sample_hv = samples_hv[i]
                    single_label = labels[i]

                    # --- Binarize HV: bipolar -1/+1 to binary 1/0 ---
                    single_sample_hv_bin = (single_sample_hv == -1).to(torch.uint8)
                    
                    # Flip the order of the bits (LSB <-> MSB)
                    single_sample_hv_bin = torch.flip(single_sample_hv_bin, dims=[0])
                    
                    # Convert the binary sample tensor to a hex string and write to file
                    hex_sample_str = binary_tensor_to_hex_str(single_sample_hv_bin)
                    f_samples.write(f"{hex_sample_str}\n")

                    # Write the corresponding label to its file
                    f_labels.write(f"{single_label.item()}\n")

                
                
                samples_hv = samples_hv.to(device)
                outputs = model(samples_hv, "dot")
                accuracy.update(outputs.cpu(), labels)

        print("Finished writing samples and labels.")

    acc = accuracy.compute().item() * 100
    print(f"Testing accuracy for run {run + 1}: {acc:.3f}%")
    accuracies.append(acc)

min_acc = min(accuracies)
max_acc = max(accuracies)
avg_acc = sum(accuracies) / len(accuracies)

print(f"\nMinimum accuracy: {min_acc:.3f}%")
print(f"Maximum accuracy: {max_acc:.3f}%")
print(f"Average accuracy: {avg_acc:.3f}%")