In [1]:
import json
import collections
import subprocess

def prettify(atom):

    s = atom['predicate']
    if 'terms' in atom:
        s += '('
        ts = [prettify(t) for t in atom['terms']]
        s += ','.join(ts)
        s += ')'
    return s

  
def parse_json_result(out):
    """Parse the provided JSON text and extract a dict
    representing the predicates described in the first solver result."""
    result = json.loads(out)
    assert len(result['Call']) > 0
    if 'Witnesses' not in result['Call'][0]:
        return []
    
    if len(result['Call'][0]['Witnesses']) == 0:
        return []
    
    all_preds = []
    ids = range(len(result['Call'][0]['Witnesses']))
    
    witness = result['Call'][0]['Witnesses'][0]['Value']

    class identitydefaultdict(collections.defaultdict):
        def __missing__(self, key):
            return key

    preds = collections.defaultdict(list)
    env = identitydefaultdict()

    for atom in witness:
        parsed,dummy = parse_terms(atom)
        preds[parsed[0]['predicate']].append(parsed)
    return preds

def solve(args):
    """Run clingo with the provided argument list and return the parsed JSON result."""

    args = ['clingo','--outf=2'] + args
    clingo = subprocess.Popen(
        ' '.join(args),
        stdout=subprocess.PIPE,
        stderr=subprocess.PIPE,
        shell=True
        )
    out, err = clingo.communicate()
            
    return parse_json_result(out)

def parse_terms(arguments):
    terms = []
    while len(arguments) > 0:
        l_paren = arguments.find('(')
        r_paren = arguments.find(')')
        comma = arguments.find(',')
        if l_paren < 0:
            l_paren = len(arguments)-1
        if r_paren < 0:
            r_paren = len(arguments)-1
        if comma < 0:
            comma = len(arguments)-1
        next = min(l_paren,r_paren,comma)
        next_c = arguments[next]
        if next_c == '(':
        
            pred = arguments[:next]
            sub_terms, arguments = parse_terms(arguments[next+1:]) 
            terms.append({'predicate':pred,'terms':sub_terms})
        elif next_c == ')':
            pred = arguments[:next]
            if pred != '':
                terms.append({'predicate':arguments[:next]})
            arguments = arguments[next+1:]
            return terms,arguments
        elif next_c == ',':
            pred = arguments[:next]
            if pred != '':
                terms.append({'predicate':arguments[:next]})
            arguments = arguments[next+1:]
        else:
            terms.append({'predicate':arguments})
            arguments = ''
    return terms, ''
   


In [2]:
filenames = ['pong.lp','kaboom.lp']

games = []
types = {}
facts = []
for filename in filenames:
    rules = open(filename,'rb').read().replace(' ','').replace('\n','').split('.')[:-1]
    rules = [parse_terms(rule)[0][0] for rule in rules]
    per_game_facts = []
    for rule in rules:
        if rule['predicate'] == 'type':
            types[rule['terms'][1]['predicate']] = rule['terms'][0]['predicate']
        else: 
            facts.append(rule)
            per_game_facts.append(rule)
    games.append([prettify(rule) for rule in per_game_facts])

In [3]:

def has_term(rule,term):
    
    if 'terms' in rule:
        
        for rule_term in rule['terms']:
                if has_term(rule_term,term):
                    return True
        return False
    elif rule['predicate'] == term:
        return True
    else:
        return False
def get_terms(rule):
    if 'terms' in rule:
        terms = []
        for rule_term in rule['terms']:
            terms += get_terms(rule_term)
        return terms
    else:
        return [rule['predicate']]
def get_higher_level(rule):
    if 'terms' in rule:
        
        terms = [prettify(rule)]
        for rule_term in rule['terms']:
            terms += get_higher_level(rule_term)
        return terms
    else:
        return []
    
    
def get_predicates(rule):
    if 'terms' in rule:
        
        terms = [rule]
        for rule_term in rule['terms']:
            terms += get_predicates(rule_term)
        return terms
    else:
        return []
    
