In [1]:
from itertools import permutations
import random
import re
import numpy as np
import math 

class Raven2x1Problem(object):
    
    def __init__(self,prob):
        self.prob = prob
        self.figures = {}
    
    def solve(self):
        question =[]
        fig = 0
        for x in range(len(self.prob)):
            if x == 2:
                correctAnswer = self.prob[x];
            if self.prob[x] == '1':
                fig = x;
                break;

        for y in range(3,len(self.prob)):
            question.append(self.prob[y])
        
        orderFigure = []
        orderObject = []
        lists = ['above', 'left-of', 'inside', 'overlaps']
        bools = ['vertical-flip', 'horizontal-flip']
        sizes = ['size']
        nums = ['angle']
        shapes = ['shape']
        fills = ['fill']
        
        for z in self.prob:
            if re.match('[A-Z0-9]{1}\n{1}',z) and z not in self.figures:
                self.figures[z.strip("\n")] ={}
                orderFigure.append(z.strip("\n"))
            elif re.match('\t{1}[A-Z]{1}\n{1}',z):
                self.figures[orderFigure[len(orderFigure)-1]][z.strip("\t\n")] = {}
                orderObject.append(z.strip('\t\n'))
            elif re.match('\t\t[a-zA-Z:0-9-,]*',z):
                value = z.strip('\t\n').split(':')
                if value[0] in lists:
                    value[1] = value[1].split(',')
                elif value[0] in bools:
                    if value[1] == 'yes':
                        value[1] = True
                    else:
                        value[1] = False
                elif value[0] in sizes:
                    if value[1] == 'small':
                        value[1] = 1
                    elif value[1] == 'medium':
                        value[1] = 2
                    elif value[1] == 'large':
                        value[1] = 3
                    else:
                        value[1] = 0
                elif value[0] in nums:
                    value[1] = float(value[1])
                elif value[0] in shapes: 
                    pass
                elif value[0] in fills:
                    value[1] = value[1].split(',')
                else:
                    print ('unknown:%s-%s' % (value[0], value[1]))
                self.figures[orderFigure[len(orderFigure)-1]][orderObject[len(orderObject)-1]][value[0]] = value[1]
        return correctAnswer
        
    def permute_fig_shapes(self, figure):
        
        name = list(figure.keys())
        shape = list(figure.values())
        result = []

        for p in permutations(range(len(figure))):
            tfigure = {}
            for i in range(len(p)):
                tfigure[name[i]] = shape[p[i]]
            result += [tfigure]
        return result
    
    def build_transform(self, f1, f2):  
        
        graphs = {}
        
        for shapes in f1:
            graphs[shapes] = []
            if not shapes in f2:
                graphs[shapes] += ['deleted']
                continue
            graphs[shapes] += self.identify_trans(f1[shapes], f2[shapes])
        return graphs
    
    def identify_trans(self, shapes1, shapes2):

        trans = []

        if shapes2.get('size', 0) > shapes1.get('size', 0):
            trans += ['expanded']
        if shapes2.get('size', 0) < shapes1.get('size', 0):
            trans += ['shrunk']
        if shapes2.get('fill') != shapes1.get('fill'):
            trans += ['filled %s' % elements for elements in shapes2.get('fill') if elements not in shapes1.get("fill")]
        if shapes2.get('shape', 'square') != shapes1.get('shape', 'square'):
            trans += ['reshaped from %s to %s' % (shapes1.get('shape', 'square'), shapes2.get('shape', 'square'))]
        if (shapes2.get('angle') != None) and (shapes1.get('angle')!= None):
            angle1 = shapes1.get('angle')
            angle2 = shapes2.get('angle')
            angle_diff = round((angle2 - angle1)%360)
            if angle_diff >= 180:
                angle_diff = 360 - angle_diff
                trans += ['rotated %f' % angle_diff]
            else:
                trans += ['rotated %f' % angle_diff]
        if shapes2.get('left-of') != shapes1.get('left-of'):
            trans += ['left-of' + str(shapes2.get('left-of'))]
        if shapes2.get('above') != shapes1.get('above'):
            trans += ['above' + str(shapes2.get('above'))]
        if shapes2.get('inside') != shapes1.get('inside'):
            if shapes2.get('inside') is None:
                pass
            else:
                trans += ['inside' + str(shapes2.get('inside'))]
        if shapes2.get('overlaps') != shapes1.get('overlaps'):
            trans += ['overlaps' + str(shapes2.get('overlaps'))]
        return trans
        
    def build_permuted_transforms(self, fig1, fig2):
        
        re = []
        
        if len(fig2) < 7:
            fig2_permutes = self.permute_fig_shapes(fig2)
        else:
            fig2_permutes = [fig2]
        
        for fig in fig2_permutes:
            re += [self.build_transform(fig1, fig)]
        return re
    
    def weight_transform_graph(self, graphs):
        
        scores = 0 
        
        for shapes in graphs:
            if len(graphs[shapes]) == 0:
                scores -= 4
            for trans in graphs[shapes]:
                if 'above' in trans or 'left-of' in trans or 'inside' in trans or 'overlaps' in trans:
                    scores -= 1
                if 'flipped' in trans:
                    scores -= 1
                elif 'filled' in trans:
                    scores -= 1
                elif 'rotated' in trans:
                    scores -= 2
                elif 'expanded' in trans or 'shrunk' in trans:
                    scores -= 3
                elif 'deleted' in trans:
                    scores -= 0
                elif 'reshaped' in trans:
                    scores -= 1
                elif 'overlaps' in trans:
                    scores -= 0
                elif len(trans) == 0:
                    print("unchanged")
                    scores -= 1
        return scores
    
    def compare_transforms(self, t1, t2):

        check = []
        differs = 0
        t2_test = [item[1] for item in t2.items()]
        
        for t in t1:
            if t not in t2:
                differs += 1
                continue
            if t1[t] != t2[t]:
                differs += 1
            if (t1[t] != t2[t]) and ('deleted' in t1[t]):
                differs += 1
        for t in t2:
            if t not in t1:
                differs += 1
                continue
            if t2[t] != t1[t]:
                differs += 1
            if t2[t] not in t1.values():
                differs += 1
            if (t2[t] != t1[t]) and ('deleted' in t2[t]):
                differs += 1
        if len(np.unique(t2_test))==0:
            differs += 5
        return differs
    
    def generate(self):
        
        # Understand transformations between A and B
        target_transforms = self.build_permuted_transforms(self.figures['A'], self.figures['B'])[0]
        choice_transforms = {} 
        
        # Process all transformations between C and each option
        for k in range(1, 7):
            k = str(k)
            choice_transforms[k] = self.build_permuted_transforms(self.figures['C'], self.figures[k])
            if len(choice_transforms[k]) > len(target_transforms):
                choice_transforms[k] = [choice_transforms[k][i] for i in range(len(target_transforms))]
        return target_transforms, choice_transforms
    
    def test(self, target_transforms, choice_transforms):

        all_scores = []
        best_score = float('inf')
        best_weight = float('-inf')
        ret = ''

        # Loop through all choices
        for item in choice_transforms.items():
            choice = item[0]
            choice_trans = item[1][0]
            # Rule in place if target transformation set has less shapes than the choice transformation
            target_values = list(target_transforms.values())
            uniform_trans_test = len(np.unique(target_values))
            
            if uniform_trans_test == 1:
                if len(choice_trans) > len(target_transforms):
                    objects = set(target_transforms.keys())
                    for keys in set(choice_trans.keys()):
                        if keys not in objects:
                            target_transforms[keys] = list(np.unique(target_values))
        
            # Test and score each result
            score = self.compare_transforms(target_transforms, choice_trans)
            weight = self.weight_transform_graph(choice_trans)
            
            # Do not append search states that have a difference score higher than 6
            if score > 6:
                continue
            else:
                all_scores.append((choice,score, weight))
        
        for idx, scores, weight in all_scores:
            if scores < best_score:
                best_score = scores
                best_weight = weight
                ret = idx
            elif score == best_score:
                if weight > best_weight:
                    best_weight = weight
                    ret = idx
                elif weight == best_weight:
                    new = idx
                    tiebreaker = random.random()
                    if round(tiebreaker) == 1:
                        ret = new
        return ret

