In [1]:
import numpy as np
import gymnasium as gym
import pickle
from tqdm import tqdm
from itertools import product


In [2]:
# =============================================
# 1. Environment and Data Setup (FIXED)
# =============================================

class CartPoleWrapper:
    def __init__(self, render=False):
        self.render = render
        self.env = gym.make("CartPole-v1", render_mode="human" if render else None)
        self.state_dim = 4
        self.action_dim = 2
        self.discrete_actions = [0, 1]
        
    def reset(self):
        return self.env.reset()[0]
    
    def step(self, action):
        state, reward, done, _, info = self.env.step(action)
        return state, reward, done, info
    
    def get_features(self, state):
        return np.array([
            state[0],               # Cart position
            state[1],               # Cart velocity
            state[2],               # Pole angle
            state[3],               # Pole angular velocity
            state[0] * state[2],    # Position-angle interaction
            state[1] * state[3]     # Velocity-angular velocity interaction
        ])

In [3]:
# =============================================
# 2. Feature Calculations (FIXED)
# =============================================

def compute_expert_feature_expectations(expert_states, gamma=0.99):
    feature_dim = len(CartPoleWrapper().get_features(expert_states[0][0]))
    mu_E = np.zeros(feature_dim)

    for traj in expert_states:
        weight = 1.0
        for state in traj:
            mu_E += weight * CartPoleWrapper().get_features(state)
            weight *= gamma

    return mu_E / len(expert_states)

In [None]:
# =============================================
# 3. Soft Value Iteration (FIXED)
# =============================================
def soft_value_iteration(w, env_wrapper, gamma=0.99, n_iters=100, n_bins=20):
    env = env_wrapper.env
    s_space = [
        np.linspace(-4.8, 4.8, n_bins),
        np.linspace(-3.4, 3.4, n_bins),
        np.linspace(-0.418, 0.418, n_bins),
        np.linspace(-2.0, 2.0, n_bins)
    ]
    
    V = np.zeros(tuple([n_bins] * env_wrapper.state_dim))

    for _ in range(n_iters):
        new_V = np.zeros_like(V)
        for idx in product(*[range(n_bins) for _ in range(env_wrapper.state_dim)]):
            s_cont = np.array([s_space[d][idx[d]] for d in range(env_wrapper.state_dim)])
            Q_values = []
            for a in env_wrapper.discrete_actions:
                original_state = env.env.state
                if env_wrapper.render:
                    env.reset()
                env.env.state = s_cont
                s_next, _, done, _, _ = env.step(a)
                s_next = np.array(s_next)
                s_next_idx = tuple(
                    np.clip(np.digitize(s_next[i], s_space[i]) - 1, 0, n_bins - 1)
                    for i in range(env_wrapper.state_dim)
                )
                cost = np.dot(w, env_wrapper.get_features(s_cont))
                Q = cost if done else cost + gamma * V[s_next_idx]
                Q_values.append(Q)
                env.env.state = original_state
            new_V[idx] = np.log(np.sum(np.exp(Q_values)))
        V = new_V

    return V, s_space

In [15]:
# =============================================
# 4. Learner Feature Expectations (FIXED)
# =============================================

def compute_learner_feature_expectations(w, env_wrapper, V, s_space, 
                                         gamma=0.99, n_trajs=10):
    feature_dim = len(env_wrapper.get_features(env_wrapper.reset()))
    mu_learner = np.zeros(feature_dim)

    for _ in range(n_trajs):
        state = env_wrapper.reset()
        done = False
        weight = 1.0

        while not done:
            state_idx = tuple(
                np.clip(np.digitize(state[i], s_space[i]) - 1, 0, len(s_space[i]) - 1)
                for i in range(env_wrapper.state_dim)
            )
            Q_values = []
            for a in env_wrapper.discrete_actions:
                original_state = env_wrapper.env.env.state
                env_wrapper.env.env.state = state
                if env_wrapper.render:
                    env_wrapper.env.reset()
                s_next, _, done, _ = env_wrapper.step(a)
                s_next = np.array(s_next)
                s_next_idx = tuple(
                    np.clip(np.digitize(s_next[i], s_space[i]) - 1, 0, len(s_space[i]) - 1)
                    for i in range(env_wrapper.state_dim)
                )
                cost = np.dot(w, env_wrapper.get_features(state))
                Q = cost if done else cost + gamma * V[s_next_idx]
                Q_values.append(Q)
                env_wrapper.env.env.state = original_state
            max_Q = np.max(Q_values)
            policy = np.exp(Q_values - max_Q)
            policy = policy / np.sum(policy)
            action = np.random.choice(env_wrapper.discrete_actions, p=policy)
            mu_learner += weight * env_wrapper.get_features(state)
            state, _, done, _ = env_wrapper.step(action)
            weight *= gamma

    return mu_learner / n_trajs

In [6]:
# =============================================
# 5. MaxEnt IRL Main Loop (FIXED)
# =============================================

