In [33]:
import gymnasium as gym
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.distributions.utils import probs_to_logits
import torch.multiprocessing as mp
import math, random
import numpy as np

env_name = "CartPole-v1"
env = gym.make(env_name)
s_size = env.observation_space.shape[0]
a_size = env.action_space.n
env.close()

gamma         = 0.99
lmbda         = 0.95
eps_clip      = 0.1
T_horizon     = 500
K_epoch_policy = 3
K_epoch_pred_model = 10

class Categorical:
    def __init__(self, probs_shape):
        # NOTE: probs_shape is supposed to be
        #       the shape of probs that will be
        #       produced by policy network
        if len(probs_shape) < 1:
            raise ValueError("`probs_shape` must be at least 1.")
        self.probs_dim = len(probs_shape)
        self.probs_shape = probs_shape
        self._num_events = probs_shape[-1]
        self._batch_shape = probs_shape[:-1] if self.probs_dim > 1 else torch.Size()
        self._event_shape=torch.Size()

    def set_probs(self, probs):
        # normalized the probs
        self.probs = probs / probs.sum(-1, keepdim=True)
        # log probabilities
        # domain range changed from [0, 1] -> [-inf, inf]
        self.logits = probs_to_logits(self.probs)

    def sample(self, sample_shape=torch.Size()):
        if not isinstance(sample_shape, torch.Size):
            sample_shape = torch.Size(sample_shape)
        # reshape the probs to 2D
        probs_2d = self.probs.reshape(-1, self._num_events)
        # for each row, return n results with replacement, n == 1 result in my case
        samples_2d = torch.multinomial(probs_2d, sample_shape.numel(), True).T
        # reshape the results to specified shape
        return samples_2d.reshape(sample_shape + self._batch_shape + self._event_shape)

    def log_prob(self, value):
        value = value.long().unsqueeze(-1)
        # make value and logits have matched shape
        value, log_pmf = torch.broadcast_tensors(value, self.logits)
        value = value[..., :1]
        # for each row, return log_pmf[value[row]]
        return log_pmf.gather(-1, value).squeeze(-1)

    def entropy(self):
        # to avoid large negative log probability when log(0) occurred
        # we use "eps" instead of "min" here
        min_real = torch.finfo(self.logits.dtype).min
        logits = torch.clamp(self.logits, min=min_real)
        # entropy
        p_log_p = logits * self.probs
        return -p_log_p.sum(-1)

class PredictiveModel(nn.Module):
    def __init__(self, hidden_dim):
        super().__init__()
        state_dim = s_size
        action_dim = 1
        self.in_dim = state_dim + action_dim + hidden_dim
        self.out_dim = state_dim
        
        self.D = nn.Sequential(
            nn.Linear(self.in_dim, hidden_dim),
            nn.Tanh(),
            nn.Linear(hidden_dim, self.out_dim)
        )
        self.N = nn.Sequential(
            nn.Linear(self.in_dim, hidden_dim),
            nn.Tanh(),
            nn.Linear(hidden_dim, self.out_dim)
        )
        self.F = nn.Sequential(
            nn.Linear(self.in_dim, hidden_dim),
            nn.Tanh(),
            nn.Linear(hidden_dim, self.out_dim),
            nn.Sigmoid()
        )
        self.criterion = nn.MSELoss()
    
    def forward(self, state, action, gru_out):
        state = state.view(-1)
        action = action.view(-1)
        gru_out = gru_out.view(-1)
        
        x = torch.cat([state, action, gru_out], dim=-1)
        delta = self.D(x) # D(s, a)
        adjusted_state = state + delta # s + D(s, a)
        
        new_state = self.N(x) # N(s, a)
        
        forget_weights = self.F(x) # F(s, a), in [0,1]
        
        pred_s = forget_weights * adjusted_state + (1 - forget_weights) * new_state
        return pred_s.view(-1, 1, s_size)

