In [None]:
import numpy as np
import pickle
import scipy.ndimage
import matplotlib.pyplot as plt
import math
import fceulib
import networkx as nx
import nxpd
import sets
# TODO: UnionFind, probably via import

In [None]:
"""UnionFind.py

Union-find data structure. Based on Josiah Carlson's code,
http://aspn.activestate.com/ASPN/Cookbook/Python/Recipe/215912
with significant additional changes by D. Eppstein.
"""

class UnionFind:
    """Union-find data structure.

    Each unionFind instance X maintains a family of disjoint sets of
    hashable objects, supporting the following two methods:

    - X[item] returns a name for the set containing the given item.
      Each set is named by an arbitrarily-chosen one of its members; as
      long as the set remains unchanged it will keep the same name. If
      the item is not yet part of a set in X, a new singleton set is
      created for it.

    - X.union(item1, item2, ...) merges the sets containing each item
      into a single larger set.  If any item is not yet part of a set
      in X, it is added to X as one of the members of the merged set.
    """

    def __init__(self):
        """Create a new empty union-find structure."""
        self.weights = {}
        self.parents = {}

    def __getitem__(self, object):
        """Find and return the name of the set containing the object."""

        # check for previously unknown object
        if object not in self.parents:
            self.parents[object] = object
            self.weights[object] = 1
            return object

        # find path of objects leading to the root
        path = [object]
        root = self.parents[object]
        while root != path[-1]:
            path.append(root)
            root = self.parents[root]

        # compress the path and return
        for ancestor in path:
            self.parents[ancestor] = root
        return root
        
    def __iter__(self):
        """Iterate through all items ever found or unioned by this structure."""
        return iter(self.parents)

    def union(self, *objects):
        """Find the sets containing the objects and merge them all."""
        roots = [self[x] for x in objects]
        heaviest = max([(self.weights[r],r) for r in roots])[1]
        for r in roots:
            if r != heaviest:
                self.weights[heaviest] += self.weights[r]
                self.parents[r] = heaviest

In [None]:
inputVec = fceulib.readInputs('movie.fm2')
collisions = pickle.load(open('collisions.pkl'))

(modes, path, merged, unions, track, all_times) = pickle.load(open("modes.pkl"))

plt.plot(track[:,0],track[:,1])
plt.show()
# TODO: define axis elsewhere?
axis = 2
velocities = track[1:,axis]-track[:-1,axis] 

In [None]:
start_time = track[0,0]
transitions = {}
# Edges into [outer] from [inner]
entries_from = {m: {m: [] for m in merged} 
                for m in merged}
# Edges into [outer]
entries = {m: [] for m in merged}
for t in range(1,len(path)):
    if t == 0:
        prev = -1
    else:
        prev = unions[t-1]
    start = all_times[path[t][1][0]]
    
    entries_from[unions[t]][prev].append(start)
    entries[unions[t]].append(start)
    transitions[start] = (prev,unions[t])
    #print (path[t][1][0],start),":",prev,"->",unions[t],"\n",path[unions[t]][1][2][0],path[unions[t]][1][2][1].params

In [None]:
G = nx.MultiDiGraph()
to_add = {}
for tgt,srcs in entries_from.items():
    G.add_node(tgt,label=str(tgt))
    # Let's learn about tgt
    mtype = path[tgt][1][2][0]
    mparams = path[tgt][1][2][1].params
    params = {"type": mtype}
    if mtype == 'c0':
        pass
    elif mtype == 'cP':
        pass
    elif mtype == 'cN':
        params["N"] = mparams[1]
    elif mtype == 'acc0':
        params["acc"] = mparams[1]
    elif mtype == 'accP':
        params["acc"] = mparams[1]
    elif mtype == 'accN':
        params["N"] = mparams[1]
        params["acc"] = mparams[2]
    for k,v in params.items():
        G.node[tgt]["label"] = (G.node[tgt]["label"] + "\n" + "{}: {}".format(k,v))
    print tgt,params
    for src,times in srcs.items():
        for t in times:
            if (src,tgt) not in to_add:
                to_add[(src,tgt)] = []
            to_add[(src,tgt)].append('{}'.format(t))
for e in to_add:
    print e[0],e[1],''.join(to_add[e])
    G.add_edge(e[0],e[1])

G.add_node(-1)
G.add_edge(-1, unions[0])

m2i = {m:i for i,m in enumerate(merged)}
plt.plot(velocities)
plt.plot(np.array(all_times),velocities[np.array(all_times,dtype='int')],'rx')
for u in sorted(unions):
    t0 = all_times[path[u][1][0]]
    t1 = all_times[path[u][1][1]]
    u_ = m2i[unions[u]]
    plt.plot([t0,t1],[unions[u]+10,unions[u]+10])#,colors[u])
        
