In [1]:
import numpy as np
from typing import Dict, Tuple, List
import matplotlib.pyplot as plt

In [2]:
class MDP:
    """Markov Decision Process class"""

    def __init__(self, states: List, actions: List, transitions: Dict,
                 rewards: Dict, gamma: float = 0.9):
    
        self.states = states
        self.actions = actions
        self.transitions = transitions
        self.rewards = rewards
        self.gamma = gamma
        self.n_states = len(states)
        self.n_actions = len(actions)

    def get_transition_prob(self, state, action, next_state):
        """Get transition probability P(s'|s,a)"""
        return self.transitions.get((state, action, next_state), 0.0)

    def get_reward(self, state, action, next_state):
        """Get reward R(s,a,s')"""
        return self.rewards.get((state, action, next_state), 0.0)

In [3]:
class PolicyEvaluationImprovement:
    """Policy Evaluation and Improvement algorithms"""

    def __init__(self, mdp: MDP):
        self.mdp = mdp

    def policy_evaluation(self, policy: Dict, theta: float = 1e-6,
                         max_iterations: int = 1000) -> np.ndarray:
        """
        Evaluate a policy using iterative policy evaluation

        Args:
            policy: Dict mapping state -> action
            theta: Convergence threshold
            max_iterations: Maximum iterations

        Returns:
            Value function V(s) for all states
        """
        V = np.zeros(self.mdp.n_states)

        for iteration in range(max_iterations):
            delta = 0
            V_new = np.zeros(self.mdp.n_states)

            for s_idx, state in enumerate(self.mdp.states):
                action = policy[state]
                v = 0

                # Calculate expected value
                for next_state in self.mdp.states:
                    s_next_idx = self.mdp.states.index(next_state)
                    prob = self.mdp.get_transition_prob(state, action, next_state)
                    reward = self.mdp.get_reward(state, action, next_state)
                    v += prob * (reward + self.mdp.gamma * V[s_next_idx])

                V_new[s_idx] = v
                delta = max(delta, abs(V_new[s_idx] - V[s_idx]))

            V = V_new.copy()

            if delta < theta:
                print(f"Policy evaluation converged in {iteration + 1} iterations")
                break

        return V

In [4]:
def policy_improvement(self, V: np.ndarray) -> Tuple[Dict, bool]:
  
    policy = {}
    policy_stable = True

    for s_idx, state in enumerate(self.mdp.states):
        old_action = None
        best_action = None
        best_value = float('-inf')

        # Find best action for this state
        for action in self.mdp.actions:
            action_value = 0

            for next_state in self.mdp.states:
                s_next_idx = self.mdp.states.index(next_state)
                prob = self.mdp.get_transition_prob(state, action, next_state)
                reward = self.mdp.get_reward(state, action, next_state)
                action_value += prob * (reward + self.mdp.gamma * V[s_next_idx])

            if action_value > best_value:
                best_value = action_value
                best_action = action

        policy[state] = best_action

    return policy, policy_stable

# Add this method to PolicyEvaluationImprovement class
PolicyEvaluationImprovement.policy_improvement = policy_improvement

In [5]:
def policy_iteration(self, initial_policy: Dict = None,
                    theta: float = 1e-6) -> Tuple[Dict, np.ndarray]:
 
    # Initialize random policy if not provided
    if initial_policy is None:
        policy = {state: np.random.choice(self.mdp.actions)
                 for state in self.mdp.states}
    else:
        policy = initial_policy.copy()

    iteration = 0
    print("\n=== Policy Iteration ===")

    while True:
        iteration += 1
        print(f"\nIteration {iteration}")

        # Policy Evaluation
        V = self.policy_evaluation(policy, theta)

        # Policy Improvement
        new_policy, policy_stable = self.policy_improvement(V)

        # Check if policy is stable
        is_same = all(policy[s] == new_policy[s] for s in self.mdp.states)

        if is_same:
            print(f"\nPolicy converged in {iteration} iterations!")
            break

        policy = new_policy

    return policy, V

# Add this method to PolicyEvaluationImprovement class
PolicyEvaluationImprovement.policy_iteration = policy_iteration

In [6]:
def value_iteration(self, theta: float = 1e-6,
                   max_iterations: int = 1000) -> Tuple[Dict, np.ndarray]:
   
    V = np.zeros(self.mdp.n_states)

    print("\n=== Value Iteration ===")

    for iteration in range(max_iterations):
        delta = 0
        V_new = np.zeros(self.mdp.n_states)

        for s_idx, state in enumerate(self.mdp.states):
            # Find maximum value over all actions
            action_values = []

            for action in self.mdp.actions:
                action_value = 0

                for next_state in self.mdp.states:
                    s_next_idx = self.mdp.states.index(next_state)
                    prob = self.mdp.get_transition_prob(state, action, next_state)
                    reward = self.mdp.get_reward(state, action, next_state)
                    action_value += prob * (reward + self.mdp.gamma * V[s_next_idx])

                action_values.append(action_value)

            V_new[s_idx] = max(action_values)
            delta = max(delta, abs(V_new[s_idx] - V[s_idx]))

        V = V_new.copy()

        if delta < theta:
            print(f"Value iteration converged in {iteration + 1} iterations")
            break

    # Extract optimal policy
    policy = {}
    for s_idx, state in enumerate(self.mdp.states):
        best_action = None
        best_value = float('-inf')

        for action in self.mdp.actions:
            action_value = 0

            for next_state in self.mdp.states:
                s_next_idx = self.mdp.states.index(next_state)
                prob = self.mdp.get_transition_prob(state, action, next_state)
                reward = self.mdp.get_reward(state, action, next_state)
                action_value += prob * (reward + self.mdp.gamma * V[s_next_idx])

            if action_value > best_value:
                best_value = action_value
                best_action = action

        policy[state] = best_action

    return policy, V

