In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.distributions import Categorical

import gymnasium as gym
import numpy as np
from collections import deque

In [2]:
# HyperParameters

gamma = 0.99
NUM_EPISODES = 10000
MAX_STEPS = 1000
early_stop = 200
lr = 0.001
print_step = 100

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [3]:
class actor_net(nn.Module):
    def __init__(self, observation_space, action_space):
        super(actor_net, self).__init__()
        self.input_layer = nn.Linear(observation_space, 128)
        self.output_layer = nn.Linear(128, action_space)

    def forward(self, x):
        x = F.relu(self.input_layer(x))
        actions = self.output_layer(x)
        action_probs = F.softmax(actions, dim = 1)

        return action_probs

    def select_action(self, s):
        s = torch.from_numpy(s).float().unsqueeze(0).to(DEVICE)
        action_probs = self.forward(s)
        a = Categorical(action_probs)
        action = a.sample()

        return action.item(), a.log_prob(action), a

In [4]:
class critic_net(nn.Module):
    def __init__(self, observation_space):
        super(critic_net, self).__init__()
        self.input_layer = nn.Linear(observation_space, 128)
        self.output_layer = nn.Linear(128, 1)

    def forward(self, x):
        x = F.relu(self.input_layer(x))
        value = self.output_layer(x)

        return value

In [5]:
# Make environment
env = gym.make('CartPole-v1')

# Initialize networks
actor = actor_net(env.observation_space.shape[0], env.action_space.n).to(DEVICE)
critic = critic_net(env.observation_space.shape[0]).to(DEVICE)

# Initialize optimizer
actor_optimizer = optim.Adam(actor.parameters(), lr = lr)
critic_optimizer = optim.Adam(critic.parameters(), lr = lr)

In [6]:
score = []
recent_score = deque(maxlen = 100)

# Training loop
for epoch in range(5000):
    state, _ = env.reset()
    done = False
    score = 0

    for step in range(MAX_STEPS):
        action, log_prob, m = actor.select_action(state)
        next_state, reward, done, trancated, _ = env.step(action)
        score += reward

        state_val_current = critic(torch.from_numpy(state).float().unsqueeze(0).to(DEVICE))
        state_val_next = critic(torch.from_numpy(next_state).float().unsqueeze(0).to(DEVICE))

        if done:
            state_val_next = torch.tensor([0]).float().unsqueeze(0).to(DEVICE)

        val_loss = F.mse_loss(reward + gamma * state_val_next.detach(), state_val_current) # critic_loss
        advantage = reward + gamma * state_val_next.item() - state_val_current.item()
        entropy = -m.entropy()
        policy_loss = 0.1 * (-log_prob * advantage) + 0.1 * entropy # actor_loss

        actor_optimizer.zero_grad()
        policy_loss.backward(retain_graph = True)
        actor_optimizer.step()

        critic_optimizer.zero_grad()
        val_loss.backward()
        critic_optimizer.step()

        if done:
            break
    
    if (epoch % print_step == 0 and epoch != 0):
        print(f"# an Episode : {epoch}, avg_score : {score}")
        score = 0

# an Episode : 100, avg_score : 9.0
# an Episode : 200, avg_score : 21.0
# an Episode : 300, avg_score : 46.0
# an Episode : 400, avg_score : 13.0
# an Episode : 500, avg_score : 9.0
# an Episode : 600, avg_score : 13.0
# an Episode : 700, avg_score : 28.0
# an Episode : 800, avg_score : 15.0
# an Episode : 900, avg_score : 22.0
# an Episode : 1000, avg_score : 14.0
# an Episode : 1100, avg_score : 8.0
# an Episode : 1200, avg_score : 12.0
# an Episode : 1300, avg_score : 14.0
# an Episode : 1400, avg_score : 9.0
# an Episode : 1500, avg_score : 10.0
# an Episode : 1600, avg_score : 8.0
# an Episode : 1700, avg_score : 10.0
# an Episode : 1800, avg_score : 9.0
# an Episode : 1900, avg_score : 9.0
# an Episode : 2000, avg_score : 9.0
# an Episode : 2100, avg_score : 9.0
# an Episode : 2200, avg_score : 9.0
# an Episode : 2300, avg_score : 9.0
# an Episode : 2400, avg_score : 8.0
# an Episode : 2500, avg_score : 10.0
# an Episode : 2600, avg_score : 9.0
# an Episode : 2700, avg_score : 1

In [7]:
import time
env = gym.make("CartPole-v1", render_mode = "human")
state, info = env.reset()

for i in range(500):
    action, lp, m = actor.select_action(state)
    state, reward, done, truncated, _ = env.step(action)

    env.render()

    time.sleep(0.01)

    if done:
        state, info = env.reset()

env.close()