In [1]:
import json
import collections
import subprocess

def prettify(atom):

    s = atom['predicate']
    if 'terms' in atom:
        s += '('
        ts = [prettify(t) for t in atom['terms']]
        s += ','.join(ts)
        s += ')'
    return s

  
def parse_json_result(out):
    """Parse the provided JSON text and extract a dict
    representing the predicates described in the first solver result."""
    result = json.loads(out)
    assert len(result['Call']) > 0
    if 'Witnesses' not in result['Call'][0]:
        return []
    
    if len(result['Call'][0]['Witnesses']) == 0:
        return []
    
    all_preds = []
    ids = range(len(result['Call'][0]['Witnesses']))
    
    witness = result['Call'][0]['Witnesses'][0]['Value']

    class identitydefaultdict(collections.defaultdict):
        def __missing__(self, key):
            return key

    preds = collections.defaultdict(list)
    env = identitydefaultdict()

    for atom in witness:
        parsed,dummy = parse_terms(atom)
        preds[parsed[0]['predicate']].append(parsed)
    return preds

def solve(args):
    """Run clingo with the provided argument list and return the parsed JSON result."""

    args = ['clingo','--outf=2'] + args
    clingo = subprocess.Popen(
        ' '.join(args),
        stdout=subprocess.PIPE,
        stderr=subprocess.PIPE,
        shell=True
        )
    out, err = clingo.communicate()
            
    return parse_json_result(out)

def parse_terms(arguments):
    terms = []
    while len(arguments) > 0:
        l_paren = arguments.find('(')
        r_paren = arguments.find(')')
        comma = arguments.find(',')
        if l_paren < 0:
            l_paren = len(arguments)-1
        if r_paren < 0:
            r_paren = len(arguments)-1
        if comma < 0:
            comma = len(arguments)-1
        next = min(l_paren,r_paren,comma)
        next_c = arguments[next]
        if next_c == '(':
        
            pred = arguments[:next]
            sub_terms, arguments = parse_terms(arguments[next+1:]) 
            terms.append({'predicate':pred,'terms':sub_terms})
        elif next_c == ')':
            pred = arguments[:next]
            if pred != '':
                terms.append({'predicate':arguments[:next]})
            arguments = arguments[next+1:]
            return terms,arguments
        elif next_c == ',':
            pred = arguments[:next]
            if pred != '':
                terms.append({'predicate':arguments[:next]})
            arguments = arguments[next+1:]
        else:
            terms.append({'predicate':arguments})
            arguments = ''
    return terms, ''
   


In [2]:
filenames = ['pong.lp','kaboom.lp']

games = []

game_facts = []
types = {}
facts = []
for filename in filenames:
    rules = open(filename,'rb').read().replace(' ','').replace('\n','').split('.')[:-1]
    rules = [parse_terms(rule)[0][0] for rule in rules]
    per_game_facts = []
    for rule in rules:
        if rule['predicate'] == 'type':
            types[rule['terms'][1]['predicate']] = rule['terms'][0]['predicate']
        else: 
            facts.append(rule)
            per_game_facts.append(rule)
    game_facts.append(per_game_facts)
    games.append([prettify(rule) for rule in per_game_facts])

In [3]:
def get_terms(rule,full=True):
    if 'terms' in rule:
        terms = []
        for rule_term in rule['terms']:
            terms += get_terms(rule_term)
        if full:
            return terms + [prettify(rule)]
        else:
            return terms
    else:
        return [rule['predicate']]

In [4]:
all_positives = []
all_pretty_positives = []

positives = [{'predicate':'player_controls','terms':[{'predicate':'paddle_player'}]}]
pretty_positives = [prettify(f) for f in positives]
all_positives.append(positives)
all_pretty_positives.append(pretty_positives)

positives = [{'predicate':'player_controls','terms':[{'predicate':'basket'}]}]
pretty_positives = [prettify(f) for f in positives]
all_positives.append(positives)
all_pretty_positives.append(pretty_positives)



In [5]:
def create_rule_graph(game,positives):
    terms_to_fact = {}
    
    all_terms = {}
    all_rules = {}
    for positive_id,positive in enumerate(positives):
        terms = get_terms(positive)
        terms_to_fact = {term:[-positive_id-1]  for term in terms}
        all_terms[-positive_id-1] = terms
        all_rules[-positive_id-1] = positive
        
    for rule_id,rule in enumerate(game):
        terms = get_terms(rule)
        all_terms[rule_id] = terms
        for term in terms:
            if term not in terms_to_fact:
                terms_to_fact[term] = []
            terms_to_fact[term].append(rule_id)
        all_rules[rule_id] = rule
            
    
    visited = set()
    connections = {}
    
    stack = [sorted(all_terms)[0]]
    
    while len(stack) > 0:
        
        current = stack.pop()
        visited.add(current)
        connections[current] = set()
        for term in all_terms[current]:
            for connection in terms_to_fact[term]:
                if connection not in visited:
                    stack.append(connection)
                elif connection != current:
                    connections[connection].add(current)
                    connections[current].add(connection)
        
    return connections,all_rules
            