plt.xlim((50,100))
plt.show()


nxpd.draw(G, show='ipynb')

In [None]:

def button_changes(button_masks):
    last_mask = 0
    mask_times = {}
    for t, b in enumerate(button_masks):
        b_ = int(b)
        buttons = []
        for ii,c in enumerate(list('RLDUTSBA')):
            if b_ & (1 << (7-ii)):
                buttons.append(c)
        l_ = int(last_mask)
        last_buttons = []
        for ii,c in enumerate(list('RLDUTSBA')):
            if l_ & (1 << (7-ii)):
                last_buttons.append(c)
        mask_times[t] = (tuple(last_buttons),tuple(buttons))
        last_mask = b
    
    return mask_times

button_change_times = button_changes(inputVec)
for t in sorted(button_change_times):
    pass #print t, button_change_times[t]

In [None]:
def sign(num):
    if num < 0:
        return -1
    if num > 0:
        return 1
    return 0

def button_diff(btnsA, btnsB):
    return set(btnsA) - set(btnsB)

def button_intersect(btnsA, btnsB):
    return set(btnsA) & set(btnsB)

def button_union(btnsA, btnsB):
    return set(btnsA) | set(btnsB)

def button_preds(button_pairs):
    here_i = set()
    for bp in button_pairs:
        released_i = button_diff(bp[0], bp[1])
        pressed_i = button_diff(bp[1], bp[0])    
        held_i = bp[1]
        for ri in released_i:
            here_i.add(("release",ri))
        for ri in pressed_i:
            here_i.add(("press",ri))
        for ri in held_i:
            here_i.add(("hold",ri))
    return list(here_i)

In [None]:
base_preds = [set()]*len(velocities)
for t in range(0,len(velocities)):
    # TODO: button lag variables
    psi = ([button_change_times[start_time+t+i]
            for i in range(0, 1)],
            #  TODO: stopped colliding/started colliding?  That would mean
           #   I could say "started colliding with X on bottom and also zin,-1"
           #   to help find solid things.
           #     ... no... acc,0 should be enough (walking right across solid tiles)
           #     but I should also consider 
           #     a more sophisticated notion of collision.
           #      e.g. "bottom" is good but it should be the lowest bottom tile.
           #      how can I get that?  can I get that?
           #      (OTOH, maybe this isn't even necessary if e.g. "touching my feet against sky"
          #        doesn't cause vy=0 as often as "touching my feet against ground" does. so let's be
          #         sure that's surfaced!)
           # TODO: collision lag variables?
            collisions.get(start_time+t+(0),set()),
            (velocities[t-1],velocities[t-0])
          )
    buttons_i = psi[0]
    here_i = button_preds(buttons_i)
    for coli in psi[1]:
        here_i.append(("col",coli))
    vel0,vel1 = psi[2]
    if vel0 < vel1:
        here_i.append(("acc",1))
    if vel0 > vel1:
        here_i.append(("acc",-1))
    if vel0 == vel1:
        here_i.append(("acc",0))
    if vel1 < 0:
        here_i.append(("vel",-1))
    if vel1 > 0:
        here_i.append(("vel",1))
    if vel1 == 0:
        here_i.append(("vel",0))
    if vel0 < 0 and vel1 > 0:
        here_i.append(("zc",1))
    if vel0 > 0 and vel1 < 0:
        here_i.append(("zc",-1))
    if vel0 < 0 and vel1 == 0:
        here_i.append(("zin",1))
    if vel0 > 0 and vel1 == 0:
        here_i.append(("zin",-1))
    if vel0 == 0 and vel1 < 0:
        here_i.append(("zout",-1))
    if vel0 == 0 and vel1 > 0:
        here_i.append(("zout",1))
    #cur_mode = X
    # TODO: touched global min/max of velocity for current mode
    #mode_max = max(velocities_in_cur_mode)
    #mode_min = min(velocities_in_cur_mode)
    #if vel1 == mode_max and vel0 != mode_max:
    #  touched max
    #if vel1 == mode_min and vel0 != mode_min:
    #  touched min
    #if vel1 == mode_max:
    #  in min
    #if vel1 == mode_min:
    #  in max
    base_preds[t] = set(here_i)



In [None]:
def intervals_any_contains(intervals, t):
    if intervals is None:
        return True
    for (s,e) in intervals:
        if s <= t <= e:
            return True
    return False

