In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import gym
import os
import cv2
import matplotlib.pyplot as plt
import pygame
from IPython import display

class OurModel(nn.Module):
    def __init__(self, input_shape, action_space):
        super(OurModel, self).__init__()
        self.flatten = nn.Flatten()
        self.fc1 = nn.Linear(np.prod(input_shape), 512)
        self.fc2_action = nn.Linear(512, action_space)
        self.fc2_value = nn.Linear(512, 1)
        self.elu = nn.ELU()

    def forward(self, x):
        x = self.flatten(x)
        x = self.elu(self.fc1(x))
        action = nn.Softmax(dim=-1)(self.fc2_action(x))
        value = self.fc2_value(x)
        return action, value

class A2CAgent:
    def __init__(self, env_name):
        self.env_name = env_name
        self.env = gym.make(env_name)
        self.action_size = self.env.action_space.n
        self.EPISODES, self.max_average = 1000, -21.0
        self.lr = 0.000025
        self.state_size = self.env.observation_space.shape

        self.ROWS = 80
        self.COLS = 80
        self.REM_STEP = 4

        pygame.init()
        self.screen = pygame.display.set_mode((640, 480))
        self.clock = pygame.time.Clock()
        self.metadata = {"render_fps": 30}

        self.states, self.actions, self.rewards = [], [], []
        self.scores, self.episodes, self.average = [], [], []

        self.Save_Path = 'Models'
        self.image_memory = np.zeros(self.state_size)

        if not os.path.exists(self.Save_Path): os.makedirs(self.Save_Path)
        self.path = '{}_A2C_{}'.format(self.env_name, self.lr)
        self.Model_name = os.path.join(self.Save_Path, self.path)

        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.model = OurModel(input_shape=self.state_size, action_space=self.action_size).to(self.device)
        self.optimizer = optim.RMSprop(self.model.parameters(), lr=self.lr)

    def remember(self, state, action, reward):
        self.states.append(state)
        action_onehot = np.zeros([self.action_size])
        action_onehot[action] = 1
        self.actions.append(action_onehot)
        self.rewards.append(reward)

    def act(self, state):
        state = torch.FloatTensor(state).to(self.device)
        state = state.unsqueeze(0).to(self.device)  # Add batch dimension
        action_probs, _ = self.model(state)
        action_probs = action_probs.cpu().detach().numpy()[0]

        # Check for NaNs
        if np.any(np.isnan(action_probs)):
            print("NaNs detected in action_probs:", action_probs)
            action_probs = np.nan_to_num(action_probs, nan=1.0 / self.action_size)  # Replace NaNs with equal probabilities

        # Ensure probabilities sum to 1
        action_probs = action_probs / np.sum(action_probs)

        # Clip the probabilities to ensure valid range
        action_probs = np.clip(action_probs, 1e-10, 1.0)
        action_probs = action_probs / np.sum(action_probs)  # Normalize again after clipping

        action = np.random.choice(self.action_size, p=action_probs)
        return action

    def discount_rewards(self, reward):
        gamma = 0.99
        running_add = 0
        discounted_r = np.zeros_like(reward)
        for i in reversed(range(len(reward))):
            if reward[i] != 0:
                running_add = 0
            running_add = running_add * gamma + reward[i]
            discounted_r[i] = running_add

        discounted_r -= np.mean(discounted_r)
        discounted_r /= np.std(discounted_r)
        return discounted_r

    def replay(self):
        states = torch.FloatTensor(np.vstack(self.states)).to(self.device)
        actions = torch.FloatTensor(np.vstack(self.actions)).to(self.device)
        discounted_r = torch.FloatTensor(self.discount_rewards(self.rewards)).to(self.device)

        self.model.train()
        self.optimizer.zero_grad()
        action_probs, values = self.model(states)
        values = values.squeeze()

        advantages = discounted_r - values
        critic_loss = advantages.pow(2).mean()
        action_log_probs = torch.log(action_probs)
        actor_loss = -(action_log_probs * actions).sum(dim=1) * advantages
        actor_loss = actor_loss.mean()
        loss = actor_loss + critic_loss
        loss.backward()
        self.optimizer.step()

        self.states, self.actions, self.rewards = [], [], []

    def load(self, model_path):
        self.model.load_state_dict(torch.load(model_path))
        self.model.eval()

    def save(self):
        torch.save(self.model.state_dict(),'/content/A2C_model.pt')

    def plot_model(self, score, episode):
        self.scores.append(score)
        self.episodes.append(episode)
        self.average.append(sum(self.scores[-50:]) / len(self.scores[-50:]))
        if str(episode)[-2:] == "00":
            plt.plot(self.episodes, self.scores, 'b')
            plt.plot(self.episodes, self.average, 'r')
            plt.ylabel('Score', fontsize=18)
            plt.xlabel('Steps', fontsize=18)
            plt.savefig(self.path + ".png")
        return self.average[-1]

    def run(self):
        for e in range(self.EPISODES):
            state = self.env.reset()
            done, score, SAVING = False, 0, ''
            while not done:
                action = self.act(state)
                next_state, reward, done, _ = self.env.step(action)
                self.remember(state, action, reward)
                state = next_state
                score += reward
                if done:
                    average = self.plot_model(score, e)
                    if e == (self.EPISODES - 1):
                        print("Saving trained model")
                        self.save()
                    else:
                        SAVING = ""
                    print("episode: {}/{}, score: {}, average: {:.2f} {}".format(e, self.EPISODES, score, average, SAVING))
                    self.replay()
        self.env.close()

    def test(self, model_path):
     self.load(model_path)
     for e in range(self.EPISODES):
        state = self.env.reset()
        state = np.reshape(state, (1, *self.state_size))  # Unpack the tuple correctly
        state = torch.FloatTensor(state).to(self.device)  # Convert state to PyTorch tensor
        done = False
        i = 0
        img = plt.imshow(self.env.render(mode='rgb_array'))
        text_action = plt.text(0, -10, '', fontsize=12, color='red')
        text_reward = plt.text(100, -10, '', fontsize=12, color='blue')
        text_step = plt.text(300, -10, '', fontsize=12, color='green')
        text_score = plt.text(500, -10, '', fontsize=12, color='blue')

        while not done:
            #self.env.render()
            img.set_data(self.env.render(mode='rgb_array'))  # Update the data

            action_probs, _ = self.model(state)
            action = np.argmax(action_probs.cpu().detach().numpy())
            next_state, reward, done, _ = self.env.step(action)
            text_action.set_text(f'Action: {action}')
            text_reward.set_text(f'Reward: {reward}')
            text_step.set_text(f'Step: {i}')
            text_score.set_text(f'Score: {e}')

            state = np.reshape(next_state, (1, *self.state_size))  # Unpack the tuple correctly
            state = torch.FloatTensor(state).to(self.device)  # Convert next_state to PyTorch tensor
            i += 1

            plt.pause(1)
            display.display(plt.gcf())
            display.clear_output(wait=True)

            if done:
                print("episode: {}/{}, score: {}".format(e, self.EPISODES, i))
                break


In [None]:
agent = A2CAgent('CartPole-v1')
agent.run()


In [None]:
agent.test()