# Actor-Critic Method

In [1]:
from __future__ import print_function

import numpy as np
import random
import matplotlib.pyplot as plt
import gym

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as T
import torch.optim as optim
import torch.autograd as autograd
from torch.autograd import Variable

import pygame
import Maze_Solver as maze_solver
from Maze_Solver import MazeSolver, MazeSolverEnv
import Maze_Generator as maze_generator

pygame 2.1.2 (SDL 2.0.18, Python 3.9.7)
Hello from the pygame community. https://www.pygame.org/contribute.html


In [2]:
# cuda가 설치되어 있다면 cuda를 사용하고 아니라면 cpu만을 사용한다.
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

![](https://firebasestorage.googleapis.com/v0/b/aing-biology.appspot.com/o/sutton_barto_reinforcement_learning%2Fchapter13%2F01.PNG?alt=media&token=7dbd1717-7691-44fc-9375-c86fe3d04e7a)

In [3]:
class ActorCritic(nn.Module):
    def __init__(self, inputs, outputs):
        super(ActorCritic, self).__init__()
        
        # for Actor
        self.fc1 = nn.Linear(inputs, 256)
        self.fc2 = nn.Linear(256, 256)
        self.fc3 = nn.Linear(256, inputs)
        self.fc4 = nn.Linear(inputs, outputs)
        self.bn1 = nn.BatchNorm1d(256) # 4 actions
        self.bn2 = nn.BatchNorm1d(inputs)
        self.bn3 = nn.BatchNorm1d(outputs)
        self.tanh = nn.Tanh()
        self.head = nn.Softmax(dim=0)
        
        # torch.log makes nan(not a number) error so we have to add some small number in log function
        self.ups=1e-7

    # Called with either one element to determine next action, or a batch
    # during optimization. Returns tensor([[left0exp,right0exp]...]).
    def forward(self, x):
        state = x.to(device)
        x = self.fc1(state)
        x = self.fc2(x)
        x = self.fc3(x)
        x = self.fc4(x)
        
        v = self.fc1(state)
        v = self.fc2(v)
        v = self.fc3(v)
        v = self.fc4(v)
        
        return v, self.head(x)
    
    def pi(self, s, a):
        s = Variable(torch.Tensor(s))
        #s = torch.unsqueeze(s, 0)
        _, probs = self.forward(s)
        probs = torch.squeeze(probs, 0)
        return probs[a]
    
    def get_action(self, state):
        state = Variable(torch.tensor(state))
        #state = torch.unsqueeze(state, 0)
        _, probs = self.forward(state)
        probs = torch.squeeze(probs, 0)
        action = probs.multinomial(num_samples=1)
        action = action.data
        
        action = action[0]
        return action
    
    def epsilon_greedy_action(self, state, epsilon = 0.1):
        state = Variable(torch.tensor(state))
        state = torch.unsqueeze(state, 0)
        _, probs = self.forward(state)
        
        probs = torch.squeeze(probs, 0)
        
        if random.random() > epsilon:
            action = torch.tensor([torch.argmax(probs)])
        else:
            action = torch.rand(probs.shape).multinomial(num_samples=1)
        
        action = action.data
        action = action[0]
        return action
    
    def value(self, s):
        s = Variable(torch.tensor(s))
        s = torch.unsqueeze(s, 0)
        value, _ = self.forward(s)
        value = torch.squeeze(value, 0)
        value = value[0]
        
        return value
'''            
class Critic(nn.Module):
    def __init__(self, inputs):
        super(Critic, self).__init__()
        self.fc1 = nn.Linear(inputs, 128)
        self.fc2 = nn.Linear(128, 128)
        self.fc3 = nn.Linear(128, inputs)
        self.fc4 = nn.Linear(inputs, 1)
        self.bn1 = nn.BatchNorm1d(128) # 4 actions
        self.bn2 = nn.BatchNorm1d(inputs)
        self.bn3 = nn.BatchNorm1d(1)
        self.tanh = nn.Tanh()
    
    def forward(self, x):
        x = x.to(device)
        x = self.fc1(x)
        x = self.fc2(x)
        x = self.fc3(x)
        x = self.fc4(x)
                
        return x
''' 
def update_weight(optimizer, values, log_probs, rewards, last_Qval):
    
    # compute Q values
    Qval = last_Qval
    Qvals = torch.zeros(len(rewards))
    for t in reversed(range(len(rewards))):
        Qval = rewards[t] + GAMMA * Qval
        Qvals[t] = Qval
    Qvals = Variable(Qvals, requires_grad=True).to(device)
    
    values = torch.tensor(values, dtype=torch.float32).to(device)
    log_probs = torch.tensor(log_probs, dtype=torch.float32).to(device)
    
    advantage = Qvals - values
    actor_loss = (-log_probs * advantage).mean()
    critic_loss = 0.5 * advantage.pow(2).mean()
    loss = actor_loss + critic_loss
    
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

In [None]:
MAX_EPISODES = 10000
MAX_TIMESTEPS = 1000

ALPHA = 3e-5 # learning rate
GAMMA = 0.99 # step-size for actor
BETA = 0.9 # step-size for critic

#env = MazeSolverEnv()
env = gym.make('CartPole-v0')

num_actions = 2 #env.num_action
num_states = 4 #365

actor_critic = ActorCritic(num_states, num_actions).to(device)

actor_critic.eval()

optimizer = optim.Adam(actor_critic.parameters(), lr=ALPHA)

try:
    returns = []
    
    for i_episode in range(MAX_EPISODES):

        state = env.reset()
        #state = env.init_obs
        done = False

        values = []
        log_probs = []
        rewards = []   # no reward at t = 0

        #while not done:
        for timesteps in range(MAX_TIMESTEPS):
            value = float(actor_critic.value(state).cpu().detach().numpy())
            values.append(value)
            action = actor_critic.get_action(state)
            log_prob = float(torch.log(actor_critic.pi(state, action)).cpu().detach().numpy())
            log_probs.append(log_prob)

            state, reward, done, _ = env.step(action.tolist())

            rewards.append(reward)

            if done or timesteps == MAX_TIMESTEPS:
                last_Qval = float(actor_critic.value(state).cpu().detach().numpy())
                #print("Episode {} finished after {} timesteps".format(i_episode, timesteps+1))
                break

        #print("Episode {} return: {}".format( i_episode + 1, sum(rewards)))

        update_weight(optimizer, values, log_probs, rewards, last_Qval)
        
        #print("Update {} finished".format(i_episode + 1))

        if (i_episode + 1) % 500 == 0:
            print("Episode {} return {}".format(i_episode + 1, sum(rewards)))
            torch.save(actor_critic, './saved_models/actor_critic' + str(i_episode + 1) + '.pt')

        env.close()
        
        returns.append(sum(rewards))
        #env.reset_player(exploring_starts = False)

except KeyboardInterrupt:
    plt.plot(range(len(returns)), returns)

env.close()


Episode 500 return 13.0
Episode 1000 return 9.0
Episode 1500 return 40.0
Episode 2000 return 14.0
Episode 2500 return 15.0
Episode 3000 return 11.0
Episode 3500 return 12.0
Episode 4000 return 14.0
