In [13]:
import gymnasium as gym
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.distributions.utils import probs_to_logits
import math, random
import numpy as np

env_name = "CartPole-v1"
env = gym.make(env_name)
s_size = env.observation_space.shape[0]
a_size = env.action_space.n
env.close()

learning_rate = 3e-4
gamma         = 0.99
lmbda         = 0.95
eps_clip      = 0.1
K_epoch       = 3
T_horizon     = 500

class Categorical:
    def __init__(self, probs_shape):
        # NOTE: probs_shape is supposed to be
        #       the shape of probs that will be
        #       produced by policy network
        if len(probs_shape) < 1:
            raise ValueError("`probs_shape` must be at least 1.")
        self.probs_dim = len(probs_shape)
        self.probs_shape = probs_shape
        self._num_events = probs_shape[-1]
        self._batch_shape = probs_shape[:-1] if self.probs_dim > 1 else torch.Size()
        self._event_shape=torch.Size()

    def set_probs(self, probs):
        # normalized the probs
        self.probs = probs / probs.sum(-1, keepdim=True)
        # log probabilities
        # domain range changed from [0, 1] -> [-inf, inf]
        self.logits = probs_to_logits(self.probs)

    def sample(self, sample_shape=torch.Size()):
        if not isinstance(sample_shape, torch.Size):
            sample_shape = torch.Size(sample_shape)
        # reshape the probs to 2D
        probs_2d = self.probs.reshape(-1, self._num_events)
        # for each row, return n results with replacement, n == 1 result in my case
        samples_2d = torch.multinomial(probs_2d, sample_shape.numel(), True).T
        # reshape the results to specified shape
        return samples_2d.reshape(sample_shape + self._batch_shape + self._event_shape)

    def log_prob(self, value):
        value = value.long().unsqueeze(-1)
        # make value and logits have matched shape
        value, log_pmf = torch.broadcast_tensors(value, self.logits)
        value = value[..., :1]
        # for each row, return log_pmf[value[row]]
        return log_pmf.gather(-1, value).squeeze(-1)

    def entropy(self):
        # to avoid large negative log probability when log(0) occurred
        # we use "eps" instead of "min" here
        min_real = torch.finfo(self.logits.dtype).min
        logits = torch.clamp(self.logits, min=min_real)
        # entropy
        p_log_p = logits * self.probs
        return -p_log_p.sum(-1)
    
class Memory:
    def __init__(self, max_seq_len, exps=None):
        self.max_seq_len = max_seq_len
        if exps is None:
            self.init_exps()
        else:
            self.exps = exps

    def store(self, state, action, prob, reward, state_prime, h_in, h_out, done, timestep):
        self.exps["states"].append(state)
        self.exps["actions"].append(action)
        self.exps["probs"].append(prob)
        self.exps["rewards"].append(reward)
        self.exps["states_prime"].append(state_prime)
        self.exps["h_ins"].append(h_in)
        self.exps["h_outs"].append(h_out)
        self.exps["dones"].append(done)
        self.exps["timesteps"].append(timestep)

    def init_exps(self):
        self.exps = {
            "states": [],
            "actions": [], # sampled action at s_t
            "probs": [],
            "rewards": [], # immediate reward when leaving current s_t
            "states_prime": [],
            "h_ins": [],
            "h_outs": [],
            "dones": [],
            "timesteps": []
        }

    def get_current_size(self):
        return len(self.exps["states"])
    
class Policy(nn.Module):
    def __init__(self):
        super().__init__()        
        self.fc1   = nn.Linear(s_size,64)
        self.lstm = nn.LSTM(64,32)
        self.fc_pi = nn.Linear(32,a_size)
        self.fc_v  = nn.Linear(32,1)
        self.optimizer = optim.Adam(self.parameters(), lr=learning_rate)

    def pi(self, x, hidden):
        x = F.relu(self.fc1(x))
        x = x.view(-1, 1, 64)
        x, lstm_hidden = self.lstm(x, hidden)
        x = self.fc_pi(x)
        prob = F.softmax(x, dim=-1)
        return prob, lstm_hidden
    
    def v(self, x, hidden):
        x = F.relu(self.fc1(x))
        x = x.view(-1, 1, 64)
        x, lstm_hidden = self.lstm(x, hidden)
        v = self.fc_v(x)
        return v

