In [None]:
import torch
from torch import nn
from torch import optim
import torch.nn.functional as F
import numpy as np

In [None]:
class OUActionNoise(object):
    def __init__(self, mu, sigma=0.15, theta=0.2, dt=1e-2, x0=None):
        self.theta = theta
        self.mu = mu
        self.sigma = sigma
        self.dt = dt
        self.x0 = x0
        self.reset()
    
    def __call__(self):
        x = self.x_prev + self.theta * (self.my - self.x_prev) * self.dt + \
            self.sigma * np.sqrt(self.dt)*np.random.normal(size=self.mu.shape)
        self.x_prev = x
        return x
    def reset(self):
        self.x_prev = self.x0 if x0 is not None else np.zeros_like(self.mu)

In [None]:
class ReplayBuffer(object):
    def __init__(self, max_size, input_shape, n_actions):
        self.mem_size = max_size
        self.mem_cntr = 0
        self.state_memory = np.zeros((self.mem_size, *input_shape))
        self.new_state_memory = np.zeros((self.mem_size, *input_shape))
        self.action_memory = np.zeros((self.mem_size, n_actions))
        self.reward_memory = np.zeros(self.mem_size)
        self.terminal_memory = np.zeros(self.mem_size, dtype=np.flatiter)
    
    def store_transition(self, state, action, reward, state_, done):
        index = self.mem_cntr % self.mem_size
        self.state_memory[index] = state
        self.action_memory[index] = action
        self.reward_memory[index] = reward
        self.new_state_memory[index] = state_
        self.terminal_memory[index] = 1 - done
        self.mem_cntr += 1

    def sample_buffer(self, batch_size):
        max_memory = min(self.mem_cntr, self.mem_size)
        batch = np.random.choice(max_memory, batch_size)

        states = self.state_memory[batch]
        new_states = self.new_state_memory[batch]
        reward = self.reward_memory[batch]
        actions = self.action_memory[batch]
        terminal = self.terminal_memory[batch]

        return states, actions, reward, new_states, terminal

In [None]:
class CriticalNetwork(nn.Module):
    def __init__(self, lr, input_dim, fc1_dim, fc2_dim, n_actions):
        super(CriticalNetwork, self).__init__()
        self.input_dim = input_dim
        self.fc1_dim = fc1_dim
        self.fc2_dim = fc2_dim
        self.n_actions = n_actions

        self.fc1 = nn.Linear(*self.input_dim, self.fc1_dim)
        f1 = 1 / np.sqrt(self.fc1.weight.data.size()[0])
        torch.nn.init.uniform_(self.fc1.weight.data, -f1, f1)
        torch.nn.init.uniform_(self.fc1.bias.data, -f1, f1)

        self.bn1 = nn.LayerNorm(self.fc1_dim)


        self.fc2 = nn.Linear(self.fc1_dim, self.fc2_dim)
        f2 = 1 / np.sqrt(self.fc2.weight.data.size()[0])
        torch.nn.init.uniform_(self.fc2.weight.data, -f2, f2)
        torch.nn.init.uniform_(self.fc2.bias.data, -f2, f2)

        self.bn2 = nn.LayerNorm(self.fc2_dim)

        self.action_value = nn.Linear(self.n_actions, self.fc2_dim)
        f3 = 0.003
        self.q = nn.Linear(self.fc2_dim, 1)
        torch.nn.init.uniform_(self.q.weight.data, -f3, f3)
        torch.nn.init.uniform_(self.q.bias.data, -f3, f3)

        self.optimizer = optim.Adam(self.parameters(), lr=lr)
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

        self.to(self.device)
    
    def forward(self, state, action):
        state_value = self.fc1(state)
        state_value = self.bn1(state_value)
        state_value = F.relu(state_value)
        state_value = self.fc2(state_value)
        state_value = self.bn2(state_value)

        action_value = F.relu(self.action_value(action))
        state_action_value = F.relu(torch.add(state_value, action_value))
        state_action_value = self.q(state_action_value)

        return state_action_value

In [None]:
class ActionNetwork(nn.Module):
    def __init__(self, alpha, input_dim, fc1_dim, fc2_dim, n_actions):
        super(ActionNetwork, self).__init__()
        self.input_dim = input_dim
        self.fc1_dim = fc1_dim
        self.fc2_dim = fc2_dim
        self.n_actions = n_actions

        self.fc1 = nn.Linear(*self.input_dim, self.fc1_dim)
        f1 = 1 / np.sqrt(self.fc1.weight.data.size()[0])
        torch.nn.init.uniform_(self.fc1.weight.data, -f1, f1)
        torch.nn.init.uniform_(self.fc1.bias.data, -f1, f1)
        
        self.bn1 = nn.LayerNorm(self.fc1_dim)

        self.fc2 = nn.Linear(self.fc1_dim, self.fc2_dim)
        f2 = 1 / np.sqrt(self.fc2.weight.data.size()[0])
        torch.nn.init.uniform_(self.fc2.weight.data, -f2, f2)
        torch.nn.init.uniform_(self.fc2.bias.data, -f2, f2)
        
        self.bn2 = nn.LayerNorm(self.fc2_dim)

        f3 = 0.003
        self.mu = nn.Linear(self.fc2_dim, self.n_actions)
        torch.nn.init.uniform_(self.mu.weight.data, -f3, f3)
        torch.nn.init.uniform_(self.mu.bias.data, -f3, f3)

        self.optimizer = optim.Adam(self.parameters(), lr=alpha)

        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.to(self.device)

    def forward(self, state):
        x = self.fc1(state)
        x = self.bn1(x)
        x = F.relu(x)
        x = self.fc2(x)
        x = self.bn2(x)
        x = F.relu(x)
        x = torch.tanh(self.mu(x))

        return x