def get_term_positions(term,rule):
    found = []
    if 'terms' in rule:
        for i,t in enumerate(rule['terms']):
            if t['predicate'] == term:
                found.append(i)
    return found
    

In [4]:
import random
import sys
import numpy as np
import hashlib
import os


max_rules = 2
temperature = 5

In [5]:
import unionfind 


def replace(fact,source,target):
    if 'terms' in fact:
        terms = []
        for fact_term in fact['terms']:
            terms.append(replace(fact_term,source,target))
        return {'predicate':fact['predicate'],
                'terms':terms}
    else:
        pred = fact['predicate']
        if pred == source:
            pred = target
        return {'predicate':pred}
    
def create_rule_graph(game,positives):
    terms_to_fact = {}
    
    all_terms = {}
    all_rules = {}
    for positive_id,positive in enumerate(positives):
        terms = get_terms(positive)
        terms_to_fact = {term:[-positive_id-1]  for term in terms}
        all_terms[-positive_id-1] = terms
        all_rules[-positive_id-1] = positive
        
    
    for rule_id,rule in enumerate(game):
        terms = get_terms(rule)
        all_terms[rule_id] = terms
        for term in terms:
            if term not in terms_to_fact:
                terms_to_fact[term] = []
            terms_to_fact[term].append(rule_id)
        all_rules[rule_id] = rule
        
    children = {}
    for term in terms_to_fact:        
        for fact in terms_to_fact[term]:
            rule = all_rules[fact]
            all_predicates = get_predicates(rule)
            for predicate in all_predicates:
                positions = get_term_positions(term,predicate)
                if predicate['predicate'] not in children:
                    children[predicate['predicate']] = {}
                for position in positions:
                    if position not in children[predicate['predicate']]:
                        children[predicate['predicate']][position] = set()
                    children[predicate['predicate']][position].add(term)
                
    
                
    term2id = {t:i for i,t in enumerate(sorted(terms_to_fact))}
    id2term = {i:t for t,i in term2id.items()}
    union = unionfind.unionfind(len(terms_to_fact))
    
    for rule in children:
        for pos in children[rule]:
            children[rule][pos] = list(children[rule][pos])
            for ii in range(len(children[rule][pos])):
                for jj in range(ii+1,len(children[rule][pos])):
                    union.unite(term2id[children[rule][pos][ii]],
                                term2id[children[rule][pos][jj]])
    
    implicit_types = union.groups()
    term2type = {}
    for group in implicit_types:
        group = [id2term[t] for t in group]
        for t in group:
            term2type[t] = group
            
            
    for term in terms_to_fact:
        for other in term2type[term]:
            terms_to_fact[term] += terms_to_fact[other]
        terms_to_fact[term] = list(set(terms_to_fact[term]))
    visited = set()
    connections = {}
    
    stack = list(sorted([i for i in all_terms if i < 0]))
    #print stack
    while len(stack) > 0:        
        #print 'stack', stack
        current = stack.pop(0)
        visited.add(current)
        if current not in connections:
            connections[current] = set()
        for term in all_terms[current]:
            for connection in terms_to_fact[term]:
                if connection not in visited and connection not in stack:
                    if connection not in connections:
                        connections[connection] = set()
                    connections[connection].add(current)
                    connections[current].add(connection)
                        
                    stack.append(connection)
                elif connection != current:
                    connections[connection].add(current)
                    connections[current].add(connection)
    
    
    
    
    return connections,all_rules
  

In [6]:
all_positives = []
all_raw_positives = []

test = 'player_affects_outcome'

if test == 'player_controls':

    positives = [{'predicate':'player_controls','terms':[{'predicate':'paddle_player'}]}]
    all_raw_positives.append(positives[-1])
    positives = [prettify(f) for f in positives]
    all_positives.append(positives)


    positives = [{'predicate':'player_controls','terms':[{'predicate':'basket'}]}]
    all_raw_positives.append(positives[-1])
    positives = [prettify(f) for f in positives]
    all_positives.append(positives)