# Add this method to PolicyEvaluationImprovement class
PolicyEvaluationImprovement.value_iteration = value_iteration

In [7]:
def create_gridworld_mdp(grid_size: int = 4) -> MDP:
   
    states = [(i, j) for i in range(grid_size) for j in range(grid_size)]
    actions = ['up', 'down', 'left', 'right']

    # Define obstacles and goal
    obstacles = [(1, 1), (1, 3)]
    goal = (3, 3)

    # Remove obstacles from states
    states = [s for s in states if s not in obstacles]

    transitions = {}
    rewards = {}

    # Define action effects
    action_effects = {
        'up': (-1, 0),
        'down': (1, 0),
        'left': (0, -1),
        'right': (0, 1)
    }

    for state in states:
        for action in actions:
            effect = action_effects[action]
            next_state = (state[0] + effect[0], state[1] + effect[1])

            # Check if next state is valid
            if (next_state in states and
                0 <= next_state[0] < grid_size and
                0 <= next_state[1] < grid_size):
                transitions[(state, action, next_state)] = 1.0
            else:
                # Stay in place if invalid move
                transitions[(state, action, state)] = 1.0
                next_state = state

            # Set rewards
            if next_state == goal:
                rewards[(state, action, next_state)] = 10.0
            else:
                rewards[(state, action, next_state)] = -0.1

    return MDP(states, actions, transitions, rewards, gamma=0.9)

In [8]:
def print_results(mdp: MDP, policy: Dict, V: np.ndarray):
    """Print policy and value function"""
    print("\n=== Optimal Policy ===")
    for i, state in enumerate(mdp.states):
        print(f"State {state}: {policy[state]} (Value: {V[i]:.3f})")

    print(f"\nTotal states: {len(mdp.states)}")
    print(f"Discount factor: {mdp.gamma}")

In [9]:
# Create gridworld MDP
print("Creating Gridworld MDP...")
mdp = create_gridworld_mdp(grid_size=4)

# Initialize solver
solver = PolicyEvaluationImprovement(mdp)

Creating Gridworld MDP...


In [10]:
# Method 1: Policy Iteration
print("\n" + "="*50)
print("METHOD 1: POLICY ITERATION")
print("="*50)
policy_pi, V_pi = solver.policy_iteration()
print_results(mdp, policy_pi, V_pi)


METHOD 1: POLICY ITERATION

=== Policy Iteration ===

Iteration 1
Policy evaluation converged in 154 iterations

Iteration 2
Policy evaluation converged in 154 iterations

Iteration 3
Policy evaluation converged in 154 iterations

Iteration 4
Policy evaluation converged in 154 iterations

Iteration 5
Policy evaluation converged in 154 iterations

Iteration 6
Policy evaluation converged in 154 iterations

Iteration 7
Policy evaluation converged in 154 iterations

Policy converged in 7 iterations!

=== Optimal Policy ===
State (0, 0): down (Value: 58.639)
State (0, 1): right (Value: 65.266)
State (0, 2): down (Value: 72.629)
State (0, 3): left (Value: 65.266)
State (1, 0): down (Value: 65.266)
State (1, 2): down (Value: 80.810)
State (2, 0): down (Value: 72.629)
State (2, 1): down (Value: 80.810)
State (2, 2): down (Value: 89.900)
State (2, 3): down (Value: 100.000)
State (3, 0): right (Value: 80.810)
State (3, 1): right (Value: 89.900)
State (3, 2): right (Value: 100.000)
State (3, 3):

In [11]:
# Method 2: Value Iteration
print("\n" + "="*50)
print("METHOD 2: VALUE ITERATION")
print("="*50)
policy_vi, V_vi = solver.value_iteration()
print_results(mdp, policy_vi, V_vi)


METHOD 2: VALUE ITERATION

=== Value Iteration ===
Value iteration converged in 154 iterations

=== Optimal Policy ===
State (0, 0): down (Value: 58.639)
State (0, 1): right (Value: 65.266)
State (0, 2): down (Value: 72.629)
State (0, 3): left (Value: 65.266)
State (1, 0): down (Value: 65.266)
State (1, 2): down (Value: 80.810)
State (2, 0): down (Value: 72.629)
State (2, 1): down (Value: 80.810)
State (2, 2): down (Value: 89.900)
State (2, 3): down (Value: 100.000)
State (3, 0): right (Value: 80.810)
State (3, 1): right (Value: 89.900)
State (3, 2): right (Value: 100.000)
State (3, 3): down (Value: 100.000)

Total states: 14
Discount factor: 0.9
