Skip to content

Commit

Permalink
FOIL (#625)
Browse files Browse the repository at this point in the history
* Added predicate_symbols

* Added FOIL

* Updated README
  • Loading branch information
Chipe1 authored and norvig committed Aug 24, 2017
1 parent a065c3b commit 718224a
Show file tree
Hide file tree
Showing 5 changed files with 335 additions and 18 deletions.
2 changes: 1 addition & 1 deletion README.md
Expand Up @@ -112,7 +112,7 @@ Here is a table of algorithms, the figure, name of the algorithm in the book and
| 19.2 | Current-Best-Learning | `current_best_learning` | [`knowledge.py`](knowledge.py) | Done |
| 19.3 | Version-Space-Learning | `version_space_learning` | [`knowledge.py`](knowledge.py) | Done |
| 19.8 | Minimal-Consistent-Det | `minimal_consistent_det` | [`knowledge.py`](knowledge.py) | Done |
| 19.12 | FOIL | | |
| 19.12 | FOIL | `FOIL_container` | [`knowledge.py`](knowledge.py) | Done |
| 21.2 | Passive-ADP-Agent | `PassiveADPAgent` | [`rl.py`][rl] | Done |
| 21.4 | Passive-TD-Agent | `PassiveTDAgent` | [`rl.py`][rl] | Done |
| 21.8 | Q-Learning-Agent | `QLearningAgent` | [`rl.py`][rl] | Done |
Expand Down
116 changes: 115 additions & 1 deletion knowledge.py
@@ -1,9 +1,12 @@
"""Knowledge in learning, Chapter 19"""

from random import shuffle
from math import log
from utils import powerset
from collections import defaultdict
from itertools import combinations
from itertools import combinations, product
from logic import (FolKB, constant_symbols, predicate_symbols, standardize_variables,
variables, is_definite_clause, subst, expr, Expr)

# ______________________________________________________________________________

Expand Down Expand Up @@ -231,6 +234,117 @@ def consistent_det(A, E):
# ______________________________________________________________________________


class FOIL_container(FolKB):
"""Holds the kb and other necessary elements required by FOIL"""

def __init__(self, clauses=[]):
self.const_syms = set()
self.pred_syms = set()
FolKB.__init__(self, clauses)

def tell(self, sentence):
if is_definite_clause(sentence):
self.clauses.append(sentence)
self.const_syms.update(constant_symbols(sentence))
self.pred_syms.update(predicate_symbols(sentence))
else:
raise Exception("Not a definite clause: {}".format(sentence))

def foil(self, examples, target):
"""Learns a list of first-order horn clauses
'examples' is a tuple: (positive_examples, negative_examples).
positive_examples and negative_examples are both lists which contain substitutions."""
clauses = []

pos_examples = examples[0]
neg_examples = examples[1]

while pos_examples:
clause, extended_pos_examples = self.new_clause((pos_examples, neg_examples), target)
# remove positive examples covered by clause
pos_examples = self.update_examples(target, pos_examples, extended_pos_examples)
clauses.append(clause)

return clauses

def new_clause(self, examples, target):
"""Finds a horn clause which satisfies part of the positive
examples but none of the negative examples.
The horn clause is specified as [consequent, list of antecedents]
Return value is the tuple (horn_clause, extended_positive_examples)"""
clause = [target, []]
# [positive_examples, negative_examples]
extended_examples = examples
while extended_examples[1]:
l = self.choose_literal(self.new_literals(clause), extended_examples)
clause[1].append(l)
extended_examples = [sum([list(self.extend_example(example, l)) for example in
extended_examples[i]], []) for i in range(2)]

return (clause, extended_examples[0])

def extend_example(self, example, literal):
"""Generates extended examples which satisfy the literal"""
# find all substitutions that satisfy literal
for s in self.ask_generator(subst(example, literal)):
s.update(example)
yield s

def new_literals(self, clause):
"""Generates new literals based on known predicate symbols.
Generated literal must share atleast one variable with clause"""
share_vars = variables(clause[0])
for l in clause[1]:
share_vars.update(variables(l))

for pred, arity in self.pred_syms:
new_vars = {standardize_variables(expr('x')) for _ in range(arity - 1)}
for args in product(share_vars.union(new_vars), repeat=arity):
if any(var in share_vars for var in args):
yield Expr(pred, *[var for var in args])

def choose_literal(self, literals, examples):
"""Chooses the best literal based on the information gain"""
def gain(l):
pre_pos = len(examples[0])
pre_neg = len(examples[1])
extended_examples = [sum([list(self.extend_example(example, l)) for example in
examples[i]], []) for i in range(2)]
post_pos = len(extended_examples[0])
post_neg = len(extended_examples[1])
if pre_pos + pre_neg == 0 or post_pos + post_neg == 0:
return -1

# number of positive example that are represented in extended_examples
T = 0
for example in examples[0]:
def represents(d):
return all(d[x] == example[x] for x in example)
if any(represents(l_) for l_ in extended_examples[0]):
T += 1

return T * log((post_pos*(pre_pos + pre_neg) + 1e-4) / ((post_pos + post_neg)*pre_pos))

return max(literals, key=gain)

def update_examples(self, target, examples, extended_examples):
"""Adds to the kb those examples what are represented in extended_examples
List of omitted examples is returned"""
uncovered = []
for example in examples:
def represents(d):
return all(d[x] == example[x] for x in example)
if any(represents(l) for l in extended_examples):
self.tell(subst(example, target))
else:
uncovered.append(example)

return uncovered


# ______________________________________________________________________________


def check_all_consistency(examples, h):
"""Check for the consistency of all examples under h"""
for e in examples:
Expand Down
34 changes: 22 additions & 12 deletions logic.py
Expand Up @@ -196,7 +196,7 @@ def tt_entails(kb, alpha):
True
"""
assert not variables(alpha)
symbols = prop_symbols(kb & alpha)
symbols = list(prop_symbols(kb & alpha))
return tt_check_all(kb, alpha, symbols, {})


Expand All @@ -216,23 +216,33 @@ def tt_check_all(kb, alpha, symbols, model):


def prop_symbols(x):
"""Return a list of all propositional symbols in x."""
"""Return the set of all propositional symbols in x."""
if not isinstance(x, Expr):
return []
return set()
elif is_prop_symbol(x.op):
return [x]
return {x}
else:
return list(set(symbol for arg in x.args for symbol in prop_symbols(arg)))
return {symbol for arg in x.args for symbol in prop_symbols(arg)}


def constant_symbols(x):
"""Return a list of all constant symbols in x."""
"""Return the set of all constant symbols in x."""
if not isinstance(x, Expr):
return []
return set()
elif is_prop_symbol(x.op) and not x.args:
return [x]
return {x}
else:
return list({symbol for arg in x.args for symbol in constant_symbols(arg)})
return {symbol for arg in x.args for symbol in constant_symbols(arg)}


def predicate_symbols(x):
"""Return a set of (symbol_name, arity) in x.
All symbols (even functional) with arity > 0 are considered."""
if not isinstance(x, Expr) or not x.args:
return set()
pred_set = {(x.op, len(x.args))} if is_prop_symbol(x.op) else set()
pred_set.update({symbol for arg in x.args for symbol in predicate_symbols(arg)})
return pred_set


def tt_true(s):
Expand Down Expand Up @@ -549,7 +559,7 @@ def dpll_satisfiable(s):
function find_pure_symbol is passed a list of unknown clauses, rather
than a list of all clauses and the model; this is more efficient."""
clauses = conjuncts(to_cnf(s))
symbols = prop_symbols(s)
symbols = list(prop_symbols(s))
return dpll(clauses, symbols, {})


Expand Down Expand Up @@ -652,7 +662,7 @@ def WalkSAT(clauses, p=0.5, max_flips=10000):
"""Checks for satisfiability of all clauses by randomly flipping values of variables
"""
# Set of all symbols in all clauses
symbols = set(sym for clause in clauses for sym in prop_symbols(clause))
symbols = {sym for clause in clauses for sym in prop_symbols(clause)}
# model is a random assignment of true/false to the symbols in clauses
model = {s: random.choice([True, False]) for s in symbols}
for i in range(max_flips):
Expand All @@ -663,7 +673,7 @@ def WalkSAT(clauses, p=0.5, max_flips=10000):
return model
clause = random.choice(unsatisfied)
if probability(p):
sym = random.choice(prop_symbols(clause))
sym = random.choice(list(prop_symbols(clause)))
else:
# Flip the symbol in clause that maximizes number of sat. clauses
def sat_count(sym):
Expand Down

0 comments on commit 718224a

Please sign in to comment.