def intervals_summed_length(intervals):
    return sum([e-s for (s,e) in intervals])

def count_events(preds,intervals):
    all_counts = {}
    counts_by_time = {}
    for t,ps in enumerate(preds):
        if not intervals_any_contains(intervals, t): continue
        counts_by_time[t] = {}
        for p in ps:
            all_counts[p] = all_counts.get(p,0)+1
            counts_by_time[t][p] = counts_by_time[t].get(p,0)+1
        if t in transitions:
            tr = transitions[t]
            key = ("tr",tr)
            all_counts[key] = all_counts.get(key,0)+1
            counts_by_time[t][key] = counts_by_time[t].get(key,0)+1
            (_,dest) = tr
            keystar = ("tr",("*",dest))
            all_counts[keystar] = all_counts.get(keystar,0)+1
            counts_by_time[t][keystar] = counts_by_time[t].get(keystar,0)+1
    return all_counts, counts_by_time

def count_coevents(preds,intervals):
    all_counts = {}
    counts_by_time = {}
    for t,ps in enumerate(preds):
        if not intervals_any_contains(intervals, t): continue
        counts_by_time[t] = {}
        for p1 in ps:
            for p2 in ps:
                p = (p1,p2)
                all_counts[p] = all_counts.get(p,0)+1
                counts_by_time[t][p] = counts_by_time[t].get(p,0)+1
        if t in transitions:
            for p1 in ps:
                tr = transitions[t]
                key = (("tr",tr),p1)
                all_counts[key] = all_counts.get(key,0)+1
                counts_by_time[t][key] = counts_by_time[t].get(key,0)+1
                (_,dest) = tr
                keystar = (("tr",("*",dest)),p1)
                all_counts[keystar] = all_counts.get(keystar,0)+1
                counts_by_time[t][keystar] = counts_by_time[t].get(keystar,0)+1
    return all_counts, counts_by_time

In [None]:
all_counts, counts_by_time = count_events(base_preds,None)
all_cocounts, cocounts_by_time = count_coevents(base_preds,None)

In [None]:
# Let's figure out which tiles block movement on which sides!
# co-occurrence of (col, BLAH) and acc0 for each BLAH.
# cluster together tiles which block on a given side (for now, all those with co-occurrence over threshold)
# then add new preds!

def cond_prob(e1s, e2, all_counts, counts_by_time):
    p2 = all_counts[e2]/float(len(counts_by_time))
    count12 = 0
    for t,cs in counts_by_time.items():
        any_e1_present = False
        for e1 in e1s:
            if e1 in cs:
                any_e1_present = True
                break
        if any_e1_present and (e2 in cs):
            count12 += 1
    p12 = count12 / float(len(counts_by_time))
    return p12 / p2

block_chance = {}
for thing,count in all_counts.items():
    # TODO: generalize back to all sides, but note "colliding on right with something" -> "vely=0" is not that sensible.
    #  need a notion of acc,vel,zin,zout and _other axis_ acc,vel,zin,zout.
    if thing[0] != "col": 
        continue
    block_chance[thing] = cond_prob([("vel",0),("acc",0)], 
                                    thing,
                                    all_counts,
                                    counts_by_time)

merged_by_side = {}
# TODO: generalize back to all sides
for side in ["bottom","right","left","top"]:
    # Let's pretend colliding with sprites is the same as colliding with tiles?  Maybe needed for moving platforms?
    blockings = filter(lambda (col,prob):(col[1][0][0] != "solid" and 
                                          col[1][1] == side and 
                                          prob > 0.8),
                       block_chance.items())
    merged_by_side[side] = set()
    for bcol,bprob in blockings:
        merged_by_side[side].add(bcol)
    merged_by_side[side] = sets.ImmutableSet(merged_by_side[side])
        
    
#color_tiles = pickle.load(open('tile2colorized.pkl'))
for side,bcols in merged_by_side.items():
    print "----\n{}\n----".format(side)
    for bc in bcols:
        print block_chance[bc],bc[1][0]

In [None]:
# Let's add new preds now!
extended_preds = [set() for i in range(0,len(base_preds))]
for t,pset in enumerate(base_preds):
    for side,equiv in merged_by_side.items():
        found = False
        for pred in pset:
            extended_preds[t].add(pred)
            if not found and pred[0] == "col" and pred in equiv:
                extended_preds[t].add(("col", (("solid", equiv), side)))
                found = True
all_counts,counts_by_time = count_events(extended_preds,None)
all_cocounts,cocounts_by_time = count_coevents(extended_preds,None)

