In [None]:
from frozen_lake import *
env = FrozenLakeEnvMultigoal(goal=2)
#print(env.__doc__)



import numpy as np, numpy.random as nr, gym
import matplotlib.pyplot as plt
%matplotlib inline

# Seed RNGs so you get the same printouts as me
env.seed(0); from gym.spaces import prng; prng.seed(10)

NUM_ITER = 75
gamma = 0.95



class MDP(object):
    def __init__(self, env):
        P, nS, nA, desc = MDP.env2mdp(env)
        self.P = P # state transition and reward probabilities, explained below
        self.nS = nS # number of states
        self.nA = nA # number of actions
        self.desc = desc # 2D array specifying what each grid cell means (used for plotting)
        self.env = env
        self.T = self.get_transition_matrix()

    def env2mdp(env):
        return {s : {a : [tup[:3] for tup in tups] for (a, tups) in a2d.items()} for (s, a2d) in env.P.items()}, env.nS, env.nA, env.desc
    
    def get_transition_matrix(self):
        """Return a matrix with index S,A,S' -> P(S'|S,A)"""
        T = np.zeros([self.nS, self.nA, self.nS])
        for s in range(self.nS):
            for a in range(self.nA):
                transitions = self.P[s][a]
                s_a_s = {t[1]:t[0] for t in transitions}
                for s_prime in range(self.nS):
                    if s_prime in s_a_s:
                        T[s, a, s_prime] = s_a_s[s_prime]
        return T
    
mdp = MDP(env)


#print("mdp.P is a two-level dict where the first key is the state and the second key is the action.")
#print("The 2D grid cells are associated with indices [0, 1, 2, ..., 15] from left to right and top to down, as in")
#print(np.arange(16).reshape(4,4))
#print("mdp.P[state][action] is a list of tuples (probability, nextstate, reward).\n")
#print("For example, state 0 is the initial state, and the transition information for s=0, a=0 is \nP[0][0] =", mdp.P[0][0], "\n")
#print("As another example, state 5 corresponds to a hole in the ice, which transitions to itself with probability 1 and reward 0.")
#print("P[5][0] =", mdp.P[5][0], '\n')

In [None]:
def compute_value(mdp, gamma, reward,  threshold=1e-4):
    """
    Find the optimal value function via value iteration with the max-ent Bellman backup 
    given at https://graphics.stanford.edu/projects/gpirl/gpirl_supplement.pdf.

    reward: Vector of rewards for each state.
    threshold: Convergence threshold.
    gamma: MDP gamma factor. float.
    -> Array of values for each state
    """

    v = np.zeros(mdp.nS)

    diff = float("inf")
    while diff > threshold:
        v_prev = np.copy(v)
        diff = 0
        for s in range(mdp.nS):
            v_s_new = 0
            for a in range(mdp.nA):
                v_s_new += np.exp(np.dot(mdp.T[s, a, :], reward[s] + gamma*v_prev))
            v[s] = np.log(v_s_new)
            new_diff = abs(v_prev[s] - v[s])
            if new_diff > diff:
                diff = new_diff         
    return v


def compute_policy(mdp, gamma, reward=None, V=None, threshold=1e-4):
    
    if reward is None and V is None: raise Exception('Cannot compute V: no reward provided')
    if V is None: V = compute_value(mdp, gamma, reward, threshold=threshold)

    policy = np.zeros((mdp.nS, mdp.nA))
    for s in range(mdp.nS):
        for a in range(mdp.nA):
            policy[s,a] = np.exp(r1[s] - V[s] + np.dot(mdp.T[s, a,:], gamma * V))
    
    return policy


def generate_trajectories(mdp, policy, T=20, D=50):
    s = mdp.env.reset()
    
    trajectories = np.zeros([D, T, 2]).astype(int)
    
    for d in range(D):
        for t in range(T):
            action = np.random.choice(range(mdp.nA), p=policy[s, :])
            trajectories[d, t, :] = [s, action]
            s, _, _, _ = mdp.env.step(action)
        s = mdp.env.reset()
    
    return trajectories

In [None]:
mdp1 = MDP(FrozenLakeEnvMultigoal(is_slippery=False, goal=1))


t1 = mdp1.get_transition_matrix()

r1 = np.zeros(64)
r1[63] = 1.0

policy1 = compute_policy(mdp1, gamma, r1, threshold=1e-8)

In [None]:
trajectories1 = generate_trajectories(mdp1, policy1, T=50, D=100)

trajectories1.shape

In [None]:
def compute_irl_log_likelihood(mdp, gamma, trajectories, V, r):

    L_D = 0

    for traj in trajectories:
        for (s, a) in traj:
            L_D += r[s] - V[s] + np.dot(mdp.T[s,a,:], gamma * V)
    
    return L_D



def compute_s_a_visitations(mdp, gamma, trajectories):
    mu_hat_sa = np.zeros((mdp1.nS, mdp1.nA))
    v_hat_s = np.zeros((mdp1.nS))
    for traj in trajectories:
        for (s, a) in traj:
            mu_hat_sa[s, a] += 1
            v_hat_s[s] += 1

            v_hat_s -= gamma * mdp1.T[s,a,:]
            # Same as the line above but slower:
            #for (s_prime, p_transition) in enumerate(t1[s,a,:]):
            #    v_hat_s[s_prime] -= gamma * p_transition
    return(mu_hat_sa, v_hat_s)
        
        
def compute_mu_tilda(mdp, gamma, V, policy, v_hat_s, mu_tilda = None, threshold = 1e-4):
    """
    Computes occupancy measure of a MDP under a given policy -- 
    the expected discounted number of times that policy π visits state s.
    """
    assert V.shape[0] == mdp.nS
    assert policy.shape == (mdp.nS, mdp.nA)    
    
    if mu_tilda is None: mu_tilda =  np.zeros((mdp.nS, mdp.nA))
    else: mu_tilda = np.tile(mu_tilda, (mdp.nA, 1))
    
    diff = float("inf")
    
    while diff > threshold:
        mu_tilda_new = np.copy(mu_tilda)
        
        for s_prime in range(mdp.nS):
            for a_prime in range(mdp.nA):
                mu_tilda_new[s_prime, a_prime] = (policy[s_prime, a_prime] * 
                                                 (v_hat_s[s_prime] + 
                                                  np.sum(gamma * mdp.T[:,:,s_prime] * mu_tilda)))
        
        diff = np.amax(abs(mu_tilda - mu_tilda_new))    
        mu_tilda = np.copy(mu_tilda_new)
    
    return np.sum(mu_tilda, 1)


def irl_log_likelihood_and_grad(mdp, gamma, r, trajectories):
    
    V = compute_value(mdp1, gamma, r)
    # IRL log likelihood term
    L_D = compute_irl_log_likelihood(mdp, gamma, trajectories, V, r)
    
    # IRL log likelihood gradient w.r.t reward
    policy = compute_policy(mdp1, gamma, V=V) 
    mu_hat, v_hat_s = compute_s_a_visitations(mdp, gamma, trajectories)
    mu_tilda = compute_mu_tilda(mdp, gamma, V, policy, v_hat_s)
    
    dL_D_dr = np.sum(mu_hat,1) - mu_tilda
    
    return L_D, dL_D_dr

irl_log_likelihood_and_grad(mdp1, gamma, r1, trajectories1)