In [2]:
%load_ext autoreload
%autoreload 2
import numpy as np
import matplotlib.pyplot as plt
import glob, os, sys, re, math, random
import xml.etree.ElementTree as ET
from rules import *
from raven_data import load_question, display_problem

In [40]:
rule_order = [constant, progression, arithmetic, distribute_three, noise]
def predict(embeddings, embedding_names, debug=False):
    counts = [0 for _ in range(5)]
    grid = [[0, 1, 2], [3, 4, 5], [6, 7, None]]
    grid = [[embeddings[col_index] if col_index is not None else None for col_index in row] for row in grid]
    if debug:
        print("\n".join([str(row) for row in grid]))
    final_rules = {}
    for attribute in range(2, len(embeddings[0])): # shape, size, color, angle
        for i, rule in enumerate(rule_order):
            if rule(attribute, is_bitwise=('BW' in embedding_names[attribute]))([grid[0], grid[1]]):
                final_rules[attribute] = rule
                counts[i] += 1
                if debug:
                    print(f'Attribute {attribute} {embedding_names[attribute]} follows rule {rule.__name__}')
                break
    # Guess which item is correct
    my_guess = None
    scores = [0 for _ in range(8)]
    for guess in range(8):
        guess_index = 8 + guess
        guess_embedding = embeddings[guess_index]
        grid[2][2] = guess_embedding
        flag = True
        for rule_i, (attribute, rule) in enumerate(final_rules.items()):
            if not rule(attribute, is_bitwise=('BW' in embedding_names[attribute]))(grid):
                if flag and debug:
                    print(f'Guess {guess} failed rule {rule.__name__} for attribute {attribute} {embedding_names[attribute]}')
                flag = False
            else:
                scores[guess] += 1
        if flag:
            my_guess = guess
            if debug:
                print(f'My guess is {guess} with embedding {guess_embedding}')
    # print("Correct answer is", answer, "which is", embeddings[8 + answer])
    if my_guess is None:
        my_guess = np.argmax(scores)
        if debug:
            print("No answer found, guessing", my_guess)
    if debug:
        print("Scores", scores)
    return my_guess, counts

In [43]:
# Generalized components v1: 3.5s / 63.6s
path = '../RAVEN/new_data_0'
subtype = '*'
items = sorted([x.split('.xml')[0] for x in glob.glob(os.path.join(path, subtype, '*.xml'))])
# print(items)
correct = 0
debug=False
global_counts = [0 for _ in range(5)]
if debug:
    items=['../RAVEN/new_data_1/center_single/RAVEN_551_train']
for index, item in enumerate(items):
    embeddings, embedding_names, answer = load_question(item, display=debug, debug=False)
    if any(len(embedding) != len(embeddings[0]) for embedding in embeddings):
        print("Error!", item, " has embeddings of different lengths")
        break
    guess, counts = predict(embeddings,embedding_names,debug=debug)
    global_counts = [global_counts[i] + counts[i] for i in range(5)]
    if guess == answer:
        correct += 1
    else:
        if debug:
            grid = [[0, 1, 2], [3, 4, 5], [6, 7, None]]
            print("\n".join([str([embeddings[col] if col is not None else None for col in row ]) for row in grid]))
            print("Answer", answer, embeddings[8+answer], "Guess", guess, embeddings[8+guess] if guess is not None else None)
        print("WRONG",item, "Wrong guess", guess, "Correct answer", answer)
        # break
print("Final score:", correct, "out of", len(items))
print("Names of rules:", [rule.__name__ for rule in rule_order])
print("Counts", global_counts)
print("Percentages", [global_counts[i]/sum(global_counts) for i in range(5)])

Final score: 700 out of 700
Names of rules: ['constant', 'progression', 'arithmetic', 'distribute_three', 'noise']
Counts [3694, 937, 546, 1003, 420]
Percentages [0.5596969696969697, 0.14196969696969697, 0.08272727272727273, 0.15196969696969698, 0.06363636363636363]


In [5]:
from collections import defaultdict
unique_value_count = defaultdict(lambda: defaultdict(int))
# unique_value_count = [defaultdict(int) for _ in range(20)]
largest = []
for item in items:
    embeddings, embedding_names, answer = load_question(item, display=False)
    for embedding in embeddings:
        for name, value in zip(embedding_names, embedding):
            if value is not None:
                unique_value_count[name][value] += 1
    if len(embedding_names) > len(largest):
        largest = embedding_names
print("Unique values")
for name, values in unique_value_count.items():
    print(name, sorted(list(values.keys())))

Unique values
Row [0, 1, 2]
Col [0, 1, 2]
Type [1, 3, 5]
Size [0, 1, 2, 4, 5]
Color [1, 2, 3, 4, 5, 6, 7]
Angle [1, 2, 6]
BWPosition [12, 44, 88, 176, 186, 207, 233, 268, 277, 319, 359, 372, 414]
Number [2, 3, 4, 5, 6, 7]
