In [1]:
%load_ext autoreload
%autoreload 2
import numpy as np
import matplotlib.pyplot as plt
import glob, os, sys, re, math, random
import xml.etree.ElementTree as ET
from rules import *
from raven_data import load_question, display_problem

In [2]:
def predict(embeddings, embedding_names, debug=False):
    rule_order = [constant, progression, arithmetic, distribute_three, noise]
    grid = [[0, 1, 2], [3, 4, 5], [6, 7, None]]
    grid = [[embeddings[col_index] if col_index is not None else None for col_index in row] for row in grid]
    if debug:
        print("\n".join([str(row) for row in grid]))
    final_rules = {}
    for attribute in range(2, len(embeddings[0])): # shape, size, color, angle
        for rule in rule_order:
            if rule(attribute, is_bitwise=('BW' in embedding_names[attribute]))([grid[0], grid[1]]):
                final_rules[attribute] = rule
                if debug:
                    print(f'Attribute {attribute} {embedding_names[attribute]} follows rule {rule.__name__}')
                break
    # Guess which item is correct
    my_guess = None
    scores = [0 for _ in range(8)]
    for guess in range(8):
        guess_index = 8 + guess
        guess_embedding = embeddings[guess_index]
        grid[2][2] = guess_embedding
        flag = True
        for rule_i, (attribute, rule) in enumerate(final_rules.items()):
            if not rule(attribute, is_bitwise=('BW' in embedding_names[attribute]))(grid):
                flag = False
            else:
                scores[guess] += 1
        if flag:
            my_guess = guess
            if debug:
                print(f'My guess is {guess} with embedding {guess_embedding}')
    # print("Correct answer is", answer, "which is", embeddings[8 + answer])
    if my_guess is None:
        my_guess = np.argmax(scores)
        if debug:
            print("No answer found, guessing", my_guess)
    if debug:
        print("Scores", scores)
    return my_guess

In [41]:
path = '../RAVEN/new_data_0'
subtype = '*'
items = sorted([x.split('.xml')[0] for x in glob.glob(os.path.join(path, subtype, '*.xml'))])
# print(items)
correct = 0
debug=False
if debug:
    items=['../RAVEN/new_data_0/distribute_nine/RAVEN_91_train']
for index, item in enumerate(items):
    embeddings, embedding_names, answer = load_question(item, display=debug)
    if any(len(embedding) != len(embeddings[0]) for embedding in embeddings):
        print("Error!", item, " has embeddings of different lengths")
        break
    guess = predict(embeddings,embedding_names,debug=debug)
    if guess == answer:
        correct += 1
    else:
        if debug:
            grid = [[0, 1, 2], [3, 4, 5], [6, 7, None]]
            print("\n".join([str([embeddings[col] if col is not None else None for col in row ]) for row in grid]))
            print("Answer", answer, embeddings[8+answer], "Guess", guess, embeddings[8+guess] if guess is not None else None)
        print("WRONG",item, "Wrong guess", guess, "Correct answer", answer)
        # break
print("Final score:", correct, "out of", len(items))

WRONG ../RAVEN/new_data_0/distribute_four/RAVEN_23_train Wrong guess 5 Correct answer 1
WRONG ../RAVEN/new_data_0/distribute_four/RAVEN_44_train Wrong guess 7 Correct answer 4
WRONG ../RAVEN/new_data_0/distribute_four/RAVEN_4_train Wrong guess 5 Correct answer 1
WRONG ../RAVEN/new_data_0/distribute_four/RAVEN_69_test Wrong guess 7 Correct answer 4
WRONG ../RAVEN/new_data_0/distribute_four/RAVEN_75_train Wrong guess 7 Correct answer 2
WRONG ../RAVEN/new_data_0/distribute_four/RAVEN_93_train Wrong guess 4 Correct answer 2
WRONG ../RAVEN/new_data_0/distribute_four/RAVEN_96_val Wrong guess 4 Correct answer 6
WRONG ../RAVEN/new_data_0/distribute_nine/RAVEN_1_train Wrong guess 6 Correct answer 5
WRONG ../RAVEN/new_data_0/distribute_nine/RAVEN_37_val Wrong guess 4 Correct answer 3
WRONG ../RAVEN/new_data_0/distribute_nine/RAVEN_48_test Wrong guess 6 Correct answer 0
WRONG ../RAVEN/new_data_0/distribute_nine/RAVEN_82_train Wrong guess 5 Correct answer 4
WRONG ../RAVEN/new_data_0/distribute_nin

In [5]:
from collections import defaultdict
unique_value_count = defaultdict(lambda: defaultdict(int))
# unique_value_count = [defaultdict(int) for _ in range(20)]
largest = []
for item in items:
    embeddings, embedding_names, answer = load_question(item, display=False)
    for embedding in embeddings:
        for name, value in zip(embedding_names, embedding):
            if value is not None:
                unique_value_count[name][value] += 1
    if len(embedding_names) > len(largest):
        largest = embedding_names
print("Unique values")
for name, values in unique_value_count.items():
    print(name, sorted(list(values.keys())))

[0.5, 0.5, 0.33, 0.33]
[0.16, 0.83, 0.33, 0.33]
[0.83, 0.5, 0.33, 0.33]
[0.16, 0.5, 0.33, 0.33]
[0.83, 0.83, 0.33, 0.33]
[0.5, 0.16, 0.33, 0.33]
done
[0.5, 0.16, 0.33, 0.33]
[0.16, 0.5, 0.33, 0.33]
[0.83, 0.16, 0.33, 0.33]
[0.16, 0.16, 0.33, 0.33]
[0.83, 0.5, 0.33, 0.33]
[0.16, 0.83, 0.33, 0.33]
done
[0.16, 0.83, 0.33, 0.33]
[0.16, 0.16, 0.33, 0.33]
[0.5, 0.83, 0.33, 0.33]
[0.83, 0.83, 0.33, 0.33]
[0.83, 0.16, 0.33, 0.33]
[0.16, 0.5, 0.33, 0.33]
done
[0.5, 0.83, 0.33, 0.33]
[0.16, 0.16, 0.33, 0.33]
[0.83, 0.16, 0.33, 0.33]
[0.5, 0.16, 0.33, 0.33]
[0.83, 0.5, 0.33, 0.33]
done
[0.5, 0.5, 0.33, 0.33]
[0.83, 0.83, 0.33, 0.33]
[0.5, 0.83, 0.33, 0.33]
[0.16, 0.83, 0.33, 0.33]
[0.83, 0.16, 0.33, 0.33]
done
[0.5, 0.16, 0.33, 0.33]
[0.83, 0.5, 0.33, 0.33]
[0.5, 0.5, 0.33, 0.33]
[0.16, 0.5, 0.33, 0.33]
[0.5, 0.83, 0.33, 0.33]
done
[0.5, 0.83, 0.33, 0.33]
[0.83, 0.5, 0.33, 0.33]
[0.5, 0.5, 0.33, 0.33]
done
[0.5, 0.5, 0.33, 0.33]
[0.83, 0.16, 0.33, 0.33]
[0.5, 0.16, 0.33, 0.33]
done
[0.5, 0.16, 0.