In [None]:
import numpy as np
import nengo
import nengo_spa as spa
import matplotlib.pyplot as plt
import pickle

from utils import create_vectors, encode_point
from region_query_utils import direction_quad, generate_rectangle_region, saccades, lookup_space_table, predict_single_query

from image_to_memory import encode_memory_shape, decode_image

#putting it all together

In [None]:
objs = ["ZERO", "ONE", "TWO", "THREE", "FOUR", "FIVE", "SIX", "SEVEN", "EIGHT", "NINE"]
D = 512

obj_dic, vec_dic = create_vectors(objs, D)
X = vec_dic['X']
Y = vec_dic['Y']

In [None]:
n = 1000 #number of images
m = 4 #digits per image (max 4 right now due to spread factor)

from mnist_image_generator import gen_images
image_data = gen_images(n,m)

xs_original = np.array(image_data['x'])
ys_original = np.array(image_data['y'])

In [None]:
xs, ys = saccades(image_data['images'])
xs = np.array(xs)
ys = np.array(ys)

In [None]:
import keras

im_dim = 28
model = keras.models.load_model('mnist_net.h5')

pred_obj_list = decode_image(image_data['images'], xs, ys, im_dim//2, model)

In [None]:
obj_vectors = np.array([obj_dic[_] for _ in objs])

square = generate_rectangle_region([-1,1],[-1,1], X,Y)
square.normalized()

#store objects in memory as squares rather than points
memory_data = encode_memory_shape(pred_obj_list, xs,ys, obj_vectors, [X,Y], square, n, m)
# memory_data = encode_memory(pred_obj_list, xs,ys, obj_vectors, [X,Y], n, m)

In [None]:
#get direction of second object towards first object
dirs = direction_quad(xs_original[:, 0] - xs_original[:, 1], ys_original[:, 0] - ys_original[:, 1])

In [None]:
from region_query_utils import predict_single_query, get_quads, generate_space_table

UP_RIGHT, DOWN_RIGHT, UP_LEFT, DOWN_LEFT = get_quads(X,Y, 5)
region_selector = np.array([[DOWN_LEFT, UP_LEFT],[DOWN_RIGHT, UP_RIGHT]])

loc_table = generate_space_table(np.linspace(-5, 5, 100),np.linspace(-5, 5, 100),D, X,Y)

In [None]:
obj_list = np.array(image_data['obj_list'])
query_obj = [obj_dic[objs[o]] for o in obj_list[:,1]]

In [None]:
preds = predict_single_query(memory_data['obj_loc_memory'], memory_data['obj_memory'], query_obj, dirs, obj_dic, region_selector, loc_table)

In [None]:
#Check if the predicted vector is in the correct direction (since multiple correct answers)

inregion_map = np.zeros((n, m))
for i in range(m):
    correct_dir_1 = np.all(np.array(dirs) == np.array(direction_quad(xs_original[:, i] - xs_original[:, 1], ys_original[:, i] - ys_original[:, 1])), axis=0)
    
    inregion_map[:,i] = correct_dir_1

inregion = np.where(inregion_map, obj_list, -1)

print(np.mean(np.any((inregion== preds[:,None]), axis = 1)))