# Actor-Critic Method

In [1]:
import numpy as np
import random
import matplotlib.pyplot as plt
import gym

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as T
import torch.optim as optim
import torch.autograd as autograd
from torch.autograd import Variable

import pygame
import Maze_Solver as maze_solver
from Maze_Solver import MazeSolver, MazeSolverEnv
import Maze_Generator as maze_generator

pygame 2.1.2 (SDL 2.0.18, Python 3.9.7)
Hello from the pygame community. https://www.pygame.org/contribute.html


In [2]:
# cuda가 설치되어 있다면 cuda를 사용하고 아니라면 cpu만을 사용한다.
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [3]:
class ActorCritic(nn.Module):
    def __init__(self, h, w, outputs):
        super(ActorCritic, self).__init__()
        self.actor_conv1 = nn.Conv2d(5, 16, kernel_size=3, stride=1)
        self.actor_bn1 = nn.BatchNorm2d(16)
        self.actor_conv2 = nn.Conv2d(16, 32, kernel_size=3, stride=1)
        self.actor_bn2 = nn.BatchNorm2d(32)
        self.actor_conv3 = nn.Conv2d(32, 32, kernel_size=3, stride=1)
        self.actor_bn3 = nn.BatchNorm2d(32)
        self.actor_bn4 = nn.BatchNorm1d(4) # 4 actions
        self.actor_tanh = nn.Tanh()
        
        self.critic_conv1 = nn.Conv2d(5, 16, kernel_size=3, stride=1)
        self.critic_bn1 = nn.BatchNorm2d(16)
        self.critic_conv2 = nn.Conv2d(16, 32, kernel_size=3, stride=1)
        self.critic_bn2 = nn.BatchNorm2d(32)
        self.critic_conv3 = nn.Conv2d(32, 32, kernel_size=3, stride=1)
        self.critic_bn3 = nn.BatchNorm2d(32)
        self.critic_bn4 = nn.BatchNorm1d(1)
        self.critic_tanh = nn.Tanh()
        
        # torch.log makes nan(not a number) error so we have to add some small number in log function
        self.ups=1e-7

        # Number of Linear input connections depends on output of conv2d layers
        # and therefore the input image size, so compute it.
        def conv2d_size_out(size, kernel_size = 3, stride = 1):
            return (size - (kernel_size - 1) - 1) // stride  + 1
        convw = conv2d_size_out(conv2d_size_out(conv2d_size_out(w)))
        convh = conv2d_size_out(conv2d_size_out(conv2d_size_out(h)))
        linear_input_size = convw * convh * 32
        
        self.actor_fc1 = nn.Linear(linear_input_size, outputs)
        self.head = nn.Softmax(dim=1)
        
        self.critic_fc1 = nn.Linear(linear_input_size, 1)

    # Called with either one element to determine next action, or a batch
    # during optimization. Returns tensor([[left0exp,right0exp]...]).
    def forward(self, x):
        state = x.to(device)
        
        probs = F.relu(self.actor_conv1(state))
        probs = F.relu(self.actor_conv2(probs))
        probs = F.relu(self.actor_conv3(probs))
        probs = torch.flatten(probs, 1)
        probs = self.head(self.actor_fc1(probs))
        
        value = F.relu(self.critic_conv1(state))
        value = F.relu(self.critic_conv2(value))
        value = F.relu(self.critic_conv3(value))
        value = torch.flatten(value, 1)
        value = F.relu(self.critic_fc1(value))
        
        return value, probs
    
    def get_action(self, state):
        state = torch.tensor(state)
        state = torch.unsqueeze(state, 0)
        _, probs = self.forward(state)
        probs = torch.squeeze(probs, 0)
        action = probs.multinomial(num_samples=1)
        action = action.data
        action = action[0].to(device)
        return action

    def pi(self, s, a):
        s = torch.Tensor(s)
        s = torch.unsqueeze(s, 0)
        _, probs = self.forward(s)
        probs = torch.squeeze(probs, 0)
        
        return probs[a]
    
    def epsilon_greedy_action(self, state, epsilon = 0.1):
        state = torch.tensor(state)
        state = torch.unsqueeze(state, 0)
        _, probs = self.forward(state)
        
        probs = torch.squeeze(probs, 0)
        
        if random.random() > epsilon:
            action = torch.tensor([torch.argmax(probs)])
        else:
            action = torch.rand(probs.shape).multinomial(num_samples=1)
        
        action = action.data
        action = action[0]
        return action
    
    def value(self, s):
        s = torch.tensor(s)
        s = torch.unsqueeze(s, 0)
        value, _ = self.forward(s)
        value = torch.squeeze(value, 0)
        value = value[0]
        
        return value
    