class Agent:
    def __init__(self):
        self.policy = Policy()
        self.memory = Memory(T_horizon)
        self.dist = Categorical((a_size, ))
    
    def sample_action(self, s, h_in):
        pi, h_out = self.policy.pi(torch.from_numpy(s).float(), h_in)
        self.dist.set_probs(pi)
        action = self.dist.sample()
        return action.item(), pi, h_out
    
    def make_batch(self):
        s_lst, a_lst, r_lst, s_prime_lst, prob_a_lst, done_lst, h_in_lst, h_out_lst = \
        self.memory.exps["states"], self.memory.exps["actions"], self.memory.exps["rewards"], \
        self.memory.exps["states_prime"], self.memory.exps["probs"], self.memory.exps["dones"], \
        self.memory.exps["h_ins"], self.memory.exps["h_outs"]
            
        done_lst = [0 if done else 1 for done in done_lst]            
        s,a,r,s_prime,done_mask, prob_a = torch.tensor(s_lst, dtype=torch.float).view(-1, s_size), torch.tensor(a_lst).view(-1, 1), \
                                          torch.tensor(r_lst).view(-1, 1), torch.tensor(s_prime_lst, dtype=torch.float).view(-1, s_size), \
                                          torch.tensor(done_lst, dtype=torch.float).view(-1, 1), torch.tensor(prob_a_lst).view(-1, 1)
        return s, a, r, s_prime, done_mask, prob_a, h_in_lst[0], h_out_lst[0]

    def cal_advantage(self, s, r, s_prime, done_mask, first_hidden, second_hidden):
        v_prime = self.policy.v(s_prime, second_hidden).squeeze(1)
        td_target = r + gamma * v_prime * done_mask
        v_s = self.policy.v(s, first_hidden).squeeze(1)
        delta = td_target - v_s
        delta = delta.detach().numpy()

        advantage_lst = []
        advantage = 0.0
        for delta_t in delta[::-1]:
            advantage = gamma * lmbda * advantage + delta_t[0]
            advantage_lst.append([advantage])
        advantage_lst.reverse()
        advantage = torch.tensor(advantage_lst, dtype=torch.float)
        return advantage, td_target
        
    def learn(self):
        s,a,r,s_prime,done_mask, prob_a, (h1_in, c1_in), (h2_out, c2_out) = self.make_batch()
        first_hidden  = (h1_in.detach(), c1_in.detach())
        second_hidden = (h2_out.detach(), c2_out.detach())
        for i in range(K_epoch):
            advantage, td_target = self.cal_advantage(s, r, s_prime, done_mask, first_hidden, second_hidden)

            pi, _ = self.policy.pi(s, first_hidden)
            print(pi.shape)
            pi_a = pi.squeeze(1).gather(1,a)
            ratio = torch.exp(torch.log(pi_a) - torch.log(prob_a))  # a/b == exp(log(a)-log(b))

            surr1 = ratio * advantage
            surr2 = torch.clamp(ratio, 1 - eps_clip, 1 + eps_clip) * advantage
            loss = -torch.min(surr1, surr2) + F.smooth_l1_loss(self.policy.v(s, first_hidden).squeeze(1) , td_target.detach())

            self.policy.optimizer.zero_grad()
            loss.mean().backward(retain_graph=True)
            self.policy.optimizer.step()
        
        self.memory.init_exps()

In [14]:
num_eps = 500
score_avg_interval = 20
delay = 0

env = gym.make(env_name)
model = Agent()
score = 0.0

for ep in range(1, num_eps + 1):
    s, info = env.reset()
    h_out = (torch.zeros([1, 1, 32], dtype=torch.float), torch.zeros([1, 1, 32], dtype=torch.float)) # hidden_state, cell_state
    action_queue = [env.action_space.sample() for _ in range(delay)]
    done = False
    
    while not done:
        for t in range(T_horizon):
            h_in = h_out
            a, prob, h_out = model.sample_action(s, h_in)
            prob = prob.view(-1)
            action_queue.append(a)

            delay_a = action_queue.pop(0)
            s_prime, r, terminated, truncated, info = env.step(delay_a)
            done = terminated or truncated
            model.memory.store(s, delay_a, prob[delay_a].item(), r/100.0, s_prime, h_in, h_out, done, t)
            s = s_prime
            score += r
            
            if done:
                break
        model.learn()

    if ep % score_avg_interval == 0:
        print(f"Ep. {ep - score_avg_interval + 1} ~ {ep}", end=", ")
        print(f"avg score : {score / score_avg_interval:.1f}")
        score = 0

