In [1]:
import gym 
import numpy as np
from helpers import NormalizedEnv, RandomAgent
from qnetwork2 import ReplayBuffer, QNetwork
from heuristic import HeuristicPendulumAgent
from matplotlib import pyplot
import torch.optim as optim
import sys

In [2]:
# Initialization
env = gym.make("Pendulum-v1")
norm_env = NormalizedEnv(env) # accept actions between -1 and 1

#we fix a torque
torque = norm_env.action(norm_env.action_space.sample())
agent = HeuristicPendulumAgent(norm_env, torque)

buffer = ReplayBuffer(10000)
batch_size = 128

num_states = env.observation_space.shape[0]
num_actions = env.action_space.shape[0]
hidden_size = 256 # choose as you wish 
critic = QNetwork(num_states + num_actions, hidden_size, num_actions, agent)
optimizer = optim.Adam(critic.parameters(), lr=1e-4)

losses = []
rewards = []

In [4]:
for episode in range(500): 
    state, info = norm_env.reset()
    trunc = False
    
    episode_loss = []
    av_episode_loss = 0
    
    episode_reward = []
    av_episode_reward = 0
    
    while not trunc:
        action = agent.compute_action(state)
        # print(norm_env.step(action))
        next_state, reward, terminated, trunc, info = norm_env.step(action)
        buffer.add(state, action, reward, next_state, trunc)
        
        if len(buffer) > batch_size:
            transition = buffer.sample(batch_size)
            loss = critic.update(transition, trunc, 0.99)
            optimizer.zero_grad()
            optimizer.step()
            episode_loss.append(loss)
        
        state = next_state
        episode_reward.append(reward)
        
        if trunc:
            av_episode_loss = np.mean(episode_loss)
            av_episode_reward = np.mean(episode_reward)
            sys.stdout.write("episode: {}, loss: {}, reward: {} \n".format(episode, av_episode_loss, np.round(av_episode_reward, decimals=2)))
            break
            
    losses.append(av_episode_loss)
    rewards.append(episode_reward)

episode: 0, loss: 19.427560806274414, reward: -2.49 
episode: 1, loss: 19.610759735107422, reward: -1.97 
episode: 2, loss: 19.56987953186035, reward: -2.54 
episode: 3, loss: 19.635753631591797, reward: -2.48 
episode: 4, loss: 19.62099266052246, reward: -2.47 
episode: 5, loss: 19.510387420654297, reward: -2.46 
episode: 6, loss: 19.7239933013916, reward: -2.58 
episode: 7, loss: 19.561967849731445, reward: -1.98 
episode: 8, loss: 19.734922409057617, reward: -3.07 
episode: 9, loss: 19.247848510742188, reward: -1.31 
episode: 10, loss: 18.823850631713867, reward: -1.96 
episode: 11, loss: 19.365005493164062, reward: -1.96 
episode: 12, loss: 19.069969177246094, reward: -1.97 
episode: 13, loss: 19.171710968017578, reward: -1.89 
episode: 14, loss: 18.99932098388672, reward: -2.53 
episode: 15, loss: 18.83452033996582, reward: -2.46 
episode: 16, loss: 18.73809242248535, reward: -1.97 
episode: 17, loss: 18.939746856689453, reward: -2.5 
episode: 18, loss: 18.93325424194336, reward: 

episode: 153, loss: 18.4176025390625, reward: -2.55 
episode: 154, loss: 18.60626220703125, reward: -3.07 
episode: 155, loss: 18.93350601196289, reward: -2.49 
episode: 156, loss: 19.124292373657227, reward: -3.09 
episode: 157, loss: 18.795120239257812, reward: -1.34 
episode: 158, loss: 19.037273406982422, reward: -1.32 
episode: 159, loss: 18.968433380126953, reward: -1.33 
episode: 160, loss: 18.27519416809082, reward: -1.94 
episode: 161, loss: 18.481822967529297, reward: -3.07 
episode: 162, loss: 18.646713256835938, reward: -2.51 
episode: 163, loss: 18.661529541015625, reward: -3.14 
episode: 164, loss: 19.20785140991211, reward: -2.61 
episode: 165, loss: 19.069046020507812, reward: -1.95 
episode: 166, loss: 18.520952224731445, reward: -2.0 
episode: 167, loss: 18.851089477539062, reward: -2.54 


KeyboardInterrupt: 