class Memory:
    def __init__(self, max_seq_len, exps=None):
        self.keys = [
            "states", "actions", "probs", "rewards", 
            "states_prime", "h_ins", "h_outs", "dones", 
            "timesteps", "pred_s_tis", "a_lsts"
        ]
        self.max_seq_len = max_seq_len
        
        # used when copying other memory
        self.init_exps() if exps is None else exps

    def store(self, **kwargs):
        for key, value in kwargs.items():
            if key in self.exps:
                self.exps[key].append(value)
            else:
                raise KeyError(f"Invalid key '{key}' provided to store.")

    def init_exps(self):
        self.exps = {key: [] for key in self.keys}

    def get_current_size(self):
        return len(next(iter(self.exps.values()), []))
    

class Policy(nn.Module):
    def __init__(self, input_dim, hidden_dim):
        super().__init__()
        self.fc1   = nn.Linear(input_dim, hidden_dim)
        self.fc_pi = nn.Linear(hidden_dim, a_size)
        self.fc_v  = nn.Linear(hidden_dim, 1)

    def pi(self, x):
        x = F.relu(self.fc1(x))
        x = self.fc_pi(x)
        prob = F.softmax(x, dim=-1)
        return prob
    
    def v(self, x):
        x = F.relu(self.fc1(x))
        v = self.fc_v(x)
        return v

class RNN(nn.Module):
    def __init__(self, input_dim, hidden_dim, out_dim):
        super().__init__()
        self.hidden_dim = hidden_dim
        self.fc1   = nn.Linear(input_dim, hidden_dim)
        self.rnn = nn.LSTM(hidden_dim, out_dim)
        
    def forward(self, x, h):
        x = F.relu(self.fc1(x))
        x = x.view(-1, 1, self.hidden_dim)
        out, hidden = self.rnn(x, h)
        return out, hidden
    