def maxent_irl(expert_states, env_wrapper, lr=0.1, n_irl_iters=50, 
               n_vi_iters=100, n_trajs=10, n_bins=20):
    feature_dim = len(env_wrapper.get_features(env_wrapper.reset()))
    w = np.random.randn(feature_dim) * 0.1
    mu_E = compute_expert_feature_expectations(expert_states)
    losses = []

    for it in tqdm(range(n_irl_iters)):
        V, s_space = soft_value_iteration(w, env_wrapper, n_iters=n_vi_iters, n_bins=n_bins)
        mu_learner = compute_learner_feature_expectations(w, env_wrapper, V, s_space, n_trajs=n_trajs)
        gradient = mu_E - mu_learner
        w += lr * gradient
        loss = np.linalg.norm(gradient)
        losses.append(loss)
        tqdm.write(f"Iter {it+1}/{n_irl_iters}, Loss: {loss:.4f}")
    
    return w, losses

In [7]:
# =============================================
# 6. Executionvec and Expert Data Handling (FIXED)
# =============================================

env_wrapper = CartPoleWrapper(render=False)

# Load expert data
with open("./expert_data/ckpt0.pkl", "rb") as f:
    exp_data = pickle.load(f)
exp_states = exp_data["states"]
timestep_lens = exp_data["timestep_lens"]

# Reconstruct expert state trajectories
expert_states = []
current = 0
for length in timestep_lens:
    episode = exp_states[current:current+length]
    expert_states.append(np.array(episode))
    current += length

# Run MaxEnt IRL
learned_weights, losses = maxent_irl(
    expert_states,
    env_wrapper,
    lr=0.1,
    n_irl_iters=10,   # Keep low for initial testing
    n_vi_iters=50,
    n_trajs=5,
    n_bins=10
)

print("Learned weights:", learned_weights)
print("Losses:", losses)

  logger.warn(
  logger.warn(
 10%|█         | 1/10 [00:36<05:28, 36.55s/it]

Iter 1/10, Loss: 4.0877


 20%|██        | 2/10 [01:12<04:49, 36.14s/it]

Iter 2/10, Loss: 3.1016


 30%|███       | 3/10 [01:47<04:10, 35.82s/it]

Iter 3/10, Loss: 4.7907


 40%|████      | 4/10 [02:24<03:36, 36.00s/it]

Iter 4/10, Loss: 5.3009


 50%|█████     | 5/10 [03:00<03:00, 36.03s/it]

Iter 5/10, Loss: 4.2242


 60%|██████    | 6/10 [03:35<02:23, 35.94s/it]

Iter 6/10, Loss: 3.8395


 70%|███████   | 7/10 [04:12<01:48, 36.08s/it]

Iter 7/10, Loss: 3.1224


 80%|████████  | 8/10 [04:48<01:12, 36.19s/it]

Iter 8/10, Loss: 3.8025


 90%|█████████ | 9/10 [05:25<00:36, 36.24s/it]

Iter 9/10, Loss: 4.7877


100%|██████████| 10/10 [06:01<00:00, 36.20s/it]

Iter 10/10, Loss: 4.0550
Learned weights: [-3.22282126 -1.53768332 -0.18691442 -0.14425328  0.00549319 -1.02069414]
Losses: [np.float64(4.087674021724666), np.float64(3.101588865713964), np.float64(4.790657737092499), np.float64(5.300858085215364), np.float64(4.224191682578222), np.float64(3.8394848624636153), np.float64(3.1224014190494733), np.float64(3.8025145236611846), np.float64(4.787674889808097), np.float64(4.055021773748474)]





In [11]:
def test_policy(w, env_wrapper, V, s_space, gamma=0.99, n_episodes=5, render=False):
    """
    Evaluate the learned policy by running it in the real environment.
    """
    for ep in range(n_episodes):
        state = env_wrapper.reset()
        done = False
        total_reward = 0
        steps = 0

        while not done:
            if render:
                env_wrapper.env.render()  # show the environment

            # Discretize state
            state_idx = tuple(
                np.clip(np.digitize(state[i], s_space[i]) - 1, 0, len(s_space[i]) - 1)
                for i in range(env_wrapper.state_dim)
            )

            # Compute Q-values for both actions
            Q_values = []
            for a in env_wrapper.discrete_actions:
                original_state = env_wrapper.env.env.state
                env_wrapper.env.env.state = state
                s_next, _, done, _ = env_wrapper.step(a)
                s_next = np.array(s_next)

                s_next_idx = tuple(
                    np.clip(np.digitize(s_next[i], s_space[i]) - 1, 0, len(s_space[i]) - 1)
                    for i in range(env_wrapper.state_dim)
                )

                cost = np.dot(w, env_wrapper.get_features(state))
                if done:
                    Q = cost
                else:
                    Q = cost + gamma * V[s_next_idx]
                Q_values.append(Q)

                env_wrapper.env.env.state = original_state

            # Softmax policy
            max_Q = np.max(Q_values)
            policy = np.exp(Q_values - max_Q)
            policy = policy / np.sum(policy)

            action = np.random.choice(env_wrapper.discrete_actions, p=policy)

            # Step
            state, reward, done, _ = env_wrapper.step(action)
            total_reward += reward
            steps += 1

        print(f"Episode {ep + 1}: Total reward = {total_reward}, Steps = {steps}")

    env_wrapper.env.close()  # close the rendering window after testing


In [17]:

env_wrapper = CartPoleWrapper(render=True)
V, s_space = soft_value_iteration(
    learned_weights, env_wrapper, n_iters=50, n_bins=10
)
test_policy(learned_weights, env_wrapper, V, s_space, n_episodes=5, render=True)



  logger.warn(


TypeError: unsupported operand type(s) for +=: 'NoneType' and 'int'