In [1]:
import gym
import random
import numpy as np
from collections import deque
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.optim as optim
import torch.distributions as distributions
import torch.nn.functional as F

In [2]:
class A2C(nn.Module):
    def __init__(self, state_size, action_size):
        super(A2C, self).__init__()
        self.common = nn.Linear(state_size,64)
        self.actor  = nn.Linear(64,action_size)
        self.critic = nn.Linear(64,1)
        # Random Uniform
        torch.nn.init.uniform_(self.common.weight,-1e-3,1e-3)
        torch.nn.init.uniform_(self.actor.weight,-1e-3,1e-3)
        torch.nn.init.uniform_(self.critic.weight,-1e-3,1e-3)
        
    def forward(self, x):
        x       = F.relu(self.common(x))
        policy  = F.softmax(self.actor(x))
        value   = self.critic(x)
        return policy, value

In [3]:
class A2CAgent:
    def __init__(self, state_size, action_size, device):
        self.state_size = state_size
        self.action_size= action_size
        self.device = device
        
        # Hyper params for learning
        self.discount_factor = 0.99
        self.learning_rate = 0.001
        
        self.model = A2C(self.state_size,self.action_size).to(self.device)
        self.optimizer = optim.Adam(self.model.parameters(), lr=self.learning_rate)

    def get_action(self, state):
        policy_dist, _ = self.model.forward(state)
        dist = distributions.Categorical(policy_dist)
        action = dist.sample()
        return action
        
    def train_model(self, state, action, reward, next_state, done):
        policy_dist, value  = self.model.forward(state)
        next_state = torch.FloatTensor([next_state]).to(device)
        _,      next_value  = self.model.forward(next_state)
        target = reward + (1 - done) * self.discount_factor * next_value[0]
        
        # For policy network
        dist = distributions.Categorical(policy_dist)
        action = dist.sample()
        cross_entropy = - dist.log_prob(action)
        advantage = (target - value[0]).detach()
        actor_loss = (cross_entropy * advantage).sum()
        
        # For value network
        critic_loss = F.smooth_l1_loss(target.detach(),value[0][0]).sum()
        
        # integrate losses
        loss = 0.2 * actor_loss + critic_loss
        
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

        return loss.detach().cpu().numpy()

In [4]:
%matplotlib tk

ENV_NAME = 'CartPole-v1'
EPISODES = 1000
END_SCORE = 400
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print("DEVICE : ", device)

if __name__ == "__main__":
    env = gym.make(ENV_NAME)
    state_size = env.observation_space.shape[0]
    action_size = env.action_space.n

    agent = A2CAgent(state_size, action_size, device)
    print('Env Name : ',ENV_NAME)
    print('States {}, Actions {}'
            .format(state_size, action_size))

    scores, episodes, losses = [], [], []
    score_avg = 0
    
    end = False
    
    for e in range(EPISODES):
        # Episode initialization
        done = False
        score = 0
        loss_list = []
        
        state = env.reset()
        state = np.reshape(state, [1, state_size])
        
        while not done:
            # env.render()

            # Interact with env.
            state = torch.FloatTensor([state]).to(device)
            action = agent.get_action(state)
            next_state, reward, done, info = env.step(action.item())
            loss = agent.train_model(state, action, reward, next_state, done)
            state = next_state

            # 
            score += reward
            loss_list.append(loss)
            if done:
                # agent.update_target_model()
                
                score_avg = 0.9 * score_avg + 0.1 * score if score_avg != 0 else score
                print('epi: {:3d} | score avg {:3.2f} | loss: {:.4f}'.format(e, score_avg, np.mean(loss_list)))

                # Save data for plot
                scores.append(score_avg)
                episodes.append(e)
                losses.append(np.mean(loss_list))
                
                
                # View data
                if (e%10 == 0):
                    plt.subplot(211)
                    plt.plot(episodes, scores, 'b')
                    plt.xlabel('episode')
                    plt.ylabel('average score')
                    plt.title('cartpole A2C TORCH')
                    plt.grid()
                    
                    plt.subplot(212)
                    plt.plot(episodes, losses, 'b')
                    plt.xlabel('episode')
                    plt.ylabel('loss')
                    plt.grid()
                    
                    plt.savefig('./save_model/cartpole_a2c_TORCH.png')

                if score_avg > END_SCORE:
                    agent.model.save_weights('./save_model/cartpole_a2c_TORCH', save_format='tf')
                    end = True
                    break
        if end == True:
            np.save('./save_model/data/cartpole_a2c_TORCH_epi',episodes)
            np.save('./save_model/data/cartpole_a2c_TORCH_score',scores)
            np.save('./save_model/data/cartpole_a2c_TORCH_loss',losses)
            env.close()
            print("End")
            break

DEVICE :  cuda
Env Name :  CartPole-v1
States 4, Actions 2
epi:   0 | score avg 33.00 | loss: 0.6318
epi:   1 | score avg 31.90 | loss: 0.6366
epi:   2 | score avg 32.31 | loss: 0.6812
epi:   3 | score avg 30.48 | loss: 0.7259
epi:   4 | score avg 29.33 | loss: 0.5804
epi:   5 | score avg 28.40 | loss: 0.8328
epi:   6 | score avg 28.86 | loss: 0.7155
epi:   7 | score avg 28.67 | loss: 0.9045
epi:   8 | score avg 27.01 | loss: 1.5422
epi:   9 | score avg 25.40 | loss: 0.6226
epi:  10 | score avg 24.66 | loss: 0.7873
epi:  11 | score avg 23.20 | loss: 2.4145
epi:  12 | score avg 22.18 | loss: 0.7634
epi:  13 | score avg 21.06 | loss: 1.7681
epi:  14 | score avg 21.25 | loss: 0.6855
epi:  15 | score avg 20.43 | loss: 0.8449
epi:  16 | score avg 19.89 | loss: 1.6943
epi:  17 | score avg 19.10 | loss: 1.3196
epi:  18 | score avg 18.49 | loss: 2.1287
epi:  19 | score avg 17.94 | loss: 1.9902
epi:  20 | score avg 17.74 | loss: 1.0581
epi:  21 | score avg 16.97 | loss: 1.5501
epi:  22 | score 

KeyboardInterrupt: 