In [1]:
import numpy as np


class Agent:

    # Function to initialise the agent
    def __init__(self):
        # Set the episode length
        self.episode_length = 100
        # Reset the total number of steps which the agent has taken
        self.num_steps_taken = 0
        # The state variable stores the latest state of the agent in the environment
        self.state = None
        # The action variable stores the latest action which the agent has applied to the environment
        self.action = None

    # Function to check whether the agent has reached the end of an episode
    def has_finished_episode(self):
        if self.num_steps_taken % self.episode_length == 0:
            return True
        else:
            return False

    # Function to get the next action, using whatever method you like
    def get_next_action(self, state):
        # Here, the action is random, but you can change this
        action = np.random.uniform(low=-0.01, high=0.01, size=2).astype(np.float32)
        # Update the number of steps which the agent has taken
        self.num_steps_taken += 1
        # Store the state; this will be used later, when storing the transition
        self.state = state
        # Store the action; this will be used later, when storing the transition
        self.action = action
        return action

    # Function to set the next state and distance, which resulted from applying action self.action at state self.state
    def set_next_state_and_distance(self, next_state, distance_to_goal):
        # Convert the distance to a reward
        reward = 1 - distance_to_goal
        # Create a transition
        #here self.action == self.discrete_action
        transition = (self.state, self.action, reward, next_state)
        # Now you can do something with this transition ...
        self.state = next_state
        # Update the agent's reward for this episode
        self.total_reward += reward
        return transition

    # Function to get the greedy action for a particular state
    def get_greedy_action(self, state):
        # Here, the greedy action is fixed, but you should change it so that it returns the action with the highest Q-value
        action = np.array([0.02, 0.0], dtype=np.float32)
        return action
    
class DQN:
    
    # The class initialisation function.
    def __init__(self):
        # Create a Q-network, which predicts the q-value for a particular state.
        self.q_network = Network(input_dimension=2, output_dimension=4)
        # Define the optimiser which is used when updating the Q-network. 
        # The learning rate determines how big each gradient step is during backpropagation.
        self.optimiser = torch.optim.Adam(self.q_network.parameters(), lr=0.001)
    
    def train_q_network(self, transition):
        # Set all the gradients stored in the optimiser to zero.
        self.optimiser.zero_grad()
        # Calculate the loss for this transition.
        loss = self._calculate_loss(transition)
        # Compute the gradients based on this loss, i.e. the gradients of the loss with respect to the Q-network parameters.
        loss.backward()
        # Take one gradient step to update the Q-network.
        self.optimiser.step()
        # Return the loss as a scalar
        return loss.item()
    
    
    def _calculate_loss(self, transition):
        #calculate the Q-network's loss based on its transition 
        #create state tensor 
        #.unsqueeze(0) - remove dimension // .squeeze() - add dimension
        state_tensor = torch.tensor(transition[0], dtype=torch.float32)
        #predict immediate reward 
        #here :online learning
        predicted_reward_tensor = self.q_network.forward(state_tensor)[transition[1]]
        
        actual_immediate_reward_tensor = torch.tensor(transition[2], dtype=torch.float32)
        
        loss = torch.nn.MSE()(predicted_reward_tensor, actual_immediate_reward_tensor)
        
        return loss

#used to execute some code only if the file was run directly, and not imported
if __name__ == "main":
    
    environment = Environment(display=True, magnification=500)
    agent = Agent(environment)