env.close()
print("Finished.")

torch.Size([14, 1, 2])
torch.Size([14, 1, 2])
torch.Size([14, 1, 2])
torch.Size([39, 1, 2])
torch.Size([39, 1, 2])
torch.Size([39, 1, 2])
torch.Size([22, 1, 2])
torch.Size([22, 1, 2])
torch.Size([22, 1, 2])
torch.Size([13, 1, 2])
torch.Size([13, 1, 2])
torch.Size([13, 1, 2])
torch.Size([13, 1, 2])
torch.Size([13, 1, 2])
torch.Size([13, 1, 2])
torch.Size([24, 1, 2])
torch.Size([24, 1, 2])
torch.Size([24, 1, 2])
torch.Size([11, 1, 2])
torch.Size([11, 1, 2])
torch.Size([11, 1, 2])
torch.Size([12, 1, 2])
torch.Size([12, 1, 2])
torch.Size([12, 1, 2])
torch.Size([14, 1, 2])
torch.Size([14, 1, 2])
torch.Size([14, 1, 2])
torch.Size([27, 1, 2])
torch.Size([27, 1, 2])
torch.Size([27, 1, 2])
torch.Size([21, 1, 2])
torch.Size([21, 1, 2])
torch.Size([21, 1, 2])
torch.Size([18, 1, 2])
torch.Size([18, 1, 2])
torch.Size([18, 1, 2])
torch.Size([34, 1, 2])
torch.Size([34, 1, 2])
torch.Size([34, 1, 2])
torch.Size([20, 1, 2])
torch.Size([20, 1, 2])
torch.Size([20, 1, 2])
torch.Size([12, 1, 2])
torch.Size(

torch.Size([35, 1, 2])
torch.Size([35, 1, 2])
torch.Size([35, 1, 2])
torch.Size([42, 1, 2])
torch.Size([42, 1, 2])
torch.Size([42, 1, 2])
torch.Size([60, 1, 2])
torch.Size([60, 1, 2])
torch.Size([60, 1, 2])
torch.Size([66, 1, 2])
torch.Size([66, 1, 2])
torch.Size([66, 1, 2])
torch.Size([84, 1, 2])
torch.Size([84, 1, 2])
torch.Size([84, 1, 2])
torch.Size([56, 1, 2])
torch.Size([56, 1, 2])
torch.Size([56, 1, 2])
torch.Size([127, 1, 2])
torch.Size([127, 1, 2])
torch.Size([127, 1, 2])
torch.Size([34, 1, 2])
torch.Size([34, 1, 2])
torch.Size([34, 1, 2])
torch.Size([49, 1, 2])
torch.Size([49, 1, 2])
torch.Size([49, 1, 2])
torch.Size([60, 1, 2])
torch.Size([60, 1, 2])
torch.Size([60, 1, 2])
torch.Size([18, 1, 2])
torch.Size([18, 1, 2])
torch.Size([18, 1, 2])
torch.Size([145, 1, 2])
torch.Size([145, 1, 2])
torch.Size([145, 1, 2])
torch.Size([15, 1, 2])
torch.Size([15, 1, 2])
torch.Size([15, 1, 2])
torch.Size([18, 1, 2])
torch.Size([18, 1, 2])
torch.Size([18, 1, 2])
torch.Size([11, 1, 2])
torch

KeyboardInterrupt: 

In [None]:
num_test_eps = 10
env = gym.make(env_name)
h_out = (torch.zeros([1, 1, 32], dtype=torch.float), torch.zeros([1, 1, 32], dtype=torch.float)) # hidden_state, cell_state

for ep in range(1, num_test_eps + 1):
    s, info = env.reset()
    action_queue = [env.action_space.sample() for _ in range(delay)]
    done = False
    score = 0
    
    while not done:
        h_in = h_out
        a, prob, h_out = model.sample_action(s, h_in)
        prob = prob.view(-1)
        action_queue.append(a)

        delay_a = action_queue.pop(0)
        s_prime, r, terminated, truncated, info = env.step(delay_a)
        done = terminated or truncated
        s = s_prime
        score += r

        if done:
            break
    
    print(f"Ep. {ep}, score : {score}")

env.close()
print("Finished.")