elif test == 'moves':    
    positives = [{'predicate':'moves','terms':[{'predicate':'paddle_player'}]},
                {'predicate':'moves','terms':[{'predicate':'paddle_computer'}]}]
    all_raw_positives += positives
    positives = [prettify(f) for f in positives]
    all_positives.append(positives)


    positives = [{'predicate':'moves','terms':[{'predicate':'basket'}]},
                 {'predicate':'moves','terms':[{'predicate':'bomb'}]},
                {'predicate':'moves','terms':[{'predicate':'bomber'}]}]
    all_raw_positives += positives
    positives = [prettify(f) for f in positives]
    all_positives.append(positives)
elif test == 'player_affects_outcome':    
    positives = [{'predicate':'player_affects_outcome','terms':[{'predicate':'player_serve'}]},
                {'predicate':'player_affects_outcome','terms':[{'predicate':'move_up'}]},
                {'predicate':'player_affects_outcome','terms':[{'predicate':'move_down'}]},
                {'predicate':'player_affects_outcome','terms':[{'predicate':'player_hit'}]}]
    all_raw_positives += positives
    positives = [prettify(f) for f in positives]
    all_positives.append(positives)


    positives = [{'predicate':'player_affects_outcome','terms':[{'predicate':'defuse'}]},
                 {'predicate':'player_affects_outcome','terms':[{'predicate':'move_right'}]},
                {'predicate':'player_affects_outcome','terms':[{'predicate':'move_left'}]}]
    all_raw_positives += positives
    positives = [prettify(f) for f in positives]
    all_positives.append(positives)

connections,rules = create_rule_graph(facts,all_raw_positives)

for c in sorted(connections):
    print prettify(rules[c])
    
    for t in sorted([prettify(rules[cc]) for cc in connections[c]]):
        print '\t->',t

player_affects_outcome(move_left)
	-> player_affects_outcome(defuse)
	-> player_affects_outcome(move_down)
	-> player_affects_outcome(move_right)
	-> player_affects_outcome(move_up)
	-> player_affects_outcome(player_hit)
	-> player_affects_outcome(player_serve)
	-> precondition(above(paddle_computer,ball),move_up_computer)
	-> precondition(below(paddle_computer,ball),move_down_computer)
	-> precondition(control_event(player_input(down_arrow,held)),move_down)
	-> precondition(control_event(player_input(left_arrow,held)),move_left)
	-> precondition(control_event(player_input(right_arrow,held)),move_right)
	-> precondition(control_event(player_input(space,pressed)),player_serve)
	-> precondition(control_event(player_input(up_arrow,held)),move_up)
	-> precondition(eq(computer_start,true),computer_serve)
	-> precondition(eq(paused,false),tick)
	-> precondition(eq(player_start,true),player_serve)
	-> precondition(ge(score_computer,high),lose)
	-> precondition(ge(score_player,high),win)
	-> p

In [None]:
def coarsenings(head,body):
    possible_coarsenings = []
    all_high_level_terms = set()
    term_usage = {}
    terms = get_terms(head)
    for term in terms:
        if term not in term_usage:
            term_usage[term] = []
        term_usage[term].append(-1)


    for pred_id,predicate in enumerate(body):
        high_level_terms = get_higher_level(predicate)
        all_high_level_terms |= set(high_level_terms)
        terms = get_terms(predicate)
        for term in terms:
            if term not in term_usage:
                term_usage[term] = []
            term_usage[term].append(pred_id)
    safe_terms = set(all_high_level_terms)

    for high_level in all_high_level_terms:
        for term in term_usage:
            if len(term_usage[term]) > 1 and term in high_level:
                safe_terms.remove(high_level)
                break

    return list(safe_terms)

def coarsen(coarsenings_,body):
    new_rules = []
    coarsening2ind = {coarsening:'V{}'.format(i) for i,coarsening in enumerate(coarsenings_)}
    ind2coarsening = {'V{}'.format(i):coarsening for i,coarsening in enumerate(coarsenings_)}

    new_body = []
    for b in body:

        pretty_b = prettify(b)
        for i in sorted(ind2coarsening):
            c = ind2coarsening[i]
            pretty_b = pretty_b.replace(c,i)

        new_body.append(parse_terms(pretty_b)[0][0])
    return new_body
   

