<a href="https://colab.research.google.com/github/Abdulrahman-Aladdin/AI_MDP/blob/main/MDP.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# AI Assignment 3
## Markov Desicion Process




In [None]:
!pip install colorama
from colorama import Fore

Collecting colorama
  Downloading colorama-0.4.6-py2.py3-none-any.whl (25 kB)
Installing collected packages: colorama
Successfully installed colorama-0.4.6


In [None]:
class MDP:
    def __init__(self, world, actions, states, terminal_states,
                 discount, str_actions, next_states):
        self.world = world
        self.actions = actions
        self.states = states
        self.terminal_states = terminal_states
        self.discount = discount
        self.str_actions = str_actions
        self.next_states = next_states

In [None]:
def print_three_by_three(arr):
    for row in range(0, 7, 3):
        for col in range(3):
            print(arr[row + col], end=' ')
        print()

In [None]:
def is_converged(curr_v, new_v):
    return all(abs(curr_v[s] - new_v[s]) < 1E-8 for s in range(9))

In [None]:
def is_converged_policy(curr_p, new_p):
    return all(curr_p[s] == new_p[s] for s in range(9))

In [None]:
def str_actions_inverse(policy):
  # str_actions = ['u', 'r', 'd', 'l']
  if policy == 'u':
    return 0
  elif policy == 'r':
    return 1
  elif policy == 'd':
    return 2
  else:
    return 3

In [None]:
def get_sum(curr_state, curr_action, v, mdp):
    next_state = mdp.next_states[curr_state][curr_action]
    summation = (0.8 * (mdp.world[next_state] + mdp.discount * v[next_state]))

    next_state = mdp.next_states[curr_state][(curr_action + 1) % 4]
    summation += (0.1 * (mdp.world[next_state] + mdp.discount * v[next_state]))

    next_state = mdp.next_states[curr_state][(curr_action - 1) % 4]
    summation += (0.1 * (mdp.world[next_state] + mdp.discount * v[next_state]))
    return summation

In [None]:
def get_max_action(curr_state, old_v, mdp):
    if curr_state in mdp.terminal_states:
        return 0, 's'
    max_val = -1E9
    act = 0
    for action in mdp.actions:
        sigma = get_sum(curr_state, action, old_v, mdp)
        if sigma > max_val:
            max_val = sigma
            act = action
    return max_val, mdp.str_actions[act]

In [None]:
def MDP_value_iteration(mdp):
    old_v = [0 for _ in range(9)]
    policy = list(range(9))
    while True:
        new_value = list(range(9))
        for state in mdp.states:
            new_value[state], policy[state] = get_max_action(state, old_v, mdp)

        if is_converged(old_v, new_value):
            old_v = new_value
            break
        old_v = new_value
    print(Fore.GREEN + 'Values:' + Fore.RESET)
    old_v = [round(v, 2) for v in old_v]
    print_three_by_three(old_v)
    print(Fore.GREEN + 'Policy:' + Fore.RESET)
    print_three_by_three(policy)
    print('\n------------------------------------------------------------')

In [None]:
def MDP_policy_evaluation(mdp, old_v, policy):
    new_v = [0 for _ in range(9)]
    # print(Fore.GREEN + 'Now evaluating policy:' + Fore.RESET)
    # print_three_by_three(policy)
    while True:
        for state in mdp.states:
            if state in mdp.terminal_states:
              new_v[state] = 0 # mdp.world[state]
              continue
            action = str_actions_inverse(policy[state])
            new_v[state] = get_sum(state, action, old_v, mdp)

        if is_converged(old_v, new_v):
            old_v = new_v
            break
        old_v = new_v
    # print(Fore.GREEN + 'Values:' + Fore.RESET)
    old_v = [round(v, 2) for v in old_v]
    # print_three_by_three(old_v)
    # print('\n------------------------------------------------------------')
    return old_v

In [None]:
def MDP_policy_extraction(mdp, v):
    policy = list(range(9))
    for state in mdp.states:
        if state in mdp.terminal_states:
          policy[state] = 's'
          continue
        max_val = -1E9
        max_action = 0
        for action in mdp.actions:
            val = get_sum(state, action, v, mdp)
            if val > max_val:
                max_val = val
                max_action = action
        policy[state] = mdp.str_actions[max_action]
    # print(Fore.GREEN + 'Policy extraction successful:' + Fore.RESET)
    # print_three_by_three(policy)
    # print('\n------------------------------------------------------------')
    return policy

