### Make environment

In [64]:
import os

curr_dir = os.path.split(os.getcwd())[1]
if curr_dir != "irl-environment-design":
    os.chdir("..")

import numpy as np

from src.utils.constants import ParamTuple
from src.utils.make_environment import (
    transition_matrix,
    Environment,
    insert_walls_into_T,
)

height = 7
width = 7
rewards = np.zeros((height, width))
wall_states = np.zeros((height, width))

rewards[6,0] = 1
rewards[6,6] = 3

# rewards[5,0] = -1
# rewards[4,0] = -1

rewards = rewards.flatten()

goal_states = np.where(rewards > 0)[0]

wall_states[2,0] = 1
# wall_states[4,1] = 1

wall_states = wall_states.flatten()
wall_states = np.where(wall_states > 0)[0]

agent_p = 0.9
agent_gamma = 0.7
p_true=1

true_params = ParamTuple(agent_p, agent_gamma, rewards)

wall_states = [14]

T_true = transition_matrix(height, width, p=p_true, absorbing_states=goal_states)
T_True = insert_walls_into_T(T=T_true, wall_indices=wall_states)

### Value Iteration to find fixed point.

In [4]:
import torch

def soft_q_iteration_torch(
    R: torch.Tensor,  # R is a one-dimensional tensor with shape (n_states,)
    T_agent: torch.Tensor,
    gamma: float,
    beta: float,  # Inverse temperature parameter for the softmax function
    tol: float = 1e-6,
) -> torch.Tensor:
    n_states, n_actions, _ = T_agent.shape
    V = torch.zeros(n_states)
    Q = torch.zeros((n_states, n_actions))
    policy = torch.zeros((n_states, n_actions))

    while True:
        for s in range(n_states):
            for a in range(n_actions):
                # Calculate the Q-value for action a in state s
                Q[s, a] = R[s] + gamma * torch.dot(T_agent[s, a], V)

        # Apply softmax to get a probabilistic policy
        # max_Q = torch.max(Q, axis=1, keepdim=True)[0]
        exp_Q = torch.exp(beta * (Q))  # Subtract max_Q for numerical stability
        # exp_Q = torch.exp(beta * (Q - max_Q))  # Subtract max_Q for numerical stability

        policy = exp_Q / torch.sum(exp_Q, axis=1, keepdim=True)

        # Calculate the value function V using the probabilistic policy
        V_new = torch.sum(policy * Q, axis=1)

        # Check for convergence
        if torch.max(torch.abs(V - V_new)) < tol:
            break

        V = V_new

    return Q, V, policy

In [5]:
Q_star, V_star, policy_star = soft_q_iteration_torch(R = torch.tensor(rewards, dtype=torch.float32), T_agent = torch.tensor(T_True, dtype=torch.float32), gamma = agent_gamma, beta = 1)

### Soft Bellman Operator for fixed point.

In [51]:
def soft_bellman_update(R, gamma, T, Q):
    return R + gamma * torch.matmul(T, torch.log(torch.sum(torch.exp(Q), axis=1)))

def soft_bellman_fp(R, gamma, T, Q):
    return soft_bellman_update(R, gamma, T, Q) - Q

def soft_q_iteration_torch(R, gamma, T, Q_init=None, tol=1e-6):
    if Q_init is None:
        Q_init = torch.zeros(R.shape[0], 4)

    Q = Q_init
    
    while True:
        Q_new = soft_bellman_update(R, gamma,T, Q)
        if torch.max(torch.abs(Q - Q_new)) < tol:
            break
        Q = Q_new
    return Q

In [56]:
R = torch.tensor(rewards, dtype=torch.float32, requires_grad=True)
R_stretched = R.unsqueeze(1).repeat(1,4) # Stretch R to match the shape of Q, e.g. we switch from R(s) to R(s,a).

Q_star = soft_q_iteration_torch(R_stretched, agent_gamma, torch.tensor(T_True, dtype=torch.float32), Q_init=None, tol=1e-6)

beta = 1
exp_Q = torch.exp(beta * (Q_star))  # Subtract max_Q for numerical stability

policy = exp_Q / torch.sum(exp_Q, axis=1, keepdim=True)

# Calculate the value function V using the probabilistic policy
V_new = torch.sum(policy * Q_star, axis=1)

Calculate Derivative of Value Function.

In [57]:
R_grad = torch.autograd.Variable(R_stretched, requires_grad=True)

psi = soft_bellman_fp(R = R_grad, gamma=agent_gamma, T=torch.tensor(T_True, dtype=torch.float32), Q = Q_star)

psi.backward(torch.ones_like(psi), inputs=[R])

In [60]:
psi_grad = - R.grad
psi_grad

tensor([4.0000, 4.0000, 4.0000, 4.0000, 4.0000, 4.0000, 4.0000, 4.0000, 4.0000,
        4.0000, 4.0000, 4.0000, 4.0000, 4.0000, 4.0000, 4.0000, 4.0000, 4.0000,
        4.0000, 4.0000, 4.0000, 4.0000, 4.0000, 4.0000, 4.0000, 4.0000, 4.0000,
        4.0000, 4.0000, 4.0000, 4.0000, 4.0000, 4.0000, 4.0000, 4.0000, 4.0000,
        4.0000, 4.0000, 4.0000, 4.0000, 4.0000, 4.0000, 4.0000, 4.0000, 4.0000,
        4.0000, 4.0000, 4.0000, 4.0000])