# Monster task model

In [20]:
from LOTlib3.Grammar import Grammar
from LOTlib3.DefaultGrammars import DNF
from math import log
from LOTlib3.Hypotheses.LOTHypothesis import LOTHypothesis
from LOTlib3.Hypotheses.Likelihoods.BinaryLikelihood import BinaryLikelihood
from LOTlib3.DataAndObjects import FunctionData, Obj
from LOTlib3.Eval import primitive


from LOTlib3 import break_ctrlc
from LOTlib3.Miscellaneous import qq
from LOTlib3.TopN import TopN
from LOTlib3.Samplers.MetropolisHastings import MetropolisHastingsSampler

# Generative grammar

`grammar.add_rule( <NONTERMINAL>, <FUNCTION>, <ARGUMENTS>, <PROBABILITY>)`

Note: non-terminal arguments get passed as normal python arguments. E.g.: 

`is_color_(OBJECT, 'red')` $\rightarrow$ `OBJECT.color == 'red'`

In [257]:
@primitive
def fequal_(x, f, y):
    return getattr(x, f) == int(y)


@primitive
def fle_(x, f, y):
    return getattr(x, (f)) < int(y)


ndims = 1

# Define a grammar object
grammar = Grammar(start='START')

grammar.add_rule('START', '', ['Q'], 1.0) # Quantifier

# The following lines are crucial for "recognizing OBJECTS provided in the data"
grammar.add_rule('Q', 'forall_', ['FUNCTION', 'SET'], 1.0)
grammar.add_rule('SET', 'S', None, 1.0)
grammar.add_rule('FUNCTION', 'lambda', ['DISJ'], 1.0, bv_type='X')

# Logical operations
grammar.add_rule('DISJ', 'or_', ['CONJ', 'DISJ'], 1.0)
grammar.add_rule('DISJ', 'False', None, 1.0)
grammar.add_rule('CONJ', 'and_', ['CONJ', 'P'], 1.0)
grammar.add_rule('CONJ', 'True', None, 1.0)

# Predicate becomes feature predicate
grammar.add_rule('P', '', ['FP'], 1.0)


for v in range(1, 7):
    grammar.add_rule('FP', 'fequal_', ['X', 'F', str(v)], 1.0)
    grammar.add_rule('FP', 'fle_', ['X', 'F', str(v)], 1.0)
grammar.add_rule('F', '\'F0\'', None, 1.0)

# Boundary rules
# grammar.add_rule('P', 'is_color_',  ['X', '\'red\''],   1.0) 



# Terminals
# grammar.add_rule('TRUE', 'True', None, 1.0)
# grammar.add_rule('FALSE', 'False', None, 1.0)

for _ in range(10):
    t = grammar.generate()
    print(f'{grammar.log_probability(t):.3f}', t)

