In [7]:
import gym 
import numpy as np
from helpers import NormalizedEnv, RandomAgent
from qnetwork2 import ReplayBuffer, QNetwork, update
from heuristic import HeuristicPendulumAgent
from matplotlib import pyplot
import torch.optim as optim
import sys

In [12]:
# Initialization
env = gym.make("Pendulum-v1")
norm_env = NormalizedEnv(env) # accept actions between -1 and 1

#we fix a torque
torque = norm_env.action(norm_env.action_space.sample())
agent = HeuristicPendulumAgent(norm_env, torque)

buffer = ReplayBuffer(10000)
batch_size = 128

num_states = env.observation_space.shape[0]
num_actions = env.action_space.shape[0]
hidden_size = 256 # choose as you wish 
critic = QNetwork(num_states + num_actions, hidden_size, num_actions, agent)
optimizer = optim.Adam(critic.parameters(), lr=1e-4)

losses = []
rewards = []
avg_rewards = []

In [13]:
for episode in range(200): 
    state, info = norm_env.reset()
    trunc = False
    episode_loss = 0
    #average_loss = 0
    episode_reward = 0
    
    while not trunc:
        action = agent.compute_action(state)
        # print(norm_env.step(action))
        next_state, reward, terminated, trunc, info = norm_env.step(action)
        buffer.add(state, action, reward, next_state, trunc)
        
        if len(buffer) > batch_size:
            transition = buffer.sample(batch_size)
            loss = critic.update(optimizer, agent, transition, trunc, 0.99)
            episode_loss += loss
        
        state = next_state
        episode_reward += reward
        
        if trunc:
           # average_loss = np.mean(episode_loss[-10:]) # average of loss 
            sys.stdout.write("episode: {}, loss: {}, reward: {}, average _reward: {} \n".format(episode, episode_loss, np.round(episode_reward, decimals=2), np.mean(rewards[-10:])))
            break
            
    losses.append(episode_loss)
    rewards.append(episode_reward)
    avg_rewards.append(np.mean(rewards[-10:]))

([array([-0.99935657, -0.03586669,  0.09106901], dtype=float32), array([-0.9999689 , -0.0078914 ,  0.28591472], dtype=float32), array([-0.9995171 ,  0.03107393, -0.06238474], dtype=float32), array([-0.9997824 ,  0.02086123,  0.23974344], dtype=float32), array([-0.98401093, -0.178108  ,  0.04265997], dtype=float32), array([-0.8440219 , -0.53630865,  0.13203564], dtype=float32), array([-0.9995887 , -0.02867753, -0.06715701], dtype=float32), array([-0.9999943 ,  0.00338831, -0.29316634], dtype=float32), array([-0.99946016, -0.03285376, -0.12836993], dtype=float32), array([-0.9999979 ,  0.00205413,  0.0036163 ], dtype=float32), array([-0.9996619 ,  0.02600037, -0.26408368], dtype=float32), array([-0.9999577 ,  0.00920228, -0.00778649], dtype=float32), array([-0.999928  , -0.01199784,  0.01093847], dtype=float32), array([-0.97680384, -0.21413599, -0.52920073], dtype=float32), array([-0.99968773, -0.02498964,  0.02901145], dtype=float32), array([-0.9993793 ,  0.03522801, -0.22665831], dtype=

AttributeError: 'QNetwork' object has no attribute 'update'