In [6]:
import random 

import sys
import numpy as np
import hashlib
import os


connections,rules = create_rule_graph(game_facts[0],all_positives[0])
#for connection in sorted(connections):
#    print connection,prettify(rules[connection]), connections[connection]
    
    
def do_walk(rules,connections,min_to_add,max_to_add,visited = None):
    if visited:
        can_add = set(visited)        
    else:
        starting_points = []
        for connection in sorted(connections):
            if connection < 0:
                starting_points.append(connection)
        current = random.choice(starting_points)
        visited = set([current])
        can_add = set([current])
        
    number_of_facts = random.randint(min_to_add,max_to_add)

    while number_of_facts > 0:
        
        branch = random.choice(list(can_add))
        can_visit = connections[branch] - visited
        can_visit = set(v for v in can_visit if v >= 0)
        visiting = random.choice(list(can_visit))
        if len(can_visit) == 1:
            can_add.remove(branch)
        visited.add(visiting)
        can_add.add(visiting)
        current = visiting
        number_of_facts -= 1
    visited = sorted(visited)
    return visited

def find_bridges(vertices,connections):
    low = [-1]*len(vertices)
    pre = [-1]*len(vertices)
    cnt = [0]
    bridges = set()
    def dfs(u,v):
        cnt[0] += 1
        pre[vertices.index(v)] = cnt[0]
        low[vertices.index(v)] = pre[vertices.index(v)]
        for w in connections[v]:
            if w in vertices:
                if (pre[vertices.index(w)] == -1):
                    dfs(v, w)
                    low[vertices.index(v)] = min(low[vertices.index(v)], low[vertices.index(w)])
                    if (low[vertices.index(w)] == pre[vertices.index(w)]):
                        if len(connections[v] & set(vertices)) > 1:
                            bridges.add(v)
                        if len(connections[w] & set(vertices)) > 1:
                            bridges.add(w)
                elif (w != u):
                    low[vertices.index(v)] = min(low[vertices.index(v)], pre[vertices.index(w)])
    dfs(vertices[0],vertices[0])
    return bridges
        
def genericize(rules,connections,visited):
    shared_terms = {}
    for visit in visited:
        for term in get_terms(rules[visit]):                
            if term not in shared_terms:
                shared_terms[term] = 0
            shared_terms[term] += 1
    shared_terms = {term for term in shared_terms if shared_terms[term] > 1}
    
    subsumes = set()
    can_be_generic = set()
    pretty_rules = []
    for visit in visited:
        pretty_rules.append(prettify(rules[visit]))
        
        for term in get_terms(rules[visit]):
            for shared_term in shared_terms:
                if shared_term in term and shared_term != term:
                    subsumes.add(term)
    
    for visit in visited:
        
        for term in get_terms(rules[visit]):
            if term not in subsumes and term not in shared_terms:
                can_be_generic.add(term)
    for rule_id,rule in enumerate(pretty_rules):
        for term in shared_terms:
            pretty_rules[rule_id] = pretty_rules[rule_id].replace(term,term.upper())
            
    to_be_generic = random.randint(0,len(can_be_generic))
    can_be_generic = list(can_be_generic)
    while to_be_generic > 0:
        to_be_generic -= 1
        random.shuffle(can_be_generic)
        before = can_be_generic.pop()
        if '(' in before:
            after =parse_terms(before)[0][0]
            after = after['predicate'] + '(' + ','.join(['_']*len(after['terms'])) + ')'
        else:
            after = before.upper()
                                                        
        for rule_id,rule in enumerate(pretty_rules):
            pretty_rules[rule_id] = pretty_rules[rule_id].replace(before,after)
        
    return visited,pretty_rules

def delete_rule(rules,connections):
    
    bridges = find_bridges(rules,connections)
    can_be_deleted = list(set(rules[1:]) - bridges)
    new_rules = set(rules)
    if len(can_be_deleted) > 0 and len(rules) > 2:
        new_rules.remove(random.choice(can_be_deleted))
    return list(sorted(new_rules))

