In [29]:
import gym

import torch
from torch import nn
import torch.nn.functional as F

from models.ppo import PPO
from commons.basemodel import BaseModel

In [30]:
class Actor(nn.Module):
    def __init__(self, nstats: int, nactions: int, hid_dim:int):
        super(Actor, self).__init__()
        self.fc1 = nn.Linear(nstats, hid_dim)
        self.fc2 = nn.Linear(hid_dim, nactions)
        self.relu1 = nn.ReLU()
        
    def forward(self, x):
        x = self.relu1(self.fc1(x))
        x = F.softmax(self.fc2(x), -1)
        return x


In [31]:
class Critic(nn.Module):
    def __init__(self, nstats: int, hid_dim:int):
        super(Critic, self).__init__()
        self.fc1 = nn.Linear(nstats, hid_dim)
        self.fc2 = nn.Linear(hid_dim, 1)
        self.relu1 = nn.ReLU()
        
    def forward(self, x):
        x = self.relu1(self.fc1(x))
        x = self.fc2(x)
        return x


In [143]:
import numpy as np

class BaseModel():
    def __init__(self, max_len = 9999):

        self.datas = []
        self.max_len = max_len

    def put(self, data):
        if len(self.datas) > self.max_len:
            self.datas.pop(0)
        
        self.datas.append(data)

    def sample(self):
        lst = [[] for i in range(len(self.datas[0]))]
        for i in range(len(self.datas)):
            for j in range(len(self.datas[i])):
                d = self.datas[i][j]
                d = d if isinstance(d, np.ndarray) else np.array([d])
                lst[j].append(d)
        
        self.datas = []

        return [torch.tensor(np.array(d), dtype=torch.int64 if d[0].dtype == 'int' else torch.float) for d in lst]
            

In [None]:
# import numpy as np

# class BaseModel():
#     def __init__(self, max_len = 9999):

#         self.datas = []
#         self.max_len = max_len

#     def put(self, data):
#         if len(self.datas) > self.max_len:
#             self.datas.pop(0)
        
#         self.datas.append(data)

#     def sample(self, n_batchs):
#         samples = []
#         sample = [[] for i in range(len(self.datas[0]))]
#         for i in range(len(self.datas)):
#             for j in range(len(self.datas[i])):
#                 d = self.datas[i][j]
#                 d = d if isinstance(d, np.ndarray) else np.array([d])
#                 sample[j].append(d)

#             if i == n_batchs:
#                 sample = [torch.tesnor(s, dtype=torch.int64 if s.dtype == 'int' else torch.float) for s in sample]
#                 samples.append(sample)
#                 sample = [[] for k in range(len(self.datas[0]))]
        
#         self.datas = []

#         return samples
            

In [156]:
from torch.distributions import Categorical
from commons.basemodel import BaseModel

class PPO(BaseModel):
    def __init__(self, env, n_acts: int, act: nn.Module, act_opt: torch.optim, 
                 cri: nn.Module, cri_opt:torch.optim, gamma = 0.98, gae=0.95, eps=0.2):
        super().__init__()

        self.env = env
        self.n_acts = n_acts
        
        self.act = act
        self.act_opt = act_opt
        
        self.cri = cri
        self.cri_opt = cri_opt
        
        self.gamma = gamma
        self.gae = gae
        self.eps = eps


    def train(self, n_epis, n_rollout, n_update, print_interval=20):
        env = self.env
        score = 0.0

        for epi in range(n_epis):
            done = False
            s = env.reset()[0]
            
            while not done:
                for t in range(n_rollout):
                    a, a_prob = self.get_action(s)
                    s_p, r, done, _, _ = env.step(a)
                    d_mask = 0 if done else 1
                    self.put((s, a, r/100, s_p, d_mask, a_prob))
                    env.render()

                    s = s_p
                    score += r

                    if done:
                        break
                
                n_batch = n_rollout//n_update
                samples = self.sample(n_batch)

                for i in range(0, n_rollout, n_batch):
                    sample = [s[i:i+n_batch] for s in samples]
                    self.update(sample)
            
            if epi % print_interval == 0 and epi != 0:
                print(f"epi: {epi}, score: {score / print_interval}")
                score = 0

    def get_action(self, s):
        prob = self.act(torch.from_numpy(s).type(torch.float))
        m = Categorical(prob)
        a = m.sample().item()
        return a, prob[a].item()

    def update(self, sample):
        s, a, r, s_p, d_mask, a_prob = sample

        td_target = r + self.gamma * self.cri(s_p) * d_mask
        advs = (td_target - self.cri(s)).view(-1)

        a_hats = []
        a_hat = 0
        for adv in advs.detach().numpy()[::-1]:
            a_hat = adv + self.gamma * self.gae * a_hat
            a_hats.append(a_hat)
        
        a_hats.reverse()
        a_hats = torch.tensor(a_hats)
        
        cri_loss = F.mse_loss(self.cri(s), td_target.detach())
        self.cri_opt.zero_grad()
        cri_loss.backward()
        self.cri_opt.step()
        
        pi = self.act(s)
        pi_a = pi.gather(1, a) / a_prob
        
        act_loss = - torch.min(pi_a * a_hats, torch.clamp(pi_a, 1 - self.eps, 1 + self.eps) * a_hats)
        act_loss = act_loss.mean()
        self.act_opt.zero_grad()
        act_loss.backward()
        self.act_opt.step()


In [157]:
learning_rate = 0.0005
nstates = 4
nactions = 2
hid_dim = 128
gamma = 0.98
gae = 0.95

In [158]:
env = gym.make("CartPole-v1", render_mode='human')

In [159]:
act = Actor(nstates, nactions, hid_dim)
act_opt = torch.optim.Adam(act.parameters(), lr=learning_rate)

cri = Critic(nstates, hid_dim)
cri_opt = torch.optim.Adam(cri.parameters(), lr=learning_rate)

In [160]:
ppo = PPO(env, 2, act, act_opt, cri, cri_opt)

In [161]:
ppo.train(5000, 10, 3)

TypeError: sample() takes 1 positional argument but 2 were given

In [9]:
env.close()