In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from envs.CarRacing import CarRacing
from networks.CQNet import CQNetImage
from networks.MuNet  import MuNetImage
from memory.RewardMemory import Memory
#from memory.TorchMemory import Memory
from tqdm.notebook import tqdm
import numpy as np
from networks.utils import *
import torch
import matplotlib.pyplot as plt

In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cuda


In [5]:
# Create environment
env = CarRacing()

# Create network
critic = CQNetImage(input_channels=3, hidden_dim=128, action_dim=3).to(device)
actor = MuNetImage(input_channels=3, hidden_dim=128, action_dim=3).to(device)

critic_copy = CQNetImage(input_channels=3, hidden_dim=128, action_dim=3).to(device)
actor_copy = MuNetImage(input_channels=3, hidden_dim=128, action_dim=3).to(device)
critic_copy.copyfrom(critic)
actor_copy.copyfrom(actor)

# Create replay memory
memory = Memory(3, env.states, 3, 30000)

epsilon = 1.0
espilon_decay = 0.97
epsilon_min = 0.05

gamma = 0.99
batch_size = 256

sigma = 0.25

ddpg_rewards_per_ep = {} # store rewards for each episode

# Update target netwoks
update_target = 200
update_counter = 0

for e in range(200):
    # Reset environment
    state = env.reset()
    ep_reward = 0
    cont = 0
    
    for t in range(1000):
        action = actor(state).cpu().data.numpy()
        action = np.random.normal(loc=action, scale=sigma).reshape(3)

        # Step environment
        obs, r, terminal, truncated, info = env.step(2*action)

        if action[1] < action[2]:
            r -= 5
        if action[1] > 0.5:
            r += 10
        if action[2] > 0.5:
            r -= 10
        if action[2] < 0.2:
            r += 2 

        ep_reward += r
        memory.add(state[0, 2, :, :], action, r, terminal)
        
        # If at least 1000 transitions in memory, sample minibatch and learn (exercise 2.3)
        if len(memory) >= batch_size:
            minibatch = memory.sample(batch_size)
            states, actions, rewards, next_states, terminals = minibatch
            

            # Compute target values for critic network
            next_actions = actor_copy(next_states)
            #print(next_states.shape)
            #print(next_actions.shape)
            next_q_values = critic_copy(next_states, next_actions)
            targets = rewards + gamma * next_q_values * (1 - terminals)

            # Update critic and actor networks
            critic.update(states, actions, targets, cpu=False)
            actor.update(states, critic)

            # Update target networks if necessary
            update_counter += 1
            if update_counter >= update_target:
                critic_copy.copyfrom(critic)
                actor_copy.copyfrom(actor)
                update_counter = 0

        state = obs

        if terminal or truncated:
            break
        
        cont += 1
    
    ddpg_rewards_per_ep[e] = [ep_reward, cont]

    if epsilon > epsilon_min:
        epsilon *= espilon_decay

# Close environment
# env.plotnetwork(actor, critic)
env.close()

tensor([[-0.0459, -0.0024,  0.0263]], device='cuda:0',
       grad_fn=<AddmmBackward0>)
tensor([[-0.0456, -0.0024,  0.0260]], device='cuda:0',
       grad_fn=<AddmmBackward0>)
tensor([[-0.0458, -0.0019,  0.0257]], device='cuda:0',
       grad_fn=<AddmmBackward0>)
tensor([[-0.0459, -0.0016,  0.0256]], device='cuda:0',
       grad_fn=<AddmmBackward0>)
tensor([[-0.0458, -0.0013,  0.0258]], device='cuda:0',
       grad_fn=<AddmmBackward0>)
tensor([[-0.0459, -0.0011,  0.0258]], device='cuda:0',
       grad_fn=<AddmmBackward0>)
tensor([[-0.0458, -0.0015,  0.0258]], device='cuda:0',
       grad_fn=<AddmmBackward0>)
tensor([[-0.0458, -0.0016,  0.0259]], device='cuda:0',
       grad_fn=<AddmmBackward0>)
tensor([[-0.0459, -0.0015,  0.0256]], device='cuda:0',
       grad_fn=<AddmmBackward0>)
tensor([[-0.0458, -0.0015,  0.0256]], device='cuda:0',
       grad_fn=<AddmmBackward0>)
tensor([[-0.0458, -0.0016,  0.0258]], device='cuda:0',
       grad_fn=<AddmmBackward0>)
tensor([[-0.0458, -0.0016,  0.02

KeyboardInterrupt: 