def update_weight(optimizer, states, actions, rewards, last_state, entropy_term=0):
    # compute Q values
    Qval = actor_critic.value(last_state)
    loss = torch.tensor(0, dtype=torch.float32).to(device)
    # loss obtained when rewards are obtained
    len_loss = len(rewards)
    
    for s_t, a_t, r_tt in reversed(list(zip(states, actions, rewards))):
        log_prob = torch.log(actor_critic.pi(s_t, a_t))
        value = actor_critic.value(s_t)
        Qval = r_tt + GAMMA * torch.clone(Qval)
        
        advantage = Qval - value
        
        actor_loss = (-log_prob * advantage)
        critic_loss = 0.5 * advantage.pow(2)
        
        loss += actor_loss + critic_loss
        
    loss = loss/len_loss

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

In [None]:
MAX_EPISODES = 10000
MAX_TIMESTEPS = 1000

ALPHA = 3e-4 # learning rate
GAMMA = 0.99 # step-size

env = MazeSolverEnv()

def train_ActorCritic(env):

    num_actions = env.num_action
    num_states = env.num_obs

    actor_critic = ActorCritic(num_states[0], num_states[1], num_actions).to(device)

    optimizer = optim.Adam(actor_critic.parameters(), lr=ALPHA)

    try:
        returns = []

        for i_episode in range(MAX_EPISODES):

            state = env.init_obs
            init_state = state

            done = False

            states = []
            actions = []
            rewards = []   # no reward at t = 0

            #while not done:
            for timesteps in range(MAX_TIMESTEPS):

                states.append(state)

                action = actor_critic.get_action(state)
                actions.append(action)

                state, reward, done, _ = env.step(action.tolist())
                rewards.append(reward)

                if done or timesteps == MAX_TIMESTEPS-1:
                    last_state = state
                    break

            update_weight(optimizer, states, actions, rewards, last_state)

            '''
            #====================================================================================================
            #to see the change of the weights====================================================================
            #====================================================================================================
            print("=========================================================================================")
            print("actor_critic.actor_fc1.weight : {}".format(actor_critic.actor_conv1.weight))
            print("actor_critic.critic_fc2.weight : {}".format(actor_critic.critic_conv1.weight))
            #====================================================================================================
            '''

            if (i_episode + 1) % 500 == 0:
                print("Episode {} return {}".format(i_episode + 1, sum(rewards)))
                returns.append(sum(rewards))
                torch.save(actor_critic, './saved_models/actor_critic' + str(i_episode + 1) + '.pt')
            elif (i_episode + 1) % 10 == 0:
                print("Episode {} return {}".format(i_episode + 1, sum(rewards)))
                returns.append(sum(rewards))

            # every 10th episodes, turn off the exploring starts
            if (i_episode + 2) % 10 == 0:
                env.reset_player(exploring_starts = False)
            else:
                env.reset_player(exploring_starts = True)

    except KeyboardInterrupt:
        plt.plot([i for i in range(10, len(returns)+1, 10)], returns)
    finally:
        plt.plot([i for i in range(10, len(returns)+1, 10)], returns)

    env.close()


Episode 10 return -536.0000000000069
Episode 20 return -543.0000000000064
Episode 30 return -592.0000000000024
Episode 40 return -551.0000000000061
Episode 50 return -547.500000000007
Episode 60 return -510.5000000000088
Episode 70 return -563.0000000000043
Episode 80 return -537.0000000000063
Episode 90 return -560.5000000000048
Episode 100 return -552.5000000000052
Episode 110 return -546.0000000000059
Episode 120 return -554.5000000000052
Episode 130 return -524.5000000000076
Episode 140 return -526.5000000000075
Episode 150 return -505.00000000000875
Episode 160 return -639.0
Episode 170 return -721.4999999999958
Episode 180 return -789.999999999993