class Agent:
    def __init__(self, p_iters=0, num_memos=10):
        self.p_iters = p_iters
        self.num_memos = num_memos
        
        self.rnn = RNN(s_size, 128, 64)
        self.pred_model = PredictiveModel(64)
        self.policy = Policy(64, 32)
        self.memory = [Memory(T_horizon) for _ in range(num_memos)]
        self.dist = Categorical((a_size, ))
        self.optim_pred_model = optim.Adam(
            [
                {"params": self.rnn.parameters()}, 
                {"params": self.pred_model.parameters()}
            ], 
            lr=3e-4
        )
        self.optim_policy = optim.Adam(
            [
                {"params": self.rnn.parameters()}, 
                {"params": self.policy.parameters()}
            ], 
            lr=3e-4
        )
    
    def sample_action(self, s, a_lst, h_in):
        s = torch.from_numpy(s).float()
        a_lst = torch.tensor(a_lst)
        o, h_out, pred_s = self.pred_present(s, a_lst, h_in)
        pi = self.policy.pi(o)
        self.dist.set_probs(pi)
        action = self.dist.sample()
        return action.item(), pi, h_out, pred_s
    
    def pred_present(self, s, a_lst, h_in):
        o_ti, h_first = self.rnn(s, h_in)
        
        s_ti = []
        pred_s = s
        h_ti = h_first
        for i in range(self.p_iters):
            pred_s = self.pred_model(pred_s, a_lst[i], o_ti)
            s_ti.append(pred_s.view(-1))
            o_ti, h_ti = self.rnn(pred_s, h_ti)
        s_ti = torch.stack(s_ti) if len(s_ti) > 0 else torch.tensor([])
            
        return o_ti, h_first, s_ti
    
    def make_batch(self, i):
        # retrieve from memory
        s_lst, a_lst, prob_a_lst, r_lst, s_prime_lst, h_in_lst, h_out_lst, done_lst, _, _, a_lst_lst = \
            map(lambda key: self.memory[i].exps[key], self.memory[i].keys)
        done_lst = [0 if done else 1 for done in done_lst]
        
        # reshape then return
        s,a,r,s_prime,done_mask, prob_a = \
            torch.tensor(s_lst, dtype=torch.float).view(-1, s_size), torch.tensor(a_lst).view(-1, 1), \
            torch.tensor(r_lst).view(-1, 1), torch.tensor(s_prime_lst, dtype=torch.float).view(-1, s_size), \
            torch.tensor(done_lst, dtype=torch.float).view(-1, 1), torch.tensor(prob_a_lst).view(-1, 1)
        a_lst = torch.tensor(a_lst_lst).view(-1, self.p_iters) if self.p_iters > 0 else torch.tensor([])
        return s, a, r, s_prime, done_mask, prob_a, h_in_lst[0], h_out_lst[0], a_lst

    def _pi(self, s, h):
        s = s.view(-1, 1, s_size)
        x, h = self.rnn(s, h)
        pi = self.policy.pi(x)
        return pi.squeeze(1), h
    
    def _v(self, s, h):
        s = s.view(-1, 1, s_size)
        x, _ = self.rnn(s, h)
        v = self.policy.v(x)
        return v.squeeze(1)
    
    def cal_advantage(self, s, r, s_prime, done_mask, first_hidden, second_hidden):
        v_prime = self._v(s_prime, second_hidden)
        td_target = r + gamma * v_prime * done_mask
        v_s = self._v(s, first_hidden)
        delta = td_target - v_s
        delta = delta.detach().numpy()

        advantage_lst = []
        advantage = 0.0
        for delta_t in delta[::-1]:
            advantage = gamma * lmbda * advantage + delta_t[0]
            advantage_lst.append([advantage])
        advantage_lst.reverse()
        advantage = torch.tensor(advantage_lst, dtype=torch.float)
        return advantage, td_target
        
    def learn_policy(self, s,a,r,s_prime,done_mask, prob_a, first_hidden, second_hidden):
        for i in range(K_epoch_policy):
            advantage, td_target = self.cal_advantage(s, r, s_prime, done_mask, first_hidden, second_hidden)

            pi, h = self._pi(s, first_hidden)
            pi_a = pi.squeeze(1).gather(1,a)
            ratio = torch.exp(torch.log(pi_a) - torch.log(prob_a))  # a/b == exp(log(a)-log(b))

            surr1 = ratio * advantage
            surr2 = torch.clamp(ratio, 1 - eps_clip, 1 + eps_clip) * advantage
            loss = -torch.min(surr1, surr2) + F.smooth_l1_loss(self._v(s, first_hidden), td_target.detach())

            self.optim_policy.zero_grad()
            loss.mean().backward()
            self.optim_policy.step()
        
    def make_pred_s_tis(self, s, a_lst, h_in, limit):
        s_ti = []
        for i in range(limit):
            _, h_out, pred_s = self.pred_present(s[i], a_lst[i], h_in)
            s_ti.append(pred_s)
            h_in = h_out
        return torch.stack(s_ti)
        
    def learn_pred_model(self):
        if self.p_iters == 0: return
        
        loss_log = []
        for _ in range(K_epoch_pred_model):
            total_loss = 0
            for i in range(self.num_memos):
                s,a,r,s_prime,done_mask, prob_a, h_ins, h_outs, a_lst = self.make_batch(i)
                (h_in, c_in), (h_out, c_out) = h_ins, h_outs
                first_hidden  = (h_in.detach(), c_in.detach())
                second_hidden = (h_out.detach(), c_out.detach())

                target = []
                limit = len(s) -self.p_iters
                for i in range(limit):
                    start, end = i + 1, min(i + self.p_iters, len(s) - 1) + 1
                    before_done = s[start : end].tolist()
                    after_done = [s[-1] for _ in range(self.p_iters - (end - start))]
                    target.append(before_done + after_done)
                target = torch.tensor(target, dtype=torch.float).view(-1, self.p_iters, s_size)

                pred = self.make_pred_s_tis(s, a_lst, first_hidden, limit)
                loss = self.pred_model.criterion(pred, target)
                total_loss += loss
            total_loss /= self.num_memos
            
            self.optim_pred_model.zero_grad()
            total_loss.mean().backward()
            self.optim_pred_model.step()
            loss_log.append(total_loss)
        print(f"Loss: {torch.mean(torch.stack(loss_log))}")
        
#         s_lst = self.memory.exps["states"]
#         target = []
#         for i in range(len(s_lst)):
#             start, end = i + 1, min(i + self.p_iters, len(s_lst) - 1) + 1
#             after_done = [s_lst[-1] for _ in range(self.p_iters - (end - start))]
#             target.append(s_lst[start : end] + after_done)
#         target = torch.tensor(target, dtype=torch.float).view(-1, self.p_iters, s_size)
        
