### Deep Q Network for finite action space
TODO: 
- DQN using Double Q learning with target network, replay buffer, target network, and epsilon greedy policy

In [4]:
import torch
import torch.nn as nn
import torch.optim as optim

import homework2 as hw2
from homework2 import Hw2Env

import time
import numpy as np

### Define the network
##### it will take the high level state of the simulation (ee_pos, obj_pos, goal_post) as input and give the Q values for each action as output 

In [5]:
class QNetwork(nn.Module):
    def __init__(self, state_dim):
        super(QNetwork, self).__init__()
        self.fc1 = nn.Linear(state_dim, 64)
        self.fc2 = nn.Linear(64, 64)
        self.fc3 = nn.Linear(64, 8)
        
    def forward(self, state):
        x = torch.relu(self.fc1(state))
        x = torch.relu(self.fc2(x))
        x = self.fc3(x)
        return x
        
        

### Define the agent
##### it will have the Q network and the target network, and will use the epsilon greedy policy and the replay buffer methods. 
- lr: learning rate, used in the optimizer
- gamma: discount factor that describes how much the agent values future rewards
- epsilon: the probability of selecting a random action instead of the greedy optimal one
- tau: used for polyak averaging, the rate at which the target network is updated

In [11]:
class DQNAgent:
    def __init__(self, state_dim, lr=0.001, gamma=0.99, epsilon=0.1, tau=0.999):
        self.q_network = QNetwork(state_dim)
        self.target_network = QNetwork(state_dim)
        self.target_network.load_state_dict(self.q_network.state_dict())
        self.optimizer = optim.Adam(self.q_network.parameters(), lr=lr)
        self.gamma = gamma
        self.epsilon = epsilon
        self.tau = tau
        self.replay_buffer = []
        
    # epsilon greedy policy to select the action
    def get_action(self, state):
        if np.random.rand() < self.epsilon:
            return np.random.randint(0, 8)
        else:
            return self.get_optimal_action(state)
        
    def get_optimal_action(self, state):
        state = torch.tensor(state, dtype=torch.float32)
        return torch.argmax(self.q_network(state)).item()
    
    # Since we will use Double Q learning, we want to use different networks to compute the reward and the target
    # We will use the target network for the target estimation, and as input to the target network we will give the argmax action of the current network, decoupling the action selection (now done by the current network) from the target estimation (done by the target network)
    def estimate_target(self, reward, next_state):
        next_state = torch.tensor(next_state, dtype=torch.float32)
        print(np.shape(next_state))
        action = torch.argmax(self.q_network(next_state)).item()
        y = reward + (self.gamma * self.target_network(next_state)[action].item())
        return y
    
    def update_q_network(self, state, action, target):
        state = torch.tensor(state, dtype=torch.float32)
        target = torch.tensor(target, dtype=torch.float32)
        loss = nn.MSELoss()(self.q_network(state)[action], target)
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()
        
    def load_target_network_with_current_network(self):
        self.target_network.load_state_dict(self.q_network.state_dict())
        
    def update_target_network_using_polyak_avg(self):
        for target_param, param in zip(self.target_network.parameters(), self.q_network.parameters()):
            target_param.data.copy_(self.tau*param.data + (1-self.tau)*target_param.data)
            
    def add_to_replay_buffer(self, transition):
        if len(self.replay_buffer) > 100000:
            self.replay_buffer.pop(0)
        self.replay_buffer.append(transition)
        
    def sample_replay_buffer(self, batch_size):
        if len(self.replay_buffer) < batch_size:
            return self.replay_buffer
        return np.random.choice(self.replay_buffer, batch_size)
    
    def train(self, batch_size):
        transitions = self.sample_replay_buffer(batch_size)
        for transition in transitions:
            state, action, reward, next_state = transition
            target = self.estimate_target(reward, next_state)
            self.update_q_network(state, action, target)
        self.update_target_network_using_polyak_avg()
    
    
    def save(self, path):
        torch.save(self.q_network.state_dict(), path)
        

### Train the agent in the simulation environment

In [None]:
N_ACTIONS = 8
env = Hw2Env(n_actions=N_ACTIONS, render_mode="gui")

state_dim = 6
agent = DQNAgent(state_dim)

n_episodes = 100
batch_size = 32
for episode in range(n_episodes):
    env.reset()
    done = False
    cum_reward = 0.0
    start = time.time()
    while not done:
        state = env.high_level_state()
        print(np.shape(state))
        action = agent.get_action(state)
        next_state, reward, is_terminal, is_truncated = env.step(action)
        agent.add_to_replay_buffer((state, action, reward, next_state))
        agent.train(batch_size)
        done = is_terminal or is_truncated
        cum_reward += reward
    end = time.time()
    print(f"Episode={episode}, reward={cum_reward}, RF={env.data.time/(end-start):.2f}")