In [5]:
import gym 
import numpy as np
from helpers import NormalizedEnv, RandomAgent
from qnetwork2 import ReplayBuffer, QNetwork
from heuristic import HeuristicPendulumAgent
from matplotlib import pyplot
import torch.optim as optim
import sys

In [6]:
# Initialization
env = gym.make("Pendulum-v1")
norm_env = NormalizedEnv(env) # accept actions between -1 and 1

#we fix a torque
torque = norm_env.action(norm_env.action_space.sample())
agent = HeuristicPendulumAgent(norm_env, torque)

buffer = ReplayBuffer(10000)
batch_size = 128

num_states = env.observation_space.shape[0]
num_actions = env.action_space.shape[0]
hidden_size = 256 # choose as you wish 
critic = QNetwork(num_states + num_actions, hidden_size, num_actions, agent)
optimizer = optim.Adam(critic.parameters(), lr=1e-4)

losses = []
rewards = []

In [7]:
for episode in range(500): 
    state, info = norm_env.reset()
    trunc = False
    
    episode_loss = []
    av_episode_loss = 0
    
    episode_reward = []
    av_episode_reward = 0
    
    while not trunc:
        action = agent.compute_action(state)
        # print(norm_env.step(action))
        next_state, reward, terminated, trunc, info = norm_env.step(action)
        buffer.add(state, action, reward, next_state, trunc)
        
        if len(buffer) > batch_size:
            transition = buffer.sample(batch_size)
            loss = critic.update(transition, trunc, 0.99)
            optimizer.zero_grad()
            optimizer.step()
            episode_loss.append(loss)
        
        state = next_state
        episode_reward.append(reward)
        
        if trunc:
            av_episode_loss = np.mean(episode_loss)
            av_episode_reward = np.mean(episode_reward)
            sys.stdout.write("episode: {}, loss: {}, reward: {} \n".format(episode, av_episode_loss, np.round(av_episode_reward, decimals=2)))
            break
            
    losses.append(av_episode_loss)
    rewards.append(episode_reward)

episode: 0, loss: 75.0203857421875, reward: -8.55 
episode: 1, loss: 68.87335968017578, reward: -7.27 
episode: 2, loss: 74.60659790039062, reward: -9.28 
episode: 3, loss: 78.13536834716797, reward: -9.25 
episode: 4, loss: 80.67928314208984, reward: -9.48 
episode: 5, loss: 81.85594177246094, reward: -9.11 
episode: 6, loss: 81.31925964355469, reward: -8.24 
episode: 7, loss: 81.96721649169922, reward: -9.45 
episode: 8, loss: 81.24459838867188, reward: -7.27 
episode: 9, loss: 80.79081726074219, reward: -8.89 
episode: 10, loss: 81.43118286132812, reward: -9.49 
episode: 11, loss: 81.864013671875, reward: -9.16 
episode: 12, loss: 82.73271179199219, reward: -9.54 
episode: 13, loss: 82.35372924804688, reward: -8.3 
episode: 14, loss: 82.45398712158203, reward: -9.28 
episode: 15, loss: 82.55633544921875, reward: -8.76 
episode: 16, loss: 82.84663391113281, reward: -9.54 
episode: 17, loss: 83.21505737304688, reward: -8.87 
episode: 18, loss: 83.23890686035156, reward: -9.42 
episode

episode: 155, loss: 83.85814666748047, reward: -9.12 
episode: 156, loss: 83.75861358642578, reward: -8.52 
episode: 157, loss: 83.56068420410156, reward: -8.51 
episode: 158, loss: 83.5046615600586, reward: -7.6 
episode: 159, loss: 83.31734466552734, reward: -9.11 
episode: 160, loss: 83.71080780029297, reward: -9.54 
episode: 161, loss: 83.98175811767578, reward: -9.69 
episode: 162, loss: 83.78915405273438, reward: -9.2 
episode: 163, loss: 83.92861938476562, reward: -9.23 
episode: 164, loss: 83.86967468261719, reward: -9.48 
episode: 165, loss: 83.87960815429688, reward: -8.77 
episode: 166, loss: 84.04359436035156, reward: -8.75 
episode: 167, loss: 84.48729705810547, reward: -9.17 
episode: 168, loss: 84.17633056640625, reward: -9.52 
episode: 169, loss: 84.10124969482422, reward: -9.12 
episode: 170, loss: 84.10431671142578, reward: -9.2 
episode: 171, loss: 83.94111633300781, reward: -9.05 
episode: 172, loss: 84.00111389160156, reward: -9.37 
episode: 173, loss: 84.109062194

episode: 308, loss: 85.27156066894531, reward: -8.56 
episode: 309, loss: 85.27562713623047, reward: -7.8 
episode: 310, loss: 85.2242202758789, reward: -9.55 
episode: 311, loss: 85.6860122680664, reward: -9.2 
episode: 312, loss: 85.1669921875, reward: -9.45 
episode: 313, loss: 85.32478332519531, reward: -9.47 
episode: 314, loss: 85.07171630859375, reward: -5.66 
episode: 315, loss: 84.98896789550781, reward: -9.42 
episode: 316, loss: 84.93144226074219, reward: -9.58 
episode: 317, loss: 84.91027069091797, reward: -8.8 
episode: 318, loss: 84.3067398071289, reward: -9.58 
episode: 319, loss: 84.51884460449219, reward: -8.86 
episode: 320, loss: 84.6871566772461, reward: -9.54 
episode: 321, loss: 84.79401397705078, reward: -9.35 
episode: 322, loss: 85.03912353515625, reward: -9.09 
episode: 323, loss: 84.56763458251953, reward: -7.62 
episode: 324, loss: 84.3705825805664, reward: -9.53 
episode: 325, loss: 84.57240295410156, reward: -9.59 
episode: 326, loss: 84.3978500366211, re

episode: 462, loss: 85.34992218017578, reward: -9.47 
episode: 463, loss: 84.95396423339844, reward: -9.65 
episode: 464, loss: 84.98279571533203, reward: -9.63 
episode: 465, loss: 85.27625274658203, reward: -9.37 
episode: 466, loss: 85.15985107421875, reward: -6.7 
episode: 467, loss: 84.31146240234375, reward: -8.76 
episode: 468, loss: 84.48445129394531, reward: -9.19 
episode: 469, loss: 84.57073974609375, reward: -9.54 
episode: 470, loss: 84.48916625976562, reward: -9.34 
episode: 471, loss: 84.42183685302734, reward: -9.43 
episode: 472, loss: 84.53400421142578, reward: -8.43 
episode: 473, loss: 84.50640869140625, reward: -9.15 
episode: 474, loss: 84.54721069335938, reward: -9.41 
episode: 475, loss: 84.08087921142578, reward: -6.83 
episode: 476, loss: 83.95381164550781, reward: -8.5 
episode: 477, loss: 83.70121765136719, reward: -9.71 
episode: 478, loss: 83.62022399902344, reward: -9.11 
episode: 479, loss: 83.89620208740234, reward: -9.24 
episode: 480, loss: 83.8187332