In [1]:
import collections
import numpy as np
import heapq

In [2]:
def sample_gumbel(mu):
    """Sample a Gumbel(mu)."""    
    return -np.log(np.random.exponential()) + mu


def sample_truncated_gumbel(mu, b):
    """Sample a Gumbel(mu) truncated to be less than b."""    
    return -np.log(np.random.exponential() + np.exp(-b + mu)) + mu

  
def sample_gumbel_argmax(logits):
    """Sample from a softmax distribution over logits.

    TODO: check this is correct.

    Args:
    logits: A flat numpy array of logits.

    Returns:
    A sample from softmax(logits).
    """
    return np.argmax(-np.log(np.random.exponential(size=logits.shape)) + logits)


def logsumexp(logits):
    c = np.max(logits)
    return np.log(np.sum(np.exp(logits - c))) + c


def log_softmax(logits, axis=1):
    """Normalize logits per row so that they are logprobs.    
    """
    maxes = np.max(logits, axis=axis, keepdims=True)
    offset_logits = logits - maxes

    log_zs = np.log(np.sum(np.exp(offset_logits), axis=axis, keepdims=True))
    return offset_logits - log_zs

In [3]:
class Policy:
    def __init__(self, num_actions):
        self.W = np.random.randn(num_actions)
        self.b = np.random.randn(num_actions)
        self.num_actions = num_actions

    def __call__(self,state):
        one_hot_state = np.zeros(self.num_actions)
        one_hot_state[state]=1

        out = self.W*one_hot_state + self.b

        return log_softmax(out, axis=0)

In [4]:
# Make a node without a state, and also don't allow `next_actions` to be None.
# Just put in all possible next actions when the node is created.
Node = collections.namedtuple('Node', 
                              [
                                  'prefix',
                                  'logprob_so_far',
                                  'max_gumbel', 
                                  'next_actions',
                              ])

# Namedtuple for storing results
Trajectory = collections.namedtuple('Trajectory', ['actions', 'gumbel'])


def sample_trajectory_gumbels(init_state, max_length,num_actions):
    """Samples an independent Gumbel(logprob) for each trajectory in top-down-ish order.

    Args:
    action_logprobs: A num_actions array of log probabilities of actions. Here
      we assume that the distribution over actions doesn't depend on any state
      (so it's independent per timestep).
    max_length: Maximum length of a trajectory to allow.
    """
    policy = Policy(num_actions)
    action_logprobs = policy(init_state)
    
    # Start with a node for all trajectories.
    root_node = Node(prefix=[], 
                   logprob_so_far=0,
                   max_gumbel=sample_gumbel(0), 
                   next_actions=range(num_actions))
    queue = []
    heapq.heappush(queue, root_node)
    final_trajectories = []

    while queue:
        parent = heapq.heappop(queue)  # TODO replace it with priority queue to pop the maximum

        if len(parent.prefix) == max_length:
            final_trajectories.append(Trajectory(actions=parent.prefix,
                                               gumbel=parent.max_gumbel))
            continue

        # Choose one action from amongst the set of candidates to inherit the max
        # gumbel. Call this the "special" action.
        
        current_state = parent.prefix[-1] if len(parent.prefix)>0 else init_state
        action_logprobs = policy(current_state)
        
        next_action_logprobs = action_logprobs[parent.next_actions]
        special_action_index = sample_gumbel_argmax(next_action_logprobs)
        special_action = parent.next_actions[special_action_index]
        special_action_logprob = action_logprobs[special_action]

        special_child = Node(prefix=parent.prefix + [special_action],
                             logprob_so_far=parent.logprob_so_far + special_action_logprob,
                             max_gumbel=parent.max_gumbel, 
                             next_actions=range(num_actions))  # All next actions are possible.

        heapq.heappush(queue,special_child)

        # Sample the max gumbel for the non-chosen actions and create an "other
        # children" node if there are any alternatives left.
        other_actions = [i for i in parent.next_actions if i != special_action]

        assert len(other_actions) == len(parent.next_actions) - 1

        if other_actions:
            other_max_location = logsumexp(action_logprobs[other_actions])
            other_max_gumbel = sample_truncated_gumbel(parent.logprob_so_far + other_max_location, 
                                                     parent.max_gumbel)
            other_children = Node(prefix=parent.prefix,
                                logprob_so_far=parent.logprob_so_far,
                                max_gumbel=other_max_gumbel,
                                next_actions=other_actions)

            heapq.heappush(queue,other_children)
    return final_trajectories

In [5]:
num_actions = 4
trajectory_length = 3
init_state = 0

trajectories = sample_trajectory_gumbels(init_state,
                                         trajectory_length,
                                         num_actions)

print("Expected {} vs actual {}".format(num_actions**trajectory_length, 
                                        len(trajectories)))
for t in trajectories:
    print (t)

Expected 64 vs actual 64
Trajectory(actions=[0, 0, 0], gumbel=-4.5858048132116345)
Trajectory(actions=[0, 0, 1], gumbel=-4.425509736272827)
Trajectory(actions=[0, 0, 2], gumbel=-6.694908748987112)
Trajectory(actions=[0, 0, 3], gumbel=-3.859501004192877)
Trajectory(actions=[0, 1, 0], gumbel=-3.5402832531686412)
Trajectory(actions=[0, 1, 1], gumbel=0.9782085030670724)
Trajectory(actions=[0, 1, 2], gumbel=-4.68739911668616)
Trajectory(actions=[0, 1, 3], gumbel=-5.464147651437554)
Trajectory(actions=[0, 2, 0], gumbel=-6.354922982709571)
Trajectory(actions=[0, 2, 1], gumbel=-6.302920002128424)
Trajectory(actions=[0, 2, 2], gumbel=-6.42671518203762)
Trajectory(actions=[0, 2, 3], gumbel=-3.8761070531691013)
Trajectory(actions=[0, 3, 0], gumbel=-4.462718952794387)
Trajectory(actions=[0, 3, 1], gumbel=-1.8819105907406584)
Trajectory(actions=[0, 3, 2], gumbel=-5.93267905198373)
Trajectory(actions=[0, 3, 3], gumbel=-4.558111131141923)
Trajectory(actions=[1, 0, 0], gumbel=-3.41378553766058)
Trajec