In [None]:
class Agent(object):
      def __init__(self,
                  alpha,
                  beta,
                  input_dim,
                  tau,
                  env,
                  gamma=0.99,
                  n_actions=2,
                  max_size= 10**7,
                  layer1_size=400,
                  layer2_size=300,
                  batch_size = 64
                  ):
                  self.alpha = alpha
                  self.beta = beta
                  self.tau = tau
                  self.gamma = gamma
                  self.batch_size = batch_size

                  self.memo = ReplayBuffer(max_size, input_dim, n_actions)

                  self.actor = ActionNetwork(alpha, input_dim, layer1_size, layer2_size, n_actions)
                  
                  self.tar_actor = ActionNetwork(alpha, input_dim, layer1_size, layer2_size, n_actions)

                  self.critic = CriticalNetwork(beta, input_dim, layer1_size, layer2_size, n_actions)

                  self.tar_critic = CriticalNetwork(beta, input_dim, layer1_size, layer2_size, n_actions)
            
                  self.noise = OUActionNoise(mu=np.zeros(n_actions))

                  self.update_net_params(tau=1)

      def choose_action(self, observation):
            self.actor.eval()
            observation = torch.tensor(observation, dtype=torch.float).to(self.actor.device)
            mu =  self.actor(observation).to(self.actor.device)
            mu_prime = mu + torch.tensor(self.noise(), dtype=torch.float).to(self.actor.device))

            self.actor.train()

            return mu_prime.cpu().detach().numpy()
      
      def remenber(self, state, action, reward, new_state, done):
            self.memo.store_transition(state, action, reward, new_state, done)
      
      def learn(self):
            if self.memo.mem_cntr < self.batch_size:
                  return
            state, action, reward, new_state, done = \
                              self.memo.sample_buffer(self.batch_size)
            reward = torch.tensor(reward, dtype= torch.float).to(self.critic.device)
            new_state = torch.tensor(new_state, dtype= torch.float).to(self.critic.device)
            state = torch.tensor(state, dtype= torch.float).to(self.critic.device)
            done = torch.tensor(done, dtype= torch.float).to(self.critic.device)
            action = torch.tensor(action, dtype= torch.float).to(self.critic.device)

            self.tar_actor.eval()
            self.tar_critic.eval()
            self.critic.eval()

            tar_action = self.tar_actor.forward(new_state)
            critic_value_ = self.tar_critic.forward(new_state, tar_action)
            critic_value = self.critic.forward(state, action)

            tar = []
            for i in range(self.batch_size):
                  tar.append(reward[i] + self.gamma*critic_value_[i]*done[i])
            tar = torch.tensor(tar).to(self.critic.device)
            tar = tar.view(self.batch_size, 1)


            self.critic.train()
            self.critic.optimizer.zero_grad()

            critic_loss = F.mse_loss(tar, critic_value)
            critic_loss.backward()
            self.critic.optimizer.step()

            self.critic.eval()
            self.actor.optimizer.zero_grad()
            mu = self.actor.forward(state)
            self.actor.train()
            actor_loss = self.critic.forward(state, mu)
            actor_loss = torch.mean(actor_loss)
            actor_loss.backward()
            self.actor.optimizer.step()

            self.update_network_parameters()

      def update_network_parameters(self, tau=None):
            if tau is None:
                  tau = self.tau
            actor_params = self.actor.named_parameters()
            critic_params = self.critic.named_parameters()
            tar_actor_params = self.tar_actor.named_parameters()
            tar_critic_params = self.tar_critic.named_parameters()

            actor_state_dict = dict(actor_params)
            tar_actor_state_dict = dict(tar_actor_params)
            critic_state_dict = dict(critic_params)
            tar_critic_state_dict = dict(tar_critic_params)

            for name in critic_state_dict:
                  critic_state_dict[name] = tau*critic_state_dict[name].clone() + \
                                          (1 - tau) * tar_critic_state_dict[name].clone()
            
            self.tar_critic.load_state_dict(critic_state_dict)
            
            for name in actor_state_dict:
                  actor_state_dict[name] = tau*actor_state_dict[name].clone() + \
                                          (1 - tau) * tar_actor_state_dict[name].clone()
            
            self.tar_actor.load_state_dict(actor_state_dict)