def score_rule(games,per_game_positives,generated_rules):
    generated_rules = generated_rules[1]
    number_correct = 0
    number_incorrect = 0
    for game,positives in zip(games,per_game_positives):
        rule_string = '.\n'.join(game) + '.\n'
        rule_string += generated_rules[0] + ':-' + ','.join(generated_rules[1:]) + '.\n'
        
        hashed_name = 'temp' + hashlib.sha224(rule_string).hexdigest()
        with open(hashed_name,'wb') as outfile:
            outfile.write(rule_string)
            target_head = parse_terms(generated_rules[0])[0][0]
            outfile.write('#show {}/{}.'.format(target_head['predicate'],len(target_head['terms'])))
                
        
        solved = solve([hashed_name])
        
        is_good = True
        found = []
        total_found = 0
        for t in solved:
            for tt in solved[t]:
                for ttt in tt:
                    if prettify(ttt) in positives:
                        found.append(prettify(ttt))
                        number_correct +=1
                    else:
                        number_incorrect += 1
            
        
                
    
    return number_correct*5-number_incorrect + np.log(len(games)+1)*len(generated_rules)

    
    
chosen_rules = do_walk(rules,connections,min_to_add=1,max_to_add=5)
print delete_rule(chosen_rules,connections)
genericize(rules,connections,chosen_rules)

[-1, 18, 25, 30]


([-1, 4, 18, 25, 30],
 ['player_controls(PADDLE_PLAYER)',
  'singular(PADDLE_PLAYER)',
  'result(MOVE_UP,moves(PADDLE_PLAYER,north,LOW))',
  'precondition(overlaps(BALL,SCREEN_LEFT),COMPUTER_SCORE)',
  'precondition(overlaps(BALL,PADDLE_PLAYER),PLAYER_HIT)'])

In [7]:
population_size = 300
population = [genericize(rules,connections,do_walk(rules,connections,min_to_add=1,max_to_add=5)) for _ in range(population_size)]

In [8]:
from multiprocessing import Pool
poolsize = 7
temperature = 5

def curried(generated_rules):
    return np.exp(-score_rule(games,all_positives,generated_rules)/temperature)

pool = Pool(poolsize)


generations = 10
mutation_probability = 0.95
deletion_probability = 0.35
addition_probability = 0.95
change_probability = 0.35


for gen in range(generations):
    print 'GENERATION', gen
    probs = []
    for ii in range(0,len(population),poolsize):
        probs += pool.map(curried,population[ii:(ii+poolsize)])
    probs = np.array(probs)
    probs /= np.sum(probs)
    
    print probs
    index = np.argmax(probs == np.max(probs))
    print index, probs[index]
    rule = population[index][1]
    print rule[0], ':-'
    print '\n\t'.join(rule[1:])
    new_population = []
    for p in range(population_size):
        chosen = population[np.argmax(np.random.multinomial(1,probs,1))]
        chosen = (list(chosen[0]),list(chosen[1]))
    
        new_population.append(chosen)
        
    mutations = mutation_probability*population_size
    iters = 0
    while mutations > 0 and iters < population_size:

        iters +=1
        p = random.randint(0,len(new_population)-1)
        member = new_population[p][0]
        to_delete = 0
        if random.random() < deletion_probability:
            if len(member) != 2:
                to_delete = 1
                while to_delete > 0:
                    to_delete -= 1
                    member = delete_rule(member,connections)
                
        if random.random() < addition_probability:
            can_add = 5-len(member)
            if can_add > 0:
                member = do_walk(rules,connections,1,can_add,set(member))
        mutations -= 1
        new_population[p] = genericize(rules,connections,member)
    population = new_population
            

GENERATION 0
[ 0.00206943  0.00172806  0.00400059  0.00138718  0.00400059  0.00211065
  0.00268168  0.00257796  0.00172806  0.00400059  0.00172806  0.00268168
  0.00169431  0.00172806  0.00327541  0.00172806  0.00334065  0.00169431
  0.00172806  0.00211065  0.00488634  0.00215269  0.00206943  0.00172806
  0.00908083  0.00321144  0.00215269  0.00206943  0.00498366  0.00908083
  0.00321144  0.00215269  0.00138718  0.00268168  0.00743475  0.00215269
  0.00138718  0.00172806  0.00400059  0.00262931  0.00215269  0.00908083
  0.00908083  0.00172806  0.00268168  0.00169431  0.00172806  0.00215269
  0.00262931  0.00908083  0.00321144  0.00327541  0.00321144  0.00113573
  0.00400059  0.00608706  0.00908083  0.00211065  0.00268168  0.00321144
  0.00327541  0.00172806  0.00400059  0.00138718  0.00215269  0.00268168
  0.00138718  0.00252761  0.00334065  0.00138718  0.00172806  0.00257796
  0.00215269  0.00262931  0.00169431  0.00321144  0.00172806  0.00113573
  0.00138718  0.00334065  0.00268168  

IndexError: list index out of range

In [None]:

probs = []
for ii in range(0,len(population),poolsize):
    probs += pool.map(curried,population[ii:(ii+poolsize)])

probs = np.array(probs)
probs /= np.sum(probs)
index = np.argmax(probs == np.max(probs))

In [None]:
print population[index]