In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

In [None]:
import matplotlib.pyplot as plt
from main import get_dataset
vsa, _, test_dl = get_dataset()

In [None]:
plt.imshow(test_dl.dataset[24][0])

In [None]:
import torch
a = torch.tensor([0.1, 1.1, 1.5, -0.2, -0.9, -1.2, 0])
a[a > 1] = a[a > 1].round()
a[a < -1] = a[a < -1].round()
a[a > 0] = a[a>0].ceil()
a[a <= 0] = a[a<=0].floor()
a = a.type(torch.int8)

In [None]:
from main import factorization, gen_init_estimates, get_model_loss_optimizer
import torch
from vsa import Resonator
from colorama import Fore
import matplotlib.pyplot as plt
VERBOSE = 2

DIM = 2000
MAX_NUM_OBJECTS = 2
NUM_POS_X = 3
NUM_POS_Y = 3
NUM_COLOR = 3
# Train
TRAIN_EPOCH = 75
TRAIN_BATCH_SIZE = 128
NUM_TRAIN_SAMPLES = 70000
# Test
TEST_BATCH_SIZE = 1
NUM_TEST_SAMPLES = 300
# Resonator
NORMALIZE = True
ACTIVATION = "NONE" # "NONE", "ABS", "NONNEG
RESONATOR_TYPE = "SEQUENTIAL" # "SEQUENTIAL", "CONCURRENT"
NUM_ITERATIONS = 100

data_dir = f"./data/{DIM}dim-{NUM_POS_X}x-{NUM_POS_Y}y-{NUM_COLOR}color"

In [None]:
import torchhd as hd

device = "cuda" if torch.cuda.is_available() else "cpu"
vsa, train_dl, test_dl = get_dataset(device)
checkpoint = "data/2000dim-3x-3y-3color/model_weights_128batch_75epoch_70000samples_09-03-03-38.pt"
model, loss_fn, optimizer = get_model_loss_optimizer()
model.load_state_dict(torch.load(checkpoint))

resonator_network = Resonator(vsa, type=RESONATOR_TYPE, norm=NORMALIZE, activation=ACTIVATION, iterations=NUM_ITERATIONS, device=device)
init_estimates = gen_init_estimates(vsa.codebooks, TEST_BATCH_SIZE)

model.eval()

n = 75
image = test_dl.dataset.data[n]
label = test_dl.dataset.labels[n]
target = test_dl.dataset.targets[n]

image = image.to(device)
image_nchw = (image.type(torch.float32)/255).permute(2,0,1).unsqueeze(0)
infer_result = model(image_nchw).round().type(torch.int8)

incorrect_count = [0] * MAX_NUM_OBJECTS
unconverged = [[0,0] for _ in range(MAX_NUM_OBJECTS)]    # [correct, incorrect]

# Factorization
outcomes, convergence = factorization(vsa, resonator_network, infer_result, init_estimates)


incorrect = False
message = ""

i = 0
if NORMALIZE:
    infer_result[i] = resonator_network.normalize(infer_result[i])

print(infer_result[i].tolist())
# Sample: multiple objects
for j in range(len(label)):
    # Incorrect if one object is not detected 
    # For n objects, only check the first n results
    if (label[j] not in outcomes[i][0: len(label)]):
        message += Fore.RED + "Object {} is not detected.".format(label[j]) + Fore.RESET + "\n"
        incorrect = True
        unconverged[len(label)-1][1] += 1 if convergence[i][j] == NUM_ITERATIONS-1 else 0
    else:
        message += "Object {} is correctly detected.".format(label[j]) + "\n"
        unconverged[len(label)-1][0] += 1 if convergence[i][j] == NUM_ITERATIONS-1 else 0

if incorrect:
    incorrect_count[len(label)-1] += 1 if incorrect else 0
    if (VERBOSE >= 1):
        print(Fore.BLUE + f"Test {n} Failed:      Convergence = {convergence[i]}" + Fore.RESET)
        print("Inference result similarity = {:.4f}".format(hd.cosine_similarity(infer_result[i], target).item()))
        print(message[:-1])
        print("Outcome = {}".format(outcomes[i][0: len(label)]))
else:
    if (VERBOSE >= 2):
        print(Fore.BLUE + f"Test {n} Passed:      Convergence = {convergence[i]}" + Fore.RESET)
        print("Inference result similarity = {:.4f}".format(hd.cosine_similarity(infer_result[i], target).item()))
        print(message[:-1])
n += 1
