In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from envs.CarRacing import CarRacing
from networks.PolicyNet import PolicyNetImage
from networks.ValueNet import ValueNetImage
from memory.RewardMemory import Memory
from tqdm.notebook import tqdm
import numpy as np
import sys
import traceback
from networks.utils import *
import torch

In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [4]:
env = CarRacing()
critic = ValueNetImage(input_channels=3, hidden_dim=128)
actor = PolicyNetImage(input_channels=3, hidden_dim=128, action_dims=3)

memory = Memory(3, env.states, 3, 10000)

sigma = 1.0
_lambda = 0.97
gamma = 0.99
batch_size = 256

rewards_per_ep = {} # store rewards for each episode

for e in tqdm(range(200)):
  state = env.reset()
  cont = 0
  ep_reward = 0

  for t in range(500):
    action, logp = actor(state)
    action_dt = action.cpu().detach().numpy().reshape(3)
    print(action_dt)
    logp_dt = logp.cpu().detach()
    # print(action_np.shape)
    #print(type(action_np))

    v = critic(state)
    v_dt = v.cpu().detach()

    obs, r, terminal, truncated, info = env.step(action_dt)
    ep_reward += r
    if action_dt[1] > 0.5:
       ep_reward += 0.1

    memory.add(state, action_dt, r, obs, terminal, v_dt, logp_dt)

    if terminal or truncated:
      break

  rewards_per_ep[e] = [ep_reward, cont]

  if terminal:
      v = 0
  else:
      v = critic(obs)
      
  v_dt = v.cpu().detach()

  i = len(memory)
  memory.adv[i] = 0
  memory.rtg[i] = v_dt

  for i in reversed(range(len(memory)-t, len(memory))):
      delta = memory.r[i] + gamma*memory.v[i+1]-memory.v[i]
      adv = gamma*_lambda*memory.adv[i+1] + delta
      rtg = gamma*memory.rtg[i+1]+memory.r[i]
      
      memory.adv[i] = adv
      memory.rtg[i] = rtg
      
  if e % 10 == 0 and e != 0:
      for i in tqdm(range(80), leave=False):
          # update(self, s, a_prev, logp_a, adv, clip_ratio=0.2)
          #s, a, r, sp, terminal, v=0, logp=0
          actor.update(memory.s, memory.a, memory.logp, memory.adv)
          # update(self, inputs, targets)
          critic.update(memory.s, memory.rtg)
      
      #tqdm.write(f"Ep {e}: Actor and Critic updated | Ep reward: {ep_reward} | Last 100 rewards: {np.mean(list(rewards_per_ep.values())[-100:], axis=0)[0]}")
      memory.reset()



    

# Close environment
#env.plotnetwork(critic, actor)
env.close()

  0%|          | 0/200 [00:00<?, ?it/s]

[-0.01914418  0.71493495 -0.1326167 ]
[0.5887221 1.0638251 0.4575361]
[ 0.8121351   0.7563456  -0.58157045]
[0.18943891 1.0517627  0.05613476]
[ 0.93572867  0.47736216 -0.61933637]
[ 0.5197234   0.24025477 -0.1601747 ]
[0.67743707 0.93220496 0.18610378]
[ 0.4406666  -0.9103477   0.93279684]
[-0.13306949  0.30347615  1.0505118 ]
[ 1.0630323  0.7986758 -0.1574809]
[-0.46803054  0.69512856 -0.8922265 ]
[ 1.6233827  -1.3224074   0.08786148]
[-0.4906613   0.49102342 -0.9587671 ]
[1.1761818  0.28422058 0.72525394]
[ 0.07025111 -0.02572078  0.79642993]
[0.85357404 0.36714283 0.84280455]
[-1.0225673   0.97858727  0.03982335]
[-0.3422409  1.2317141  1.3776133]
[-2.7728996e-01 -1.8250942e-04  7.5963134e-01]
[0.8884176  1.4873381  0.25265384]
[-0.22192034 -0.02900133  0.72523606]
[-0.2688854   0.79304445 -0.4236369 ]
[-0.02118799  0.08079705  1.2579219 ]
[1.3351903  0.7755492  0.82116437]
[ 1.6699901  -0.49431154  0.46175513]
[-0.4328545   0.46849424  0.28369927]
[0.52481556 0.23737955 0.30555826

KeyboardInterrupt: 

In [5]:
env.close()

In [None]:
# try:
  # for e in tqdm(range(5)):
  #   state = env.reset()
    
  #   ep_reward = 0
  #   print(state.shape)
  #   for t in range(20):
      
  #     v = critic(state)

  #     obs, r, terminal, truncated, info = env.step(np.array([0,1,0]))

      #   if terminal or truncated:
      #     break
# except BaseException as ex:
#     # Get current system exception
#     ex_type, ex_value, ex_traceback = sys.exc_info()

#     # Extract unformatter stack traces as tuples
#     trace_back = traceback.extract_tb(ex_traceback)

#     # Format stacktrace
#     stack_trace = list()

#     for trace in trace_back:
#         stack_trace.append("File : %s , \nLine : %d, \nFunc.Name : %s, \nMessage : %s" % (trace[0], trace[1], trace[2], trace[3]))

#     print("Exception type : %s " % ex_type.__name__)
#     print("Exception message : %s" %ex_value)
#     # Print each line in stack trace separately
#     print("Stack trace : ")
#     for trace in stack_trace:
#         print(trace)
#     env.close()