<a href="https://colab.research.google.com/github/AbhiramDream/pailab/blob/main/Untitled3.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import numpy as np

class SimpleGridWorldMDP:
    def __init__(self, grid, terminal_states, rewards, gamma=0.99):
        self.grid = grid
        self.terminal_states = terminal_states
        self.rewards = rewards
        self.gamma = gamma
        self.states = [(i, j) for i in range(len(grid)) for j in range(len(grid[0]))]
        self.actions = ['up', 'down', 'left', 'right']
        self.transitions = self.build_transitions()

    def build_transitions(self):
        transitions = {}
        for state in self.states:
            transitions[state] = {}
            for action in self.actions:
                next_state, reward = self.get_next_state_and_reward(state, action)
                transitions[state][action] = (next_state, reward)
        return transitions

    def get_next_state_and_reward(self, state, action):
        if state in self.terminal_states:
            return state, self.rewards.get(state, 0)

        i, j = state
        if action == 'up':
            next_state = (max(i - 1, 0), j)
        elif action == 'down':
            next_state = (min(i + 1, len(self.grid) - 1), j)
        elif action == 'left':
            next_state = (i, max(j - 1, 0))
        elif action == 'right':
            next_state = (i, min(j + 1, len(self.grid[0]) - 1))

        reward = self.rewards.get(next_state, 0)
        return next_state, reward

    def value_iteration(self, theta=1e-5):
        V = {state: 0 for state in self.states}
        while True:
            delta = 0
            for state in self.states:
                if state in self.terminal_states:
                    continue
                v = V[state]
                V[state] = max(sum(prob * (reward + self.gamma * V[next_state])
                                   for (next_state, reward), prob in [(self.transitions[state][a], 1.0)])
                               for a in self.actions)
                delta = max(delta, abs(v - V[state]))
            if delta < theta:
                break
        policy = {state: self.actions[np.argmax([sum(prob * (reward + self.gamma * V[next_state])
                                                     for (next_state, reward), prob in [(self.transitions[state][a], 1.0)])
                                             for a in self.actions])] for state in self.states if state not in self.terminal_states}
        return policy, V

    def policy_iteration(self):
        policy = {state: np.random.choice(self.actions) for state in self.states if state not in self.terminal_states}
        V = {state: 0 for state in self.states}

        def policy_evaluation(policy, V):
            while True:
                delta = 0
                for state in self.states:
                    if state in self.terminal_states:
                        continue
                    v = V[state]
                    action = policy[state]
                    V[state] = sum(prob * (reward + self.gamma * V[next_state])
                                   for (next_state, reward), prob in [(self.transitions[state][action], 1.0)])
                    delta = max(delta, abs(v - V[state]))
                if delta < 1e-5:
                    break
            return V

        while True:
            V = policy_evaluation(policy, V)
            policy_stable = True
            for state in self.states:
                if state in self.terminal_states:
                    continue
                old_action = policy[state]
                policy[state] = self.actions[np.argmax([sum(prob * (reward + self.gamma * V[next_state])
                                                            for (next_state, reward), prob in [(self.transitions[state][a], 1.0)])
                                                    for a in self.actions])]
                if old_action != policy[state]:
                    policy_stable = False
            if policy_stable:
                break
        return policy, V


# Example usage
grid = [
    [0, 0, 0, 0],
    [0, 1, 0, 0],
    [0, 0, 0, 0]
]

terminal_states = [(0, 3)]
rewards = {(0, 3): 1, (1, 1): -1}

mdp = SimpleGridWorldMDP(grid, terminal_states, rewards)

policy_vi, value_vi = mdp.value_iteration()
print("Policy (Value Iteration):")
for state in sorted(policy_vi.keys()):
    print(f"State {state}: {policy_vi[state]}")

policy_pi, value_pi = mdp.policy_iteration()
print("\nPolicy (Policy Iteration):")
for state in sorted(policy_pi.keys()):
    print(f"State {state}: {policy_pi[state]}")


Policy (Value Iteration):
State (0, 0): right
State (0, 1): right
State (0, 2): right
State (1, 0): up
State (1, 1): up
State (1, 2): up
State (1, 3): up
State (2, 0): up
State (2, 1): right
State (2, 2): up
State (2, 3): up

Policy (Policy Iteration):
State (0, 0): right
State (0, 1): right
State (0, 2): right
State (1, 0): up
State (1, 1): up
State (1, 2): up
State (1, 3): up
State (2, 0): up
State (2, 1): right
State (2, 2): up
State (2, 3): up
