In [None]:
import sys; sys.path.append('..') # add project root to the python path

In [None]:
import gym
import torch

from src.part3.MLP import MultiLayerPerceptron as MLP
from src.part4.PolicyGradient import REINFORCE
from src.common.train_utils import EMAMeter, to_tensor
from src.common.memory.episodic_memory import EpisodicMemory

In [None]:
env = gym.make('CartPole-v1')
s_dim = env.observation_space.shape[0]
a_dim = env.action_space.n

In [None]:
net = MLP(s_dim, a_dim, [128])
agent = REINFORCE(net)
ema = EMAMeter()
memory = EpisodicMemory(max_size=100, gamma=1.0)

In [None]:
n_eps = 10000
update_every = 1
print_every = 50

for ep in range(n_eps):
    s = env.reset()
    cum_r = 0

    states = []
    actions = []
    rewards = []

    while True:
        s = to_tensor(s, size=(1, 4))
        a = agent.get_action(s)
        ns, r, done, info = env.step(a.item())
        
        # preprocess data
        r = torch.ones(1,1) * r
        done = torch.ones(1,1) * done
        
        memory.push(s,a,r,torch.tensor(ns),done)
                
        s = ns
        cum_r += r
        if done:
            break

    ema.update(cum_r)
    if ep % print_every == 0:
        print("Episode {} || EMA: {} ".format(ep, ema.s))
    
    if ep % update_every == 0:
        s,a, _, _, done, g = memory.get_samples()
        agent.update_episodes(s, a, g, use_norm=True)