In [None]:


import itertools

def get_neighbors(points,connections,rules):
    neighbors = set()
    
    for point in points:
        for conn in connections[point]:
            neighbors.add(tuple(sorted(set(points) | set([conn]))))
    return neighbors

def get_all_combinations(points,lhs,rules):
    rules_to_use = [rules[point] for point in points]
    uniques = set()
        
    for fact_id, fact in enumerate(rules_to_use):
        terms = set(get_terms(fact))
        uniques |= terms #set([(fact_id,term) for term in terms])
    unique_combos = []
    for ii in range(0,len(uniques)+1):
        unique_combos += list(itertools.combinations(uniques,ii))
    output = []
    for uniques in unique_combos:   
        uniques = sorted(uniques)
        unique_mapping = {}
        for unique_id,u in enumerate(uniques):
            unique_mapping[u] = 'V{}'.format(unique_id) #random.randint(0,len(by_type[t])))
        final_facts = [] 
        for fact_id, fact in enumerate(rules_to_use):            
            terms = set(get_terms(fact))
            for term in terms:
                if term in unique_mapping:
                    fact = replace(fact,term,unique_mapping[term])
            final_facts.append(fact)
        target_form = lhs
        terms = list(set(get_terms(target_form)))
        for term in terms:
            if term in unique_mapping:
                target_form = replace(target_form,term,unique_mapping[term])
        potential_coarsenings = coarsenings(target_form,final_facts)
        
        coarsening_combos = []
        for ii in range(0, len(potential_coarsenings)+1):
            coarsening_combos += list(itertools.combinations(potential_coarsenings,ii))
        
        for coarsening in sorted(coarsening_combos):
            coarsened = coarsen(coarsening,final_facts)
            
            
        
            by_term = {}
            for fact_id, fact in enumerate(coarsened):
                terms = set(get_terms(fact))
                for term in terms:
                    if term not in by_term:
                        by_term[term] = set()
                    by_term[term].add(fact_id)
                    
            able_to_be_negated = []
            for fact_id, fact in enumerate(coarsened):
                terms = set(get_terms(fact))
                can_be_negated = True
                for term in terms:
                    if len(by_term[term]) == 1 and term[0].isupper():
                        can_be_negated = False
                        break
                if can_be_negated:
                    able_to_be_negated.append(fact_id)
            
            negation_combos = []
            for ii in range(0,len(able_to_be_negated)+1):
                negation_combos  += list(itertools.combinations(able_to_be_negated,ii))
            
            
            for negation_combo in negation_combos:
                negations = [False]*len(coarsened)
                for n in negation_combo:
                    negations[n] = True
                output.append((target_form,coarsened,negations))
    return output

def score_rule(games,rule,per_game_positives):
    head, body, negations = rule
    probability = 0
    for game,positives in zip(games,per_game_positives):
            
        rule_string = '.\n'.join(game)
        
        
        rule_string += '.\n' + rule_to_string(rule)
        hashed_name = 'temp' + hashlib.sha224(rule_string).hexdigest()
        with open(hashed_name,'wb') as outfile:
            outfile.write('.\n'.join(game) + '.\n')
            outfile.write(rule_to_string(rule))
            outfile.write('#show {}/{}.'.format(head['predicate'],len(head['terms'])))
                
        
        solved = solve([hashed_name])
        
        is_good = True
        found = []
        total_found = 0
        for t in solved:
            for tt in solved[t]:
                for ttt in tt:
                    if prettify(ttt) in positives:
                        found.append(prettify(ttt))
                    else:
                        is_good = False
                        break
                if not is_good:
                    break
            if not is_good:
                break
            if is_good:
                total_found += 1
        if is_good:
            if total_found == 0:
                probability += np.log(1e-20)
            else:
                probability += np.log(float(total_found)/float(len(positives)))
        else:
            probability += np.log(1e-20)
        
    return -probability
