# Dynamic Programming

## Policy Iteration on FrozenLake

In [31]:
import py_inforce as pin
import gym
import numpy as np

env = gym.make('FrozenLake-v0', is_slippery=True)  

policy = pin.policy_iteration(env, .95, thresh=0.00001)

returns = []

for i in range(100):
    state = env.reset()

    done = False
    ret = 0

    while not done:
        action = np.where(policy[state, :] == 1)
        state, reward, done, _ = env.step(action[0][0])
        ret += reward
    returns.append(ret)
    
sum(returns)/100

0.78

## Value Iteration on FrozenLake

In [23]:
import gym
import numpy as np
import py_inforce as pin

env = gym.make('FrozenLake-v0', is_slippery=True)  

policy, _ = pin.value_iteration(env)

returns = []

for i in range(100):
    state = env.reset()

    done = False
    ret = 0

    while not done:
        action = np.where(policy[state, :] == 1)
        state, reward, done, _ = env.step(action[0][0])
        ret += reward
    returns.append(ret)
    
sum(returns)/100

0.84

# Policy Gradient Methods

## REINFORCE on Cartpole

In [13]:
from torch.distributions import Categorical
import gym
import torch.nn as nn
from py_inforce.generic.mlp import MLP
from py_inforce.policy_based.REINFORCE import REINFORCE
import torch.optim as optim
import torch
import numpy as np
import py_inforce as pin

env = gym.make('CartPole-v0')
in_dim = env.observation_space.shape[0] # 4
out_dim = env.action_space.n # 2
cart_agent = MLP([in_dim, 128, 128, out_dim], nn.ReLU)
optimizer = optim.Adam(cart_agent.parameters(), lr=cart_agent.lr)

# Stops when the agent achieves a score of 200 just once
pin.REINFORCE(cart_agent, env, Categorical, optimizer, 200, bf = lambda x: x - x.mean(), MAX_EPISODES=500, EARLY = lambda x: x == 200)

done = False

state = env.reset()
rewards = 0

while not done:
    state = torch.from_numpy(state.astype(np.float32))
    pd = Categorical(logits=cart_agent.forward(state))
    action = pd.sample()
    state, reward, done, _ = env.step(action.numpy())
    rewards += reward
    #env.render()
    
rewards

10.0

# Temporal Difference Learning

## SARSA