In [None]:
def MDP_policy_iteration(mdp):
    policy = ['u' for _ in range(9)]
    new_value = [0 for _ in range(9)]
    while True:
        new_value = MDP_policy_evaluation(mdp, new_value, policy)
        new_policy = MDP_policy_extraction(mdp, new_value)
        if is_converged_policy(new_policy, policy):
            break
        policy = new_policy
    print(Fore.GREEN + 'Final values:' + Fore.RESET)
    print_three_by_three(new_value)
    print(Fore.GREEN + 'Final policy:' + Fore.RESET)
    print_three_by_three(policy)
    print('\n------------------------------------------------------------')

In [None]:
def main():
    world = [-1 for _ in range(9)]
    world[2] = 10

    rs = [100, 3, 0, -3]
    actions = [0, 1, 2, 3]
    states = range(9)
    terminal_states = (0, 2)
    discount = 0.99
    str_actions = ['u', 'r', 'd', 'l']

    next_states = [
        [0, 1, 3, 0],
        [1, 2, 4, 0],
        [2, 2, 5, 1],
        [0, 4, 6, 3],
        [1, 5, 7, 3],
        [2, 5, 8, 4],
        [3, 7, 6, 6],
        [4, 8, 7, 6],
        [5, 8, 8, 7]
    ]

    mdp = MDP(world, actions, states, terminal_states,
              discount, str_actions, next_states)

    print(Fore.GREEN + 'Running value iteration:' + Fore.RESET)
    for r in rs:
        world[0] = r
        print('\nr -> ' + str(r))
        MDP_value_iteration(mdp)


    print(Fore.GREEN + 'Running policy iteration:' + Fore.RESET)
    for r in rs:
        world[0] = r
        print('\nr -> ' + str(r))
        MDP_policy_iteration(mdp)


In [None]:
if __name__ == '__main__':
    main()

[32mRunning value iteration:[39m

r -> 100
[32mValues:[39m
0 99.2 0 
99.2 96.72 90.11 
96.45 94.3 91.68 
[32mPolicy:[39m
s l s 
u l d 
u l l 

------------------------------------------------------------

r -> 3
[32mValues:[39m
0 9.56 0 
6.45 8.2 9.56 
5.63 6.86 8.05 
[32mPolicy:[39m
s r s 
r r u 
r r u 

------------------------------------------------------------

r -> 0
[32mValues:[39m
0 9.56 0 
6.14 8.2 9.56 
5.6 6.86 8.05 
[32mPolicy:[39m
s r s 
r r u 
r r u 

------------------------------------------------------------

r -> -3
[32mValues:[39m
0 9.56 0 
5.84 8.2 9.56 
5.56 6.86 8.05 
[32mPolicy:[39m
s r s 
r r u 
r r u 

------------------------------------------------------------
[32mRunning policy iteration:[39m

r -> 100
[32mFinal values:[39m
0 99.18 0 
99.18 96.67 89.35 
96.39 94.21 91.46 
[32mFinal policy:[39m
s l s 
u l d 
u l l 

------------------------------------------------------------

r -> 3
[32mFinal values:[39m
0 9.51 0 
5.95 8.03 9.53 
4.

For r = 100:   
**Values**:    
0 99.2 0    
99.2 96.72 90.11    
96.45 94.3 91.68    
**Policy**:   
s  l  s    
u  l  d    
u  l  l   

We notice that the optimal policy for the cell (1, 2) is to go down. This is done to avoid the probability of going to the cell that has a +10. This way, we'll have a higher chance of going to the cell that has +100.     
We also notice that the value of the rewards gets bigger and bigger once we get close to the top left cell (+100)


r -> 3    
**Values**:    
0 9.56 0     
6.45 8.2 9.56     
5.63 6.86 8.05     
**Policy**:   
s r s     
r r u     
r r u     

We notice that the value of r is not good enough for the algorithm to gain, gaining a +10 is much better in this case. So the optimal policy tries to avoid going to r and instead, tries to go to the +10 square. This idea also applies to values of 0, -3. We get the same optimal policy, which is always try to go to the +10 square.

We also notice that the cell (1,0) is r. This means that there is a probabilty that we hit the top left cell early. This is ok in this case as going down will lower our score by alot, which is not worth it at all in any case.
