In [None]:
!pip install ortools

Collecting ortools
  Downloading ortools-9.9.3963-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (24.8 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m24.8/24.8 MB[0m [31m31.7 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting absl-py>=2.0.0 (from ortools)
  Downloading absl_py-2.1.0-py3-none-any.whl (133 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m133.7/133.7 kB[0m [31m6.9 MB/s[0m eta [36m0:00:00[0m
Collecting protobuf>=4.25.3 (from ortools)
  Downloading protobuf-5.26.1-cp37-abi3-manylinux2014_x86_64.whl (302 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m302.8/302.8 kB[0m [31m14.3 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting immutabledict>=3.0.0 (from ortools)
  Downloading immutabledict-4.2.0-py3-none-any.whl (4.7 kB)
Installing collected packages: protobuf, immutabledict, absl-py, ortools
  Attempting uninstall: protobuf
    Found existing installation: protobuf 3.20.3
    Uninstalling protobuf-3.2

In [None]:
import numpy as np
from ortools.linear_solver import pywraplp
import matplotlib.pyplot as plt

In [None]:
def value(policy, n_states, transition_probabilities, reward, discount,
                    threshold=1e-2):
    """
    Find the value function associated with a policy.

    policy: List of action ints for each state.
    n_states: Number of states. int.
    transition_probabilities: Function taking (state, action, state) to
        transition probabilities.
    reward: Vector of rewards for each state.
    discount: MDP discount factor. float.
    threshold: Convergence threshold, default 1e-2. float.
    -> Array of values for each state
    """
    v = np.zeros(n_states)

    diff = float("inf")
    while diff > threshold:
        diff = 0
        for s in range(n_states):
            vs = v[s]
            a = policy[s]
            v[s] = sum(transition_probabilities[s, a, k] *
                       (reward[k] + discount * v[k])
                       for k in range(n_states))
            diff = max(diff, abs(vs - v[s]))

    return v

In [None]:
def v_tensor(value, transition_probability, feature_dimension, n_states,
             n_actions, policy):
    """
    Finds the v tensor used in large linear IRL.

    value: NumPy matrix for the value function. The (i, j)th component
        represents the value of the jth state under the ith basis function.
    transition_probability: NumPy array mapping (state_i, action, state_k) to
        the probability of transitioning from state_i to state_k under action.
        Shape (N, A, N).
    feature_dimension: Dimension of the feature matrix. int.
    n_states: Number of states sampled. int.
    n_actions: Number of actions. int.
    policy: NumPy array mapping state ints to action ints.
    -> v helper tensor.
    """

    v = np.zeros((n_states, n_actions-1, feature_dimension))
    for i in range(n_states):
        a1 = policy[i]
        exp_on_policy = np.dot(transition_probability[i, a1], value.T)
        optimal_action_seen = False
        for j in range(n_actions):
            # Skip this if it's the on-policy action.
            if a1 == j:
                optimal_action_seen = True
                continue
            else:
                exp_off_policy = np.dot(transition_probability[i, j], value.T)
                if optimal_action_seen:
                    v[i, j-1] = exp_on_policy - exp_off_policy
                else:
                    v[i, j] = exp_on_policy - exp_off_policy

    return v

In [1]:
def infinite_state_space_IRL(N, k, d, policy, V):
    actions = set(range(k))
    solver = pywraplp.Solver.CreateSolver('GLOP')

    A = []
    for i in range(d):
        tmp_a = solver.NumVar(-1,1,f"a_{i}")
        A.append(tmp_a)

    Z = []
    Y = []
    for i in range(N):
        y = []
        tmp_z = solver.NumVar(-solver.infinity(), solver.infinity(), f"Z_{i}")
        for j in range(k-1):
            tmp_y = solver.NumVar(-solver.infinity(), solver.infinity(), f"Y_{i}_{j}")
            y.append(tmp_y)
        Z.append(tmp_z)
        Y.append(y)

    objective = solver.Objective()
    for i in range(d):
        objective.SetCoefficient(A[i],0)

    for i in range(N):
        objective.SetCoefficient(Z[i],1)
    objective.SetMaximization()

    for s in range(N):
        a1 = policy[s]
        optimal_action_seen = False
        for j in range(k):
            if j == a1:
                optimal_action_seen = True
                continue
            else:
                if optimal_action_seen:
                    a = j - 1
                else:
                    a = j

                constraint1 = solver.Constraint(-solver.infinity(),0)
                constraint2 = solver.Constraint(-solver.infinity(),0)
                for i in range(d):
                    constraint1.SetCoefficient(A[i], -(V[s,a,i]))
                    constraint2.SetCoefficient(A[i], -2*V[s,a,i])
                constraint1.SetCoefficient(Y[s][a],1)
                constraint2.SetCoefficient(Y[s][a],1)

                constraint3 = solver.Constraint(-solver.infinity(), Y[s][a])
                constraint3.SetCoefficient(Z[s], 1)


    status = solver.Solve()
    if status == pywraplp.Solver.OPTIMAL:
        # Retrieve solution
        solution = [A[i].solution_value() for i in range(d)]
        optimal_value = solver.Objective().Value()
        return solution, optimal_value
    else:
        return None, None