In [None]:
import os
import gym
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.distributions import Categorical
import threading
from threading import Lock
from matplotlib import pylab
import time
import matplotlib.pyplot as plt

class OurModel(nn.Module):
    def __init__(self, input_shape, action_space):
        super(OurModel, self).__init__()
        self.flatten = nn.Flatten()
        self.fc1 = nn.Linear(np.prod(input_shape), 512)
        self.fc2_action = nn.Linear(512, action_space)
        self.fc2_value = nn.Linear(512, 1)
        self.elu = nn.ELU()

    def forward(self, x):
        x = self.flatten(x)
        x = self.elu(self.fc1(x))
        action = nn.Softmax(dim=-1)(self.fc2_action(x))
        value = self.fc2_value(x)
        return action, value

class A3CAgent:
    def __init__(self, env_name):
        self.env_name = env_name
        self.env = gym.make(env_name)
        self.action_size = self.env.action_space.n
        self.EPISODES, self.episode, self.max_average = 1000, 0, -21.0
        self.lock = Lock()
        self.state_size = self.env.observation_space.shape

        self.ROWS = 80
        self.COLS = 80
        self.REM_STEP = 4
        self.lr = 0.000025

        self.scores, self.episodes, self.average = [], [], []

        self.Save_Path = 'Models'
        if not os.path.exists(self.Save_Path):
            os.makedirs(self.Save_Path)
        self.path = '{}_A3C_{}'.format(self.env_name, self.lr)
        self.Model_name = os.path.join(self.Save_Path, self.path)

        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.Actor = OurModel(input_shape=self.state_size, action_space=self.action_size).to(self.device)
        self.Critic = OurModel(input_shape=self.state_size, action_space=1).to(self.device)

        self.actor_optimizer = optim.Adam(self.Actor.parameters(), lr=self.lr)
        self.critic_optimizer = optim.Adam(self.Critic.parameters(), lr=self.lr)

    def discount_rewards(self, rewards):
        gamma = 0.99
        discounted_r = np.zeros_like(rewards)
        R = 0
        for i in reversed(range(len(rewards))):
            if rewards[i] != 0:
                R = 0
            R = R * gamma + rewards[i]
            discounted_r[i] = R

        discounted_r = torch.tensor(discounted_r, dtype=torch.float32).to(self.device)
        discounted_r = (discounted_r - discounted_r.mean()) / (discounted_r.std() + 1e-5)
        return discounted_r

    def replay(self, states, actions, rewards):
        states = torch.stack([torch.tensor(s, dtype=torch.float32) for s in states]).to(self.device)
        actions = torch.tensor(actions, dtype=torch.long).to(self.device)
        rewards = self.discount_rewards(rewards)

        action_probs, values = self.Actor(states)
        values = values.squeeze()
        advantages = rewards - values.detach()

        dist = Categorical(action_probs)
        action_log_probs = dist.log_prob(actions)
        actor_loss = -(action_log_probs * advantages).mean()
        critic_loss = F.mse_loss(values, rewards)

        self.actor_optimizer.zero_grad()
        actor_loss.backward(retain_graph=True)
        self.actor_optimizer.step()

        self.critic_optimizer.zero_grad()
        critic_loss.backward(retain_graph=True)
        self.critic_optimizer.step()

    def act(self, state):
        state = torch.from_numpy(state).float().unsqueeze(0).to(self.device)
        with torch.no_grad():
            action_probs, _ = self.Actor(state)
        dist = Categorical(action_probs)
        action = dist.sample()
        return action.item()

    def load(self, actor_path, critic_path):
        self.Actor.load_state_dict(torch.load(actor_path))
        self.Critic.load_state_dict(torch.load(critic_path))

    def save(self):
        torch.save(self.Actor.state_dict(), '/content/A3C_Actor.pt')
        torch.save(self.Critic.state_dict(), '/content/A3C_Critic.pt')

    pylab.figure(figsize=(18, 9))
    def PlotModel(self, score, episode):
        self.scores.append(score)
        self.episodes.append(episode)
        self.average.append(sum(self.scores[-50:]) / len(self.scores[-50:]))
        pylab.plot(self.episodes, self.average, 'r')
        pylab.plot(self.episodes, self.scores, 'b')
        pylab.ylabel('Score', fontsize=18)
        pylab.xlabel('Steps', fontsize=18)
        try:
            pylab.savefig(self.path + ".png")
        except OSError:
            pass

        return float(self.average[-1])  # Convert average to float

    def run(self):
        for e in range(self.EPISODES):
            state = self.env.reset()
            done = False
            score = 0
            states, actions, rewards = [], [], []
            while not done:
                action = self.act(state)
                next_state, reward, done, info = self.env.step(action)
                states.append(state)
                actions.append(action)
                rewards.append(reward)
                state = next_state
                score += reward
            self.replay(states, actions, rewards)
            print(f"Episode {e+1}/{self.EPISODES}: Score = {score}")
        self.env.close()

    def train(self, n_threads):
        self.env.close()
        envs = [gym.make(self.env.unwrapped.spec.id) for _ in range(n_threads)]
        threads = [
            threading.Thread(target=self.train_threading, daemon=True, args=(self, envs[i], i))
            for i in range(n_threads)
        ]
        for t in threads:
            time.sleep(2)
            t.start()

    def train_threading(self, agent, env, thread):
        while self.episode < self.EPISODES:
            score, done, SAVING = 0, False, ''
            state = env.reset()
            states, actions, rewards = [], [], []
            while not done:
                action = agent.act(state)
                next_state, reward, done, _ = env.step(action)
                states.append(state)
                actions.append(action)
                rewards.append(reward)
                score += reward
                state = next_state

            self.lock.acquire()
            self.replay(states, actions, rewards)
            self.lock.release()

            with self.lock:
                average = self.PlotModel(score, self.episode)
                if average >= self.max_average:
                    self.max_average = average
                    self.save()
                    SAVING = "SAVING"
                else:
                    SAVING = ""
                print(f"episode: {self.episode}/{self.EPISODES}, thread: {thread}, score: {score}, average: {average:.2f} {SAVING}")
                if self.episode < self.EPISODES:
                    self.episode += 1
        env.close()


    def test(self, Actor_name, Critic_name):
        self.load(Actor_name, Critic_name)
        self.Actor.eval()
        self.Critic.eval()

        for e in range(self.EPISODES):
            state = self.env.reset()
            state = np.zeros(self.state_size)
            state = torch.FloatTensor(state).to(self.device)
            done = False
            i = 0
            img = plt.imshow(self.env.render(mode='rgb_array'))
            text_action = plt.text(0, -10, '', fontsize=12, color='red')
            text_reward = plt.text(100, -10, '', fontsize=12, color='blue')
            text_step = plt.text(300, -10, '', fontsize=12, color='green')
            text_score = plt.text(500, -10, '', fontsize=12, color='blue')

            while not done:
                img.set_data(self.env.render(mode='rgb_array'))

                with torch.no_grad():
                    action_probs = self.Actor(state.unsqueeze(0))
                    action = torch.argmax(action_probs[0], dim=1).item()

                next_state, reward, done, info = self.env.step(action)
                next_state = torch.FloatTensor(next_state).to(self.device)

                text_action.set_text(f'Action: {action}')
                text_reward.set_text(f'Reward: {reward}')
                text_step.set_text(f'Step: {i}')
                text_score.set_text(f'Score: {e}')

                state = next_state
                i += 1

                plt.pause(0.01)
                plt.draw()

                if done:
                    print("episode: {}/{}, score: {}".format(e, self.EPISODES, i))
                    break




In [None]:
agent = A3CAgent('CartPole-v1')
agent.run()


In [None]:
agent.test('/content/A3C_Actor.pt','/content/A3C_Critic.pt')