In [3]:
import gym 
import numpy as np
from helpers import NormalizedEnv, RandomAgent
from qnetwork2 import ReplayBuffer, QNetwork
from heuristic import HeuristicPendulumAgent
from matplotlib import pyplot
import torch.optim as optim
import sys

In [4]:
# Initialization
env = gym.make("Pendulum-v1")
norm_env = NormalizedEnv(env) # accept actions between -1 and 1

#we fix a torque
torque = norm_env.action(norm_env.action_space.sample())
agent = HeuristicPendulumAgent(norm_env, torque)

buffer = ReplayBuffer(10000)
batch_size = 128

num_states = env.observation_space.shape[0]
num_actions = env.action_space.shape[0]
hidden_size = 256 # choose as you wish 
critic = QNetwork(num_states + num_actions, hidden_size, num_actions, agent)
optimizer = optim.Adam(critic.parameters(), lr=1e-4)

losses = []
rewards = []
avg_rewards = []

In [3]:
for episode in range(200): 
    state, info = norm_env.reset()
    trunc = False
    episode_loss = 0
    #average_loss = 0
    episode_reward = 0
    
    while not trunc:
        action = agent.compute_action(state)
        # print(norm_env.step(action))
        next_state, reward, terminated, trunc, info = norm_env.step(action)
        buffer.add(state, action, reward, next_state, trunc)
        
        if len(buffer) > batch_size:
            transition = buffer.sample(batch_size)
            loss = critic.update(optimizer, transition, trunc, 0.99)
            episode_loss += loss
        
        state = next_state
        episode_reward += reward
        
        if trunc:
           # average_loss = np.mean(episode_loss[-10:]) # average of loss 
            sys.stdout.write("episode: {}, loss: {}, reward: {}, average _reward: {} \n".format(episode, episode_loss, np.round(episode_reward, decimals=2), np.mean(rewards[-10:])))
            break
            
    losses.append(episode_loss)
    rewards.append(episode_reward)
    avg_rewards.append(np.mean(rewards[-10:]))

  states = torch.FloatTensor(states)
  return _methods._mean(a, axis=axis, dtype=dtype,
  ret = ret.dtype.type(ret / rcount)


episode: 0, loss: 6139.49658203125, reward: -1838.28, average _reward: nan 
episode: 1, loss: 16785.60546875, reward: -1833.15, average _reward: -1838.27585434772 
episode: 2, loss: 17376.490234375, reward: -1914.23, average _reward: -1835.7151080675276 
episode: 3, loss: 17351.92578125, reward: -1815.39, average _reward: -1861.885110247916 
episode: 4, loss: 16947.1328125, reward: -1626.89, average _reward: -1850.2605661189125 
episode: 5, loss: 17034.421875, reward: -1897.82, average _reward: -1805.585911824471 
episode: 6, loss: 17098.322265625, reward: -1831.43, average _reward: -1820.9586071058472 
episode: 7, loss: 17384.103515625, reward: -1941.13, average _reward: -1822.4544551925242 
episode: 8, loss: 17307.041015625, reward: -1830.2, average _reward: -1837.2883373137652 
episode: 9, loss: 17319.98046875, reward: -1866.34, average _reward: -1836.5003860611696 
episode: 10, loss: 17296.51953125, reward: -1732.13, average _reward: -1839.4840883237862 
episode: 11, loss: 17135.46

episode: 91, loss: 17086.263671875, reward: -1926.19, average _reward: -1752.7983381917431 
episode: 92, loss: 17132.609375, reward: -1934.17, average _reward: -1755.1151631366308 
episode: 93, loss: 17084.341796875, reward: -1842.87, average _reward: -1760.1899358432813 
episode: 94, loss: 17135.35546875, reward: -1812.82, average _reward: -1762.2041045854226 
episode: 95, loss: 17100.6953125, reward: -1912.44, average _reward: -1783.8237217729202 
episode: 96, loss: 17085.31640625, reward: -1735.4, average _reward: -1799.0847570623864 
episode: 97, loss: 17117.76953125, reward: -1927.22, average _reward: -1799.0737280501428 
episode: 98, loss: 17116.951171875, reward: -1807.51, average _reward: -1828.3827075817376 
episode: 99, loss: 17056.07421875, reward: -1885.52, average _reward: -1826.2573686667045 
episode: 100, loss: 17128.919921875, reward: -1930.09, average _reward: -1831.9140416140403 
episode: 101, loss: 17216.171875, reward: -1943.12, average _reward: -1871.4243408165662 

episode: 181, loss: 17103.5, reward: -1898.76, average _reward: -1761.9273166842552 
episode: 182, loss: 17011.984375, reward: -1824.94, average _reward: -1761.2938153643695 
episode: 183, loss: 17063.6796875, reward: -1943.25, average _reward: -1753.751777782683 
episode: 184, loss: 17143.072265625, reward: -1923.57, average _reward: -1822.0178740244482 
episode: 185, loss: 17051.1640625, reward: -1847.71, average _reward: -1823.1411667805507 
episode: 186, loss: 17141.154296875, reward: -1951.3, average _reward: -1833.3000239282828 
episode: 187, loss: 17147.595703125, reward: -1943.63, average _reward: -1836.7978546136942 
episode: 188, loss: 17171.2265625, reward: -1859.9, average _reward: -1856.223990962118 
episode: 189, loss: 17208.71875, reward: -1566.65, average _reward: -1862.7821986368358 
episode: 190, loss: 17135.501953125, reward: -1807.38, average _reward: -1858.7049087117496 
episode: 191, loss: 17185.310546875, reward: -1920.53, average _reward: -1856.7087999027954 
ep