# Q Learning

In [None]:
import random
import numpy as np

## Problem Setup

First we need to lay out the problem we're working through. This consists of defining the possible states and actions, and optionally the actions that are legal at each state. In order to perform value iteration we'll also need the transition probabilities and rewards when moving from state to state

In [None]:
states = ["cold", "warm", "overheated"]
actions = ["slow", "fast"]
Legal_Actions = {"cold": ["slow","fast"],
                 "warm": ["slow","fast"],
                 "overheated": []
                }

transitions = {("cold", "slow", "cold"): 1,
               ("cold", "fast", "cold"): .5,
               ("cold", "slow", "warm"): 0,
               ("cold", "fast", "warm"): .5,
               ("cold", "slow", "overheated"): 0,
               ("cold", "fast", "overheated"): 0,
               
               ("warm", "slow", "cold"): .5,
               ("warm", "fast", "cold"): 0,
               ("warm", "slow", "warm"): .5,
               ("warm", "fast", "warm"): 0,
               ("warm", "slow", "overheated"): 0,
               ("warm", "fast", "overheated"): 1,
              }

rewards = {("cold", "slow"): 1,
           ("cold", "fast"): 2,
           
           ("warm", "slow"): 1,
           ("warm", "fast"): -10,
          }

## Value, Q, and Policy iteration

Implement a simple function for each type of iteration (hint: they should be very similar)

For reference, here is a copy of the Bellman Equation
$$\Large
Q^*(s,a) =\sum_{s'} T(s,a,s')[R(s,a,s') + \gamma V^* (s')]
$$

Feel free to define any helper functions you want or keep everything in the same block

In [None]:
def value_iteration(states, actions, Legal_Actions, T, R, gamma, max_iter=100):
    v = {}
    
    # Initialize all entries in v to 0
    for state in states:
        v[state] = 0
        
    # Iterate until max_iter is reached
    for i in range(max_iter):
        for s in states:
            # Calculate the value of the maximum legal action, this is in a try-catch because there may be no legal actions
            # which would cause max([]) to throw a value error
            try:
                v[s] = max([sum([T[(s,a,s_prime)] * (R[(s,a)] + gamma * v[s_prime]) for s_prime in states]) for a in Legal_Actions[s]])
            except:
                v[s]=0
    return v

In [None]:
def q_iteration(states, actions, Legal_Actions, T, R, gamma, max_iter=100):
    q={}
    v={}
    
    # Initialize all legal entries in q and v to 0
    for state in states:
        v[state] = 0
        for action in Legal_Actions[state]:
            q[(state,action)] = 0

    # Iterate until max_iter is reached        
    for i in range(max_iter):
        for s in states:
            for a in Legal_Actions[s]:
                # For each legal s,a pair evaluate Q(s,a)
                q[(s,a)] = sum([T[(s,a,s_prime)] * (R[(s,a)] + gamma * v[s_prime]) for s_prime in states])
            
            # Place the maximum result into v[s], this is in a try-catch because there may be no legal actions
            # which would cause max([]) to throw a value error
            try:
                v[s] = max([q[(s,a)] for a in Legal_Actions[s]])
            except:
                v[s]=0

    return q

In [None]:
def policy_iteration(states, actions, Legal_Actions, T, R, gamma, max_iter=10):
    pi = {}
    q = {}
    v = {}
    
    # Initialize all legal entries in pi, q, and v to 0 or a random value as appropriate
    for state in states:
        v[state] = 0
        try:
            pi[state] = random.choice(Legal_Actions[state])
        except:
            pass
        for action in Legal_Actions[state]:
            q[(state, action)] = 0
            
    # Iterate until max_iter is reached
    for i in range(max_iter):
        for s in states:
            for a in Legal_Actions[s]:
                # For each legal s,a pair evaluate Q(s,a)
                q[(s,a)] = sum([T[(s,a,s_prime)] * (R[(s,a)] + gamma * v[s_prime]) for s_prime in states])
           
            # Place the maximum result into v[s], this is in a try-catch because there may be no legal actions
            # which would cause max([]) to throw a value error
            try:
                pi[s] = Legal_Actions[s][np.argmax([q[(s,a)] for a in Legal_Actions[s]])]
                v[s] = q[(s, pi[s])]
            except:
                v[s]=0

    return pi

In [None]:
state_values = value_iteration(states, actions, Legal_Actions, transitions, rewards, 0.9)
q_values = q_iteration(states, actions, Legal_Actions, transitions, rewards, 0.9)
learned_policy = policy_iteration(states, actions, Legal_Actions, transitions, rewards, 0.9)

In [None]:
assert learned_policy["cold"] == "fast", "The policy chose the wrong action when the engine is cold"
assert learned_policy["warm"] == "slow", "The policy chose the wrong action when the engine is warm"
assert state_values["cold"] == max([q_values[("cold",a)] for a in Legal_Actions["cold"]]), "Value iteration did not choose the maximizing action"
assert np.isclose(state_values["cold"], 15.5), "Value iteration converged to the wrong value"