all_counts,counts_by_time

In [None]:
#TODO: update to support intervals?

def count_conditional_events(preds,condition):
    all_counts = {}
    counts_by_time = {}
    for t,ps in enumerate(preds):
        counts_by_time[t] = {}
        if condition in ps:
            for p in ps:
                all_counts[p] = all_counts.get(p,0)+1
                counts_by_time[t][p] = counts_by_time[t].get(p,0)+1
            if t in transitions:
                tr = transitions[t]
                key = ("tr",tr)
                all_counts[key] = all_counts.get(key,0)+1
                counts_by_time[t][key] = counts_by_time[t].get(key,0)+1
                (_,dest) = tr
                keystar = ("tr",("*",dest))
                all_counts[keystar] = all_counts.get(keystar,0)+1
                counts_by_time[t][keystar] = counts_by_time[t].get(keystar,0)+1
    return all_counts

def count_joint_events(preds,conditions):
    count = 0
    for t,ps in enumerate(preds):
        trans = transitions.get(t,-1)
        is_good = True
        for condition in conditions:
            if condition not in ps and condition != trans:
                is_good = False
                break
        if is_good:
            count += 1
    return count
    
def count_conditional_coevents(preds,condition):
    all_counts = {}
    counts_by_time = {}
    for t,ps in enumerate(preds):
        counts_by_time[t] = {}
        if condition in ps:
            for p1 in ps:
                for p2 in ps:
                    p = (p1,p2)
                    all_counts[p] = all_counts.get(p,0)+1
                    counts_by_time[t][p] = counts_by_time[t].get(p,0)+1
            if t in transitions:
                for p1 in ps:
                    tr = transitions[t]
                    key = (("tr",tr),p1)
                    all_counts[key] = all_counts.get(key,0)+1
                    counts_by_time[t][key] = counts_by_time[t].get(key,0)+1
                    (_,dest) = tr
                    keystar = (("tr",("*",dest)),p1)
                    all_counts[keystar] = all_counts.get(keystar,0)+1
                    counts_by_time[t][keystar] = counts_by_time[t].get(keystar,0)+1
    return all_counts

In [None]:
mode_periods = {}

for t in range(0,len(path)):
    cur = unions[t]
    start = all_times[path[t][1][0]]
    if t + 1 < len(path):
        end = all_times[path[t+1][1][0]]
    else:
        end = len(velocities)
    if cur not in mode_periods:
        mode_periods[cur] = []
    mode_periods[cur].append((start,end))

transition_leadin_intervals = {}

for src in merged:
    for tgt in merged:
        if src == tgt: continue
        intervals = [(s,e)
                      for (s,e) in mode_periods[src] 
                      if e in transitions and transitions[e][1] == tgt]
        transition_leadin_intervals[(src,tgt)] = intervals

In [None]:
##### tile2colorized = pickle.load(open('tile2colorized.pkl'))
id2colorized = pickle.load(open('id2sprites.pkl'))[1]

inv_len = 1.0/float(len(track))
npmis = {}
pmis = {}

for m1 in merged:
    for m2 in merged:
        if m1 == m2: continue
        transition = (m1,m2)
        intervals = transition_leadin_intervals[transition]
        intvl_len = intervals_summed_length(intervals)
        if intvl_len == 0: continue
        tr_counts,_by_time = count_events(extended_preds, intervals)
        tr_cocounts,_by_time = count_coevents(extended_preds, intervals)
        tr_inv_len = 1.0/float(intvl_len)
        for pred in tr_counts:
            if pred != transition:
                p1 = tr_counts[pred]*tr_inv_len
                p2 = tr_counts[("tr",transition)]*tr_inv_len
                cooccur = (("tr",transition),pred)
                if cooccur in tr_cocounts:
                    p12 = tr_cocounts[cooccur]*tr_inv_len
                else:
                    p12 = 0.0
                if p12/(p1*p2) != 0.0:
                    if transition not in npmis:
                        npmis[transition] = {}
                    pmi = np.log(p12/(p1*p2))
                    npmi = pmi/-np.log(p12)
                    pmis[(transition,pred)] = pmi
                    npmis[transition][pred] = npmi
                else:
                    npmis[transition][pred] = -1.0

for t in sorted(npmis):
    for e1,pmi in reversed(sorted(npmis[t].items(), lambda a,b:sign(b[1] - a[1]))):
        if pmi > 0.4:
#             conditioned_px = count_conditional_events(preds,e1)
#             conditioned_pxy = count_conditional_coevents(preds,e1)
            print t, e1, pmi