def rule_to_string(rule):
    head,body,negations = rule
    
    rule_text = []
    for n,r in zip(negations,body):
        if n:
            rule_text.append('not '+prettify(r))
        else:
            rule_text.append(prettify(r))
    return prettify(head) + ':-' + ','.join(list(sorted(rule_text))) + '.\n'
def breadth_first(connections,rules):    
    starting_points = []
    for connection in sorted(connections):
        if connection < 0:
            starting_points.append([connection])
            
    to_visit = [tuple(pt) for pt in starting_points]
    max_size = 3
    tested_rules = set()
    visited = set()
    visited |= set(to_visit)
    while len(to_visit) > 0:
        current = to_visit.pop(0)
        print len(visited),  len(current), len(tested_rules)
        if len(current) > max_size:
            break
            
        lhses = [c for c in current if c < 0]
        to_test = []
        for lhs in lhses:
            if len(current) > 1:
                shrunk = set(current)
                shrunk.remove(lhs)
                to_test += get_all_combinations(shrunk,rules[lhs],rules)                
            #to_test += get_all_combinations(current,rules[lhs],rules)
        scored_rules = {}
        
        
        for rule in to_test:
            terms = get_terms(rule[0])
            
            for r in rule[1]:
                terms += get_terms(r)
            terms = set(terms)
            specifics = 0
            general = 0
            for term in terms:
                if term[0].islower():
                    specifics +=1
                else:
                    general += 1
            score = 1000*specifics+general
            if score not in scored_rules:
                scored_rules[score] = {}
                
            rule_string = rule_to_string(rule)
            
            if rule_string not in tested_rules:
                scored_rules[score][rule_to_string(rule)] = rule
                tested_rules.add(rule_string)
                #print 'testing', rule_string
        
        
        for score in sorted(scored_rules):
            for rule in scored_rules[score].values():
                log_prob = score_rule(games,rule,all_positives)
                if log_prob < 46 :
                    print score,log_prob
                    print prettify(rule[0]) , ':-'
                    for r in rule[1]:
                        print '\t', prettify(r)
                    print ''
                    if log_prob < 0.0001:
                        return rule
                
            
        neighbors = get_neighbors(current,connections,rules)
        neighbors = [tuple(sorted(neighbor)) for neighbor in neighbors]
        neighbors = [neighbor for neighbor in neighbors if neighbor not in visited]
        visited |= set(neighbors)
        to_visit += neighbors
        import os
        os.system('rm temp*')
breadth_first(connections,rules)
        

7 1 0
61 1 0
109 1 0
157 1 0
205 1 0
253 1 0
301 1 0
349 2 0
433 2 13
485 2 20
567 2 29
617 2 37
697 2 55
745 2 68
823 2 77
900 2 88
976 2 105
1051 2 113
1125 2 123
1198 2 134
1270 2 144
1341 2 165
1380 2 171
1449 2 186
1517 2 194
1584 2 202
1650 2 212
1684 2 221
1748 2 231
1811 2 249
1873 2 259
1934 2 268
1994 2 284
2053 2 291
2111 2 303
2 2.48490664979
player_affects_outcome(V1) :-
	precondition(control_event(V0),V1)

2 2.48490664979
player_affects_outcome(V2) :-
	precondition(control_event(V0),V2)

3 2.48490664979
player_affects_outcome(V2) :-
	precondition(control_event(player_input(V1,V0)),V2)

1002 2.48490664979
player_affects_outcome(V1) :-
	precondition(control_event(player_input(V0,held)),V1)

2137 2 320
2193 2 328
2248 2 333
2302 2 345
2355 2 350
2407 2 360
2458 2 365
2477 2 370
2526 2 379
2574 2 389
2590 2 395
2605 2 407
2619 2 413
2663 2 422
2706 2 429
2748 2 438
2758 2 444
2798 2 454
2837 2 463
2844 2 473
2881 2 492
2886 2 498
2921 2 505
2924 2 505
2957 2 510
2989 2 520
30