#         mean_loss = 0
#         # to-do 
#         for _ in range(K_epoch_pred_model):
#             pred = self.make_pred_s_tis(s, a_lst, h_in)
#             loss = self.pred_model.criterion(pred, target)
        
#             self.optim_pred_model.zero_grad()
#             loss.mean().backward()
#             self.optim_pred_model.step()
#             mean_loss += loss.item()
            
#         print(mean_loss / K_epoch_pred_model)
            
    def learn(self):
        self.learn_pred_model()
#         self.learn_policy(s,a,r,s_prime,done_mask, prob_a, first_hidden, second_hidden)
        
        for memo in self.memory:
            memo.init_exps()

In [34]:
score_avg_interval = 20
delay = 4

env = gym.make(env_name)
model = Agent(p_iters=delay)
score = 0.0
num_eps = 10 * model.num_memos

for ep in range(1, num_eps + 1):
    s, info = env.reset()
    h0 = torch.zeros([1, 1, 64], dtype=torch.float)
    h_out = (h0, h0)
    a_lst = [i % 2 for i in range(delay)]
    done = False
    
    while not done:
        for t in range(T_horizon):
            h_in = h_out
            a, prob, h_out, pred_s_ti = model.sample_action(s, a_lst, h_in)
            prob = prob.view(-1)
            a_lst.append(a)

            delay_a = a_lst.pop(0)
            s_prime, r, terminated, truncated, info = env.step(delay_a)
            done = terminated or truncated
            
            exp = {
                "states": s,
                "actions": delay_a,
                "probs": prob[delay_a].item(),
                "rewards": r / 100.0,
                "states_prime": s_prime,
                "h_ins": h_in,
                "h_outs": h_out,
                "dones": done,
                "timesteps": t,
                "pred_s_tis": pred_s_ti,
                "a_lsts": a_lst
            }
            model.memory[(ep - 1) % model.num_memos].store(**exp)
            s = s_prime
            score += r
            if done:
                break
                
    if ep % model.num_memos == 0:
        model.learn()
        
    if ep % score_avg_interval == 0:
        print(f"Ep. {ep - score_avg_interval + 1} ~ {ep}", end=", ")
        print(f"avg score : {score / score_avg_interval:.1f}")
        score = 0

env.close()
print("Finished.")

Loss: 0.20449385046958923
Loss: 0.1440919190645218
Ep. 1 ~ 20, avg score : 25.2
Loss: 0.13292978703975677
Loss: 0.08256419003009796
Ep. 21 ~ 40, avg score : 30.8
Loss: 0.06558500230312347
Loss: 0.07339993119239807
Ep. 41 ~ 60, avg score : 29.5
Loss: 0.08932764083147049
Loss: 0.06470133364200592
Ep. 61 ~ 80, avg score : 24.5
Loss: 0.06744922697544098
Loss: 0.07532991468906403
Ep. 81 ~ 100, avg score : 24.1
Finished.


In [35]:
num_test_eps = 10
env = gym.make(env_name)
h0 = torch.zeros([1, 1, 32], dtype=torch.float)
h_out = (h0, h0)
total_score = []

for ep in range(1, num_test_eps + 1):
    s, info = env.reset()
    a_lst = [env.action_space.sample() for _ in range(delay)]
    done = False
    score = 0
    
    while not done:
        h_in = h_out
        a, prob, h_out, _ = model.sample_action(s, a_lst, h_in)
        prob = prob.view(-1)
        a_lst.append(a)

        delay_a = a_lst.pop(0)
        s_prime, r, terminated, truncated, info = env.step(delay_a)
        done = terminated or truncated
        s = s_prime
        score += r

        if done:
            break
    
    print(f"Ep. {ep}, score : {score}")
    total_score.append(score)

env.close()
print("Finished.")
print(f"Average score : {sum(total_score) / len(total_score)}")

RuntimeError: Expected hidden[0] size (1, 1, 64), got [1, 1, 32]