In [1]:
import numpy as np

def value_iteration(env, gamma=1, theta=1e-6):
    """
    Perform value iteration for a given discrete MDP.

    Parameters:
    env: Custom environment with methods:
        - env.nS: Number of states
        - env.nA: Number of actions
        - env.P: Transition probabilities, rewards, etc.
    gamma: Discount factor
    theta: Stopping threshold

    Returns:
    V: Value function
    policy: Optimal policy
    """
    # Initialize value function and policy
    V = np.zeros(env.nS)
    policy = np.zeros(env.nS, dtype=int)

    while True:
        delta = 0
        # Loop over all states
        for s in range(env.nS):
            v = V[s]
            
            # Placeholder for state-action value function
            state_action_values = np.zeros(env.nA)

            # Loop over all actions
            for a in range(env.nA):
                
                # Expected value for action a in state s
                for prob, next_state, reward, done in env.P[s][a]:
                    state_action_values[a] += prob * (reward + gamma * V[next_state] * (not done))
                    
            # Update the value for state s
            V[s] = np.max(state_action_values)
            
            # Track maximum change for convergence
            delta = max(delta, abs(v - V[s]))

        # Check for convergence
        if delta < theta:
            break

    # Derive the optimal policy from the value function
    for s in range(env.nS):
        state_action_values = np.zeros(env.nA)
        
        for a in range(env.nA):
            for prob, next_state, reward, done in env.P[s][a]:
                state_action_values[a] += prob * (reward + gamma * V[next_state] * (not done))
                
        policy[s] = np.argmax(state_action_values)

    return V, policy

# Example custom environment details (to be filled in):
# class CustomMDPEnv:
#     def __init__(self):
#         self.nS = ...
#         self.nA = ...
#         self.P = ...
# 
# env = CustomMDPEnv()

# V, policy = value_iteration(env)
# print("Optimal Value Function: ", V)
# print("Optimal Policy: ", policy)

In [9]:
P = [[[] for _ in range(2)] for _ in range(36)]
for i in range(36):
    for a in range(2):
        for j in range(36):
            (a_i, b_i) = (i//6 + 1, i%6 + 1)
            (a_j, b_j) = (j//6 + 1, j%6 + 1)
            done = False
            prob = 0
            reward = 0
            
            if a_j == b_i:
                prob = 1.0/6
        
            if a==1 and b_i != 6:
                reward = a_i
                done = True
            
            if b_i == 6 or b_j == 6:
                done = True

            P[i][a].append((prob, j, reward, done)) 

class CustomMDPEnv:
    def __init__(self):
        self.nS = 36
        self.nA = 2
        self.P = P

env = CustomMDPEnv()

In [10]:
V, policy = value_iteration(env)
print("Optimal Value Function: ", V)
print("Optimal Policy: ", policy)

Optimal Value Function:  [2.50188088 2.50188096 2.71352812 3.320625   4.125      0.
 2.50188097 2.50188097 2.71352812 3.320625   4.125      0.
 3.         3.         3.         3.320625   4.125      0.
 4.         4.         4.         4.         4.125      0.
 5.         5.         5.         5.         5.         0.
 6.         6.         6.         6.         6.         0.        ]
Optimal Policy:  [0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 0 0 0 1 1 1 1 0 0 1 1 1 1 1 0 1 1 1 1 1 0]


In [13]:
print(sum(V)/36)

3.3860126397153962


In [15]:
list(enumerate(policy))

[(0, np.int64(0)),
 (1, np.int64(0)),
 (2, np.int64(0)),
 (3, np.int64(0)),
 (4, np.int64(0)),
 (5, np.int64(0)),
 (6, np.int64(0)),
 (7, np.int64(0)),
 (8, np.int64(0)),
 (9, np.int64(0)),
 (10, np.int64(0)),
 (11, np.int64(0)),
 (12, np.int64(1)),
 (13, np.int64(1)),
 (14, np.int64(1)),
 (15, np.int64(0)),
 (16, np.int64(0)),
 (17, np.int64(0)),
 (18, np.int64(1)),
 (19, np.int64(1)),
 (20, np.int64(1)),
 (21, np.int64(1)),
 (22, np.int64(0)),
 (23, np.int64(0)),
 (24, np.int64(1)),
 (25, np.int64(1)),
 (26, np.int64(1)),
 (27, np.int64(1)),
 (28, np.int64(1)),
 (29, np.int64(0)),
 (30, np.int64(1)),
 (31, np.int64(1)),
 (32, np.int64(1)),
 (33, np.int64(1)),
 (34, np.int64(1)),
 (35, np.int64(0))]