In [None]:
!pip install gym

In [None]:
import torch as T
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np


#s : stands for state
#a : stands for action
#r : stands for reward
#thanks for Phil's help to make this project
class DeepQLearning(nn.Module):
    def __init__(self, ALPHA, input_dimension, fc1_dims, fc2_dims,
                 number_of_as):
        super(DeepQLearning, self).__init__()
        self.input_dimension = input_dimension
        self.fc1_dims = fc1_dims
        self.fc2_dims = fc2_dims
        self.number_of_as = number_of_as
        self.function1 = nn.Linear(*self.input_dimension, self.function1)
        self.function2 = nn.Linear(self.function1, self.function2)
        self.function3 = nn.Linear(self.function2, self.number_of_as)
        self.optimizer = optim.Adam(self.parameters(), lr=ALPHA)
        self.loss = nn.MSELoss()
        self.device = T.device('cuda:0' if T.cuda.is_available() else 'cuda:1')
        self.to(self.device)

    def forward(self, s):
        s = T.Tensor(s).to(self.device)
        x = F.relu(self.function1(s))
        x = F.relu(self.function2(x))
        as = self.function3(x)
        return as

class Agent(object):
    def __init__(self, gamma, epsilon, alpha, input_dimension, batch_size, number_of_as,
                 max_mem_size=100000, eps_end=0.01, eps_dec=0.996):
        
        
        self.Q_function = DeepQLearning(alpha, number_of_as=self.number_of_as,
                              input_dimension=input_dimension, function1=256, function2=256)
        self.s_memory = np.zeros((self.mem_size, *input_dimension))
        self.new_s_memory = np.zeros((self.mem_size, *input_dimension))
        self.a_memory = np.zeros((self.mem_size, self.number_of_as),
                                      dtype=np.uint8)
        self.r_memory = np.zeros(self.mem_size)
        self.terminal_memory = np.zeros(self.mem_size, dtype=np.uint8)
        self.gamma = gamma
        self.epsilon = epsilon
        self.EPS_MIN = eps_end
        self.EPS_DEC = eps_dec
        self.ALPHA = alpha
        self.a_space = [i for i in range(number_of_as)]
        self.number_of_as = number_of_as
        self.mem_size = max_mem_size
        self.batch_size = batch_size
        self.mem_cntr = 0

    def save_a_r_s(self, s, a, r, s_, terminal):
        index = self.mem_cntr % self.mem_size
        self.s_memory[index] = s
        as = np.zeros(self.number_of_as)
        as[a] = 1.0
        self.a_memory[index] = as
        self.r_memory[index] = r
        self.new_s_memory[index] = s_
        self.terminal_memory[index] = 1 - terminal
        self.mem_cntr += 1

    def take_an_action_action(self, s):
        rand = np.random.random()
        as = self.Q_function.forward(s)
        if rand > self.epsilon:
            a = T.argmax(as).item()
        else:
            a = np.random.choice(self.a_space)
        return a

    def learn(self):
        if self.mem_cntr > self.batch_size:
            self.Q_function.optimizer.zero_grad()

            max_mem = self.mem_cntr if self.mem_cntr < self.mem_size \
                                    else self.mem_size

            batch = np.random.choice(max_mem, self.batch_size)
            s_batch = self.s_memory[batch]
            a_batch = self.a_memory[batch]
            val_a = np.array(self.a_space, dtype=np.int32)
            ind_a = np.dot(a_batch, val_a)
            r_batch = self.r_memory[batch]
            new_s_batch = self.new_s_memory[batch]
            terminal_batch = self.terminal_memory[batch]

            r_batch = T.Tensor(r_batch).to(self.Q_function.device)
            terminal_batch = T.Tensor(terminal_batch).to(self.Q_function.device)

            Q_function = self.Q_function.forward(s_batch).to(self.Q_function.device)
            q_target_function = Q_function.clone()
            q_next = self.Q_function.forward(new_s_batch).to(self.Q_function.device)

            batch_index = np.arange(self.batch_size, dtype=np.int32)
            q_target_function[batch_index, ind_a] = r_batch + \
                                self.gamma*T.max(q_next, dim=1)[0]*terminal_batch

            self.epsilon = self.epsilon*self.EPS_DEC if self.epsilon > \
                           self.EPS_MIN else self.EPS_MIN

            loss = self.Q_function.loss(q_target_function, Q_function).to(self.Q_function.device)
            loss.backward()
            self.Q_function.optimizer.step()