In [5]:
import os
if __name__ == "__main__":
    
    files = []
    for file in os.listdir("/Users/ArnoldYanga/Downloads/2x1BasicProblemsTXT/"):
        files.append(file)
    counter = 0
    num_problems = len(files)
    for file in sorted(files):
        problem = []
        with open("/Users/ArnoldYanga/Downloads/2x1BasicProblemsTXT/"+file,"r") as f:
            problem = f.readlines();

        raven = Raven2x1Problem(problem)
        correctAnswer = raven.solve().strip("\n")

        goal_state, current_state = raven.generate()
        ansReceived = list(raven.test(goal_state,current_state))
        
        print( "Problem - ", file.replace(".txt",""))
        print( "Correct answer - ",correctAnswer)
        print( "Answer received from agent - ",ansReceived)
        
        if (correctAnswer in ansReceived):
            counter += 1
            print("CORRECT!")
            print("\n")
        else:
            print("INCORRECT!")
            print("\n")
    print("Success Rate: " + str(counter)+"/"+str(num_problems))

Problem -  2x1BasicProblem01
Correct answer -  5
Answer received from agent -  ['5']
CORRECT!


Problem -  2x1BasicProblem02
Correct answer -  6
Answer received from agent -  ['6']
CORRECT!


Problem -  2x1BasicProblem03
Correct answer -  4
Answer received from agent -  ['4']
CORRECT!


Problem -  2x1BasicProblem04
Correct answer -  3
Answer received from agent -  ['3']
CORRECT!


Problem -  2x1BasicProblem05
Correct answer -  2
Answer received from agent -  ['2']
CORRECT!


Problem -  2x1BasicProblem06
Correct answer -  5
Answer received from agent -  ['5']
CORRECT!


Problem -  2x1BasicProblem07
Correct answer -  2
Answer received from agent -  ['2']
CORRECT!


Problem -  2x1BasicProblem08
Correct answer -  1
Answer received from agent -  ['2']
INCORRECT!


Problem -  2x1BasicProblem09
Correct answer -  1
Answer received from agent -  ['1']
CORRECT!


Problem -  2x1BasicProblem10
Correct answer -  4
Answer received from agent -  ['4']
CORRECT!


Problem -  2x1BasicProblem11
Correct a