In [1]:
from load_env import enviroment
from agent import *
import numpy as np

In [2]:
# Creación del agente

agent = PPG(state_dim = 11,
            num_actions = 13,
            actor_hidden_dim = 32,
            critic_hidden_dim = 256,
            epochs = 1,
            epochs_aux = 6,
            minibatch_size = 64,
            lr = 0.0005,
            betas = (0.9, 0.999),
            lam = 0.95,
            gamma = 0.99,
            beta_s = .01,
            eps_clip = 0.2,
            value_clip = 0.4,
            save_path = 'ppg_0.pt')

# Cargar modelo
# agent.load(ppg_0.pt)

In [18]:
def train(env_name = 5555,agent = agent, num_episodes = 50000,max_steps = 500,update_steps = 5000,
          num_policy_updates_per_aux = 32,seed = None, save_every = 1000):
    
    # Enviroment load
    
    env = enviroment(env_name)

    memories = deque([])
    aux_memories = deque([])

    if exists(seed):
        torch.manual_seed(seed)
        np.random.seed(seed)

    current_steps = 0
    updated = False
    num_policy_updates = 0

    for eps in range(1,num_episodes+1):
        for steps in tqdm(range(max_steps),ncols= 100, desc='Episode ' + str(eps) + '/' + str(num_episodes)):
                
            state = env.state()
            state_tensor = torch.from_numpy(np.asarray(state)).float()
            action_probs, _ = agent.actor(state_tensor)
            value = agent.critic(state_tensor)

            dist = Categorical(action_probs)
            action = dist.sample()
            action_log_prob = dist.log_prob(action)
            action = action.item()
            
            next_state, reward, done = env.step(action,state)
            memory = Memory(state, action, action_log_prob, reward, done, value)
            memories.append(memory)
            state = np.asarray(next_state)

            if steps != 0 and steps % update_steps == 0:
                agent.learn(memories, aux_memories, np.asarray(next_state))
                num_policy_updates += 1
                memories.clear()

                if num_policy_updates % num_policy_updates_per_aux == 0:
                    agent.learn_aux(aux_memories)
                    aux_memories.clear()

                updated = True

            if done:
                env.restart()
                break
                
        if eps % save_every == 0:
            agent.save()
            
    env.close()

In [None]:
train(env_name = 5555,
      agent = agent,
      num_episodes = 50000,
      max_steps = 500,
      update_steps = 10,
      num_policy_updates_per_aux = 32,
      seed = None,
      save_every = 1000)