## CSC 580
## Assignment-4
## Hithesh Shanmugam

In [1]:
import random
import numpy as np
import csv #importing csv to write the qtable
class Agent:
    """ 
    An AI agent which controls the snake's movements.
    """
    def __init__(self, env, params):
        self.env = env
        self.action_space = env.action_space  # 4 actions for SnakeGame
        self.state_space = env.state_space    # 12 features for SnakeGame
        self.gamma = params['gamma']
        self.alpha = params['alpha']
        self.epsilon = params['epsilon'] 
        self.epsilon_min = params['epsilon_min'] 
        self.epsilon_decay = params['epsilon_decay']
        ## TO-DO: Choose your data structure to hold the Q table and initialize it
        self.Q= np.zeros((4096, self.action_space))#initializing the datastructure 4096X4 2Darray
        self.visited_states=set()#initializing a set to keep track of the unique states visited

    
    @staticmethod
    def state_to_int(state_list):
        """ Map state as a list of binary digits, e.g. [0,1,0,0,1,1,1] to an integer."""
        return int("".join(str(x) for x in state_list), 2)
    
    @staticmethod
    def state_to_str(state_list):
        """ Map state as a list of binary digits, e.g. [0,1,0,0,1,1,1], to a string e.g. '0100111'. """
        return "".join(str(x) for x in state_list)

    @staticmethod
    def binstr_to_int(state_str):
        """ Map a state binary string, e.g. '0100111', to an integer."""
        return int(state_str, 2)

    # (A) 
    def init_state(self, state):
        """ Initialize the state's entry in state_table and Q, if anything needed at all."""
        state_id = self.state_to_int(state) #converting binary list to int
        self.Q[state_id, :] = 0 #initializing the state entry 
        
    # (A)
    def select_action(self, state):
        """
        Do the epsilon-greedy action selection. Note: 'state' is an original list of binary digits.
        It should call the function select_greedy() for the greedy case.
        """
        if np.random.uniform(0, 1) < self.epsilon: #to check if its less than epsilon
            return np.random.choice(self.action_space) # choose a random action
        else:
            return self.select_greedy(state) # choose the action greedily
        
    # (A)
    def select_greedy(self, state):
        """ 
        Greedy choice of action based on the Q-table; 
        """
        state_int = self.state_to_int(state) #converting binary list to int
        action = np.argmax(self.Q[state_int]) #choose the action using the argmax numpy function for the specific state
        return action #return the action
    
    # (A)
    def update_Qtable(self,state, action, reward, next_state):
        """
        Update the Q-table (and anything else necessary) after an action is taken.
        Note that both 'state' and 'next_state' are an original list of binary digits.
        """
        next_state_int=self.state_to_int(next_state) #converting binary list to int
        state_int=self.state_to_int(state) #converting binary list to int
        max_q_next_state = np.max(self.Q[next_state_int, :]) #storing the maximum of the next state
        td_target = reward + self.gamma * max_q_next_state #implementing the TD Qlearning
        td_error = td_target - self.Q[state_int, action] #implementing the TD Qlearning
        self.Q[state_int, action] += self.alpha * td_error #updating the Qtable as TD Qlearning
        self.visited_states.add(state_int) #storing the states in the set
        self.adjust_epsilon() #adjusting the epsilon
        #note: I have splitted the TDQ learning in two steps as target and error
    
    
    # (A)
    def num_states_visited(self):
        """ Returns the number of unique states visited"""
        return len(set(self.visited_states)) #returning the number of unique states visited
    
    # (A)
    def write_qtable(self, filepath):
        """ Write the content of the Q-table to an output file"""
        with open(filepath, 'w', newline='') as file: #opening the file
            writer = csv.writer(file) #with the csv package writing the csv
            writer.writerow(["state", "action", "qa_value"]) #mentioning the row titles as given
            for state in range(self.state_space): #getting the states from the state_space 
                for action in range(self.action_space): #getting the action from the action_space 
                    writer.writerow([state, action, self.Q[state][action]]) #writing the csv as mentioned

    def adjust_epsilon(self):
        """ Implements the epsilon decay"""
        if self.epsilon > self.epsilon_min:
            self.epsilon *= self.epsilon_decay