Skip to content

HTTPS clone URL

Subversion checkout URL

You can clone with HTTPS or Subversion.

Download ZIP
Browse files

Massive optimisations to the option generation routine; use a cutoff …

…everywhere - code finally scales
  • Loading branch information...
commit 617b8a4d17d03c3b02f32669c60f953f2e153754 1 parent fff0cb3
@arunchaganty authored
Showing with 25 additions and 16 deletions.
  1. +25 −16 src/Environment.py
View
41 src/Environment.py
@@ -158,21 +158,20 @@ class OptionEnvironment( Environment ):
O = []
@staticmethod
- def make_point_option( g, gr, dest ):
+ def make_point_option( g, gr, dest, max_length ):
"""Create an option that takes all connected states to dest"""
- # Get all paths
- paths = nx.shortest_path( gr, source=dest )
+ paths = nx.predecessor(gr, source=dest, cutoff = max_length)
I = set( paths.keys() )
I.remove( dest )
pi = {}
- for src, path in paths.items():
+ for src, succ in paths.items():
if src == dest: continue
# Next link in the path
- dest_ = path[ -2 ]
+ succ = succ[ 0 ]
# Choose the maximum probability action for this edge
- actions = [ (attrs['action'], attrs['pr']) for src, dest__, attrs in g.edges( src, data=True ) if dest__ == dest_ ]
+ actions = [ (attrs['action'], attrs['pr']) for src, succ_, attrs in g.edges( src, data=True ) if succ_ == succ ]
action = max( actions, key = lambda (a,pr): pr )[ 0 ]
pi[ src ] = ((action, 1.0),)
@@ -181,10 +180,14 @@ def make_point_option( g, gr, dest ):
return Option( I, pi, B )
@staticmethod
- def make_path_option( g, gr, start, dest ):
+ def make_path_option( g, gr, start, dest, length = None ):
"""Create an option that takes a state to a dest"""
+ # HACK (using + 3)
+ if length == None:
+ length = nx.shortest_path_length(g, source=start, target=dest)
+ max_length = length + 2
- o = OptionEnvironment.make_point_option( g, gr, dest )
+ o = OptionEnvironment.make_point_option( g, gr, dest, max_length )
# Start not reachable from dest
if not start in o.I:
return None
@@ -331,14 +334,18 @@ def make_options_from_random_paths( g, gr, count, markov ):
@staticmethod
def make_options_from_small_world( g, gr, count, markov, r = None ):
"""Create an option that takes a state to a random nodes as per a power-law dist"""
+ S = len( g.nodes() )
if r is None:
# Estimate as avg degree / 2 == edges / nodes
- r = len( g.edges() ) / float( len( g.nodes() ) )
+ r = len( g.edges() ) / float( S )
+ max_length = np.power( 16, 1.0/r ) # fn of r
# Get all the edges in the graph
- path_lengths = nx.shortest_path_length( g ).items()
- random.shuffle( path_lengths )
+ path_lengths = nx.all_pairs_shortest_path_length( g, cutoff=max_length ).items()
+ states = range(S)
+ random.shuffle(states)
+
print "Building options...\n\n"
progress = ProgressBar( 0, count, mode='fixed' )
# Needed to prevent glitches
@@ -346,13 +353,15 @@ def make_options_from_small_world( g, gr, count, markov, r = None ):
oldprog = str(progress)
options = []
- for node, dists in path_lengths:
+ for s in states:
if len( options ) > count:
break
- dists.pop( node )
+ dists = path_lengths[s][1]
+ dists.pop( s )
if not dists:
continue
+
neighbours, dists = zip( *dists.items() )
# Create a pr distribution
dists = np.power( np.array( dists, dtype=float ), -r )
@@ -361,12 +370,12 @@ def make_options_from_small_world( g, gr, count, markov, r = None ):
if dists[i] == 1: dists[i] = 0
if not dists.any():
continue
- dest = util.choose( zip( neighbours, dists ) )
+ s_ = util.choose( zip( neighbours, dists ) )
# TODO: Prevent choosing subsumed paths
if markov:
- o = OptionEnvironment.make_markov_path_option( g, gr, node, dest )
+ o = OptionEnvironment.make_markov_path_option( g, gr, s, s_ )
else:
- o = OptionEnvironment.make_path_option( g, gr, node, dest )
+ o = OptionEnvironment.make_path_option( g, gr, s, s_ )
if o:
options.append( o )
Please sign in to comment.
Something went wrong with that request. Please try again.