#             if False:
#                 for pmi2 in reversed(sorted(npmis[t])):
#                     if pmi != pmi2:
#                         e2 =  npmis[t][pmi2]
#                         pmi_t_e1 = pmis[(t,e1)]


#                         p_t_e2_I_e1 = conditioned_pxy.get((("tr",t),e2),0)/float(all_counts[e1])
#                         p_t_I_e1 = conditioned_px.get(("tr",t),0)/float(all_counts[e1])
#                         p_e2_I_e1 = conditioned_px.get(e2,0)/float(all_counts[e1])
#                         if p_t_I_e1 == 0.0 or p_e2_I_e1 == 0.0:
#                             p_t_e2_I_e1 = 0.0
#                             p_t_I_e1 = 1
#                             p_e2_I_e1 = 1

#                         pmi_t_e1_e2 = (pmi_t_e1 + np.log(p_t_e2_I_e1/(p_t_I_e1*p_e2_I_e1)))

#                         p_t_e1_e2 = count_joint_events(preds,[e1,e2,t])*inv_len
#                         if p_t_e1_e2 == 0.0:
#                             npmi_t_e1_e2 = -1.0
#                         else:
#                             npmi_t_e1_e2 = pmi_t_e1_e2/-np.log(p_t_e1_e2)
#                         if npmi_t_e1_e2 > 0.0:
#                             print 'pmi({};{},{})'.format(t,e1,e2), npmi_t_e1_e2, pmi,pmi2
#                             if 'col' in e2 and 'tile' in e2[1][0]:
#                                 plt.imshow(tile2colorized[e2[1][0][:2]]/255.)
#                                 plt.show()
    print ''

In [None]:
# Joe's output version

relevance_threshold = 0.6
universality_threshold = 0.9
negation_threshold = -1.0

# TODO: more selective prioritizing of buttons (press/release over hold) vs 
# collisions vs qualitative stuff.

for tr,pred_npmis in npmis.items():# + paired_npmis.items():
    cond_probs = {k: cond_prob([k], ("tr",tr), all_counts, counts_by_time) 
                  for k in pred_npmis.keys()}
    relevants = {k: npmi for (k,npmi) in pred_npmis.items() if npmi >= relevance_threshold}
    negations = {k: npmi for (k,npmi) in pred_npmis.items() if npmi <= negation_threshold}
    conjuncts = set(filter(lambda k:cond_probs[k] >= universality_threshold, relevants.keys()))
    disjuncts = set(filter(lambda k:cond_probs[k] < universality_threshold and k not in conjuncts, relevants.keys()))
    print "-------\nTransition:",tr
    print "-----\nCond Probs:"
    print "\n".join(map(str,sorted(cond_probs.items(), lambda a,b:sign(b[1]-a[1]))))
    print "-----\nNPMIs:"
    print "\n".join(map(str,sorted(pred_npmis.items(), lambda a,b:sign(b[1]-a[1]))))
    print "-----\nRelevant:\n","\n".join(map(str,relevants.items()))
    print "-----\nConjuncts:\n"," & ".join(map(str,conjuncts))
    print "-----\nDisjuncts:\n"," | ".join(map(str,disjuncts))
    print "-----\nNegations:\n"," & ".join(map(lambda k:"~" + str(k),negations))
    
    


In [None]:
#Adam's output version

import operator
relevance_threshold = 0.4
universality_threshold = 0.8
negation_threshold = -1.0
for tr in sorted(npmis):# + paired_npmis.items():
    if hasattr(tr, "__len__"):
        if tr[0] != tr[1]:
            pred_npmis = npmis[tr]
            universals = {k: npmi for (k,npmi) in pred_npmis.items() if npmi >= universality_threshold}
            relevants = {k: npmi for (k,npmi) in pred_npmis.items() if npmi >= relevance_threshold and k not in universals}
            negations = {k: npmi for (k,npmi) in pred_npmis.items() if npmi <= negation_threshold}
            print "\n"
            print "-------\nTransition:",tr
            print "-----\nUniversals:\n","\n".join(map(str,map(lambda a: a[0],sorted(universals.items(), key=operator.itemgetter(1)))))
            print "-----\nRelevant:\n","\n".join(map(str,map(lambda a: a[0],sorted(relevants.items(), key=operator.itemgetter(1)))))
            #print "-----\nNegations:\n"," & ".join(map(lambda k:"~" + str(k),negations))

In [None]:
pickle.dump((npmis, paired_npmis, entries, entries_from, new_preds, modes, merged, unions, track, inputVec, all_times), 
            open("edges.pkl",'wb'))