-45.761 forall_(lambda y2: or_(and_(and_(and_(and_(True, fle_(y2, 'F0', 1)), fle_(y2, 'F0', 4)), fequal_(y2, 'F0', 1)), fle_(y2, 'F0', 6)), or_(and_(True, fequal_(y2, 'F0', 5)), or_(and_(and_(and_(True, fle_(y2, 'F0', 2)), fequal_(y2, 'F0', 5)), fle_(y2, 'F0', 4)), or_(True, or_(and_(and_(and_(and_(True, fequal_(y2, 'F0', 4)), fequal_(y2, 'F0', 6)), fequal_(y2, 'F0', 6)), fle_(y2, 'F0', 3)), False))))), S)
-0.693 forall_(lambda y2: False, S)
-0.693 forall_(lambda y2: False, S)
-3.466 forall_(lambda y2: or_(True, or_(True, False)), S)
-11.208 forall_(lambda y2: or_(and_(and_(True, fle_(y2, 'F0', 4)), fle_(y2, 'F0', 4)), or_(True, or_(True, False))), S)
-13.981 forall_(lambda y2: or_(True, or_(True, or_(and_(and_(True, fle_(y2, 'F0', 1)), fle_(y2, 'F0', 3)), or_(True, or_(True, False))))), S)
-26.693 forall_(lambda y2: or_(True, or_(and_(and_(and_(and_(True, fequal_(y2, 'F0', 3)), fequal_(y2, 'F0', 3)), fle_(y2, 'F0', 4)), fle_(y2, 'F0', 5)), or_(True, or_(and_(and_(True, fle_(y2, 'F0', 

# Data

In [258]:
# datum1 = FunctionData(input=[('1', '1')], output=True, alpha=0.99)
# datum2 = FunctionData(input=[('6', '6')], output=False, alpha=0.99)
# data = [datum1, datum2]

class stim1d(object):
    def __init__(self, F0):
        self.F0 = F0

    def __repr__(self):
        return f's[{self.F0}]'

# print(datum1)
# print(datum2)
# datum1 = FunctionData(input=[{Obj(F0=1)}], output=True, alpha=0.99)
# datum2 = FunctionData(input=[{Obj(F0=6)}], output=False, alpha=0.99)
datum1 = FunctionData(input=[{stim1d(F0=1)}], output=True, alpha=0.99)
datum2 = FunctionData(input=[{stim1d(F0=2)}], output=True, alpha=0.99)
datum3 = FunctionData(input=[{stim1d(F0=3)}], output=True, alpha=0.99)
datum4 = FunctionData(input=[{stim1d(F0=4)}], output=False, alpha=0.99)
datum5 = FunctionData(input=[{stim1d(F0=5)}], output=False, alpha=0.99)
datum6 = FunctionData(input=[{stim1d(F0=6)}], output=False, alpha=0.99)
data = [datum1, datum2, datum3, datum4, datum5, datum6]

print(datum1)
print(datum2)
print(datum3)
print(datum4)
print(datum5)
print(datum6)


<{s[1]} -> True>
<{s[2]} -> True>
<{s[3]} -> True>
<{s[4]} -> False>
<{s[5]} -> False>
<{s[6]} -> False>


# Inference

In [259]:
class MyHypothesis(BinaryLikelihood, LOTHypothesis):
    def __init__(self, grammar=grammar, **kwargs):
        LOTHypothesis.__init__(self, grammar=grammar, display='lambda S: %s', **kwargs)

**Two hypotheses**

In [260]:
h1 = MyHypothesis()
h2 = MyHypothesis()

print(
    f'• The hypotheses are: ',
    f'\t• {h1} (prior = {h1.compute_prior():.3f})',
    f'\t• {h2} (prior = {h2.compute_prior():.3f})',
    f'• Given:',
    f'\t• a positive example: {datum1}',
    f'\t• a negative example: {datum2}',
    f'• The likelihood of h1 (={h1.compute_likelihood(data):.3f})',
    f'• The likelihood of h2 (={h2.compute_likelihood(data):.3f})',
    f'• Thus the posterior beliefs are:',
    f'\t• h1 = {h1.compute_posterior(data):.3f}',
    f'\t• h2 = {h2.compute_posterior(data):.3f}',
    sep='\n'
)

TooBigException: 

**MCMC**

In [261]:
h0 = MyHypothesis()

top  = TopN(N=10)
thin = 100
for i, h in enumerate(break_ctrlc(MetropolisHastingsSampler(h0, data, steps=1000))):
    top << h
    if i % thin == 0:
        print(f'{i}), post={h.posterior_score:.2f}, prior={h.prior:.2f}, lik={h.likelihood:.2f}, qq={qq(h)}')


print('=== WINNERS ===')
for i, h in enumerate(top):
    print(f'{i}), post={h.posterior_score:.2f}, prior={h.prior:.2f}, lik={h.likelihood:.2f}, qq={qq(h)}')

0), post=-17.99, prior=-2.08, lik=-15.91, qq="lambda S: forall_(lambda y2: or_(True, False), S)"
100), post=-5.29, prior=-5.26, lik=-0.03, qq="lambda S: forall_(lambda y2: or_(and_(True, fle_(y2, 'F0', 4)), False), S)"
200), post=-5.29, prior=-5.26, lik=-0.03, qq="lambda S: forall_(lambda y2: or_(and_(True, fle_(y2, 'F0', 4)), False), S)"
300), post=-5.29, prior=-5.26, lik=-0.03, qq="lambda S: forall_(lambda y2: or_(and_(True, fle_(y2, 'F0', 4)), False), S)"
400), post=-5.29, prior=-5.26, lik=-0.03, qq="lambda S: forall_(lambda y2: or_(and_(True, fle_(y2, 'F0', 4)), False), S)"
500), post=-5.29, prior=-5.26, lik=-0.03, qq="lambda S: forall_(lambda y2: or_(and_(True, fle_(y2, 'F0', 4)), False), S)"
600), post=-5.29, prior=-5.26, lik=-0.03, qq="lambda S: forall_(lambda y2: or_(and_(True, fle_(y2, 'F0', 4)), False), S)"
700), post=-5.29, prior=-5.26, lik=-0.03, qq="lambda S: forall_(lambda y2: or_(and_(True, fle_(y2, 'F0', 4)), False), S)"
800), post=-5.29, prior=-5.26, lik=-0.03, qq="lam