## El objetivo de este ejercicio es calcular la función de valor de una política dada usando Diferencias Temporales.
<img src="Prediccion_TD.PNG">

In [2]:
%matplotlib inline

import gym
import matplotlib
import numpy as np
import sys

from collections import defaultdict
import itertools

if "../" not in sys.path:
  sys.path.append("../") 
from lib.envs.gridworld import GridworldEnv
from lib import plotting

matplotlib.style.use('ggplot')

In [3]:
env = GridworldEnv()

In [7]:
def td_prediction(policy, env, num_episodes, discount_factor=1.0, alpha=0.01):
    """
    Algoritmo de predicción por diferencias temporales.
    Calcula la función de valor dada una política utilizando diferencias temporales para los updates.
    
    Argumentos:
        policy: Una función que dada una observación toma una acción
        env: ambiente de OpenAI gym.
        num_episodes: número de episodios para samplear en total.
        discount_factor: factor de descuento gama.
    
    Returns:
    Retorna:
        Un diccionario que mapea de estado -> valor
    """
    # la función de valor final
    V = defaultdict(float)
    
    for ep in range(num_episodes):
        
        if ep % 1000 == 0:
            print('Episode: %d' % ep)
        
        state = env.reset()
            
        while True:
            action = policy(state)
            next_state, reward, done, info = env.step(action)
            V[state] = V[state] + alpha * (reward + discount_factor * V[next_state] - V[state])
            state = next_state
            if done:
                break
                
    return V    

In [8]:
def sample_policy(observation):
    """
    Política aleatoria
    """
    return np.random.choice(4)

In [9]:
V = td_prediction(sample_policy, env, num_episodes=500000)

Episode: 0
Episode: 1000
Episode: 2000
Episode: 3000
Episode: 4000
Episode: 5000
Episode: 6000
Episode: 7000
Episode: 8000
Episode: 9000
Episode: 10000
Episode: 11000
Episode: 12000
Episode: 13000
Episode: 14000
Episode: 15000
Episode: 16000
Episode: 17000
Episode: 18000
Episode: 19000
Episode: 20000
Episode: 21000
Episode: 22000
Episode: 23000
Episode: 24000
Episode: 25000
Episode: 26000
Episode: 27000
Episode: 28000
Episode: 29000
Episode: 30000
Episode: 31000
Episode: 32000
Episode: 33000
Episode: 34000
Episode: 35000
Episode: 36000
Episode: 37000
Episode: 38000
Episode: 39000
Episode: 40000
Episode: 41000
Episode: 42000
Episode: 43000
Episode: 44000
Episode: 45000
Episode: 46000
Episode: 47000
Episode: 48000
Episode: 49000
Episode: 50000
Episode: 51000
Episode: 52000
Episode: 53000
Episode: 54000
Episode: 55000
Episode: 56000
Episode: 57000
Episode: 58000
Episode: 59000
Episode: 60000
Episode: 61000
Episode: 62000
Episode: 63000
Episode: 64000
Episode: 65000
Episode: 66000
Episode:

In [12]:
print(sorted(V.items()))

[(0, 0.0), (1, -12.871366265024633), (2, -19.576608305129007), (3, -21.9483992661738), (4, -14.446212645663802), (5, -17.78094110444044), (6, -20.031838207267285), (7, -20.170447103491174), (8, -20.429692276242), (9, -19.961929222272612), (10, -18.052938183416344), (11, -14.684376350901978), (12, -22.32021329183584), (13, -19.976718298350107), (14, -13.936367413526023), (15, 0.0)]


In [13]:
# deberia ser:
np.array([0, -14, -20, -22, -14, -18, -20, -20, -20, -20, -18, -14, -22, -20, -14, 0])

array([  0, -14, -20, -22, -14, -18, -20, -20, -20, -20, -18, -14, -22,
       -20, -14,   0])