In [None]:
#!pip install gymnasium==0.29.1

In [None]:
import gymnasium as gym
import torch,sys
import torch.nn as nn
import torch.nn.functional as F
from dataclasses import dataclass
import warnings
warnings.filterwarnings("ignore",category=UserWarning)

@dataclass
class Hypers:
    env = gym.make_vec(
        "CartPole-v1",num_envs=10,vectorization_mode="sync"
    )
    device = torch.device("cpu")
    lr = 1e-4
    gamma = 0.99
    lambda_ = 1.0
    epsilon = 0.2

    entropy_coeff = 1e-1
    critic_coeff = 5e-1
 
    total_games = 600
    batchsize = 256
    minibatch = 64
    optim_steps = 6

hypers = Hypers()

class network(nn.Module):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.l1 = nn.LazyLinear(32)
        self.l2 = nn.LazyLinear(32)
        self.l3 = nn.LazyLinear(32)
        self.policy = nn.LazyLinear(2)
        self.value = nn.LazyLinear(1)
        self.optim =  torch.optim.Adam(self.parameters(),lr=hypers.lr)

    def forward(self,x):
        x = F.relu(self.l1(x))
        x = F.relu(self.l2(x))
        x = F.relu(self.l3(x))
        policy = self.policy(x)
        value = self.value(x)        
        return F.softmax(policy,-1),value
    
def init_w(w):
    if isinstance(w,nn.Linear):
        torch.nn.init.xavier_uniform_(w.weight)
        torch.nn.init.constant_(w.bias,0.0)

model = network().to(hypers.device)
model(torch.rand((5,4),dtype=torch.float32,device=hypers.device))
model.apply(init_w)

In [None]:
from torch.distributions import Categorical
from collections import deque
import numpy as np

class replay_buffer:
    def __init__(self,env : Hypers,model : network):
        self.data = []
        self.env = env
        self.model = model
        self.pointer = 0
        self.reward = np.zeros(self.env.num_envs, dtype=np.float32)
        self.reward_data = deque(maxlen=25)

    @torch.no_grad()
    def rollout(self,batchsize):
        self.clear()
        obs,_ = self.env.reset()
        for n in range(batchsize):
            obs = torch.from_numpy(obs).to(torch.float32)
            probs,value = self.model(obs)
            action_dist = Categorical(probs=probs)
            action = action_dist.sample()
            old_porbs = action_dist.log_prob(action)

            new_state,reward,done,_,_ = self.env.step(action.tolist())
            for n in range(self.env.num_envs): # reset done env
                self.reward[n]+=reward[n]
                if done[n]:
                    self.reward_data.append(self.reward[n])
                    self.reward[n] = 0
       
            self.data.append([
                obs,
                torch.from_numpy(reward).to(torch.float32),
                value,
                action,
                old_porbs,
                torch.from_numpy(done).to(torch.float32)]
            )
            obs = new_state
    
        next_state = torch.from_numpy(obs).to(torch.float32)
        _,_next_value = self.model(next_state)
        _,rewards,values,_,_,dones = zip(*self.data)
        dones = torch.stack(dones)
        rewards = torch.stack(rewards).to(torch.float32)
        _values = torch.stack((values)).squeeze()
        values = torch.cat((_values,_next_value.permute(-1,0)))
        advantage = torch.zeros_like(rewards,dtype=torch.float32)
        gae = 0.0
        for n in reversed(range(len(rewards))):
            td = rewards[n] + (hypers.gamma * values[n+1] * (1-dones[n])) - values[n]
            gae = td + (hypers.lambda_ * hypers.gamma * gae * (1-dones[n]))
            advantage[n] = gae

        for data,item in zip(self.data,advantage):
            data.append(item)
    
    def sample(self,minibatch):
        output = self.data[self.pointer:self.pointer+minibatch]
        if len(output)== 0 : raise ValueError("replay buffer is empty")
        self.pointer+=minibatch
        Stack_ = lambda x : torch.stack(x).squeeze()
        states,_,values,actions,old_probs,_,advantages = zip(*output)
        return Stack_(states),Stack_(values) ,Stack_(actions),Stack_(old_probs),Stack_(advantages)
         
    def clear(self):
        self.pointer = 0
        self.data = []

    def reward_fn(self):
        return torch.tensor(self.reward_data).mean()

In [None]:
from tqdm import tqdm

class PPO:
    def __init__(self):
        self.model = model
        self.env = hypers.env
        self.memory = replay_buffer(self.env,self.model)
    
    def save(self):
        chk = {
            "model_state" : self.model.state_dict(),
            "optim_state" : self.model.optim.state_dict()
        }
        torch.save(chk,"./CartPole.pth")

    def run(self,total_games,batchsize,minibatch,optim_step):
        for traj in tqdm(range(total_games),total=total_games):
            self.memory.rollout(batchsize)
            for _ in range(batchsize//minibatch):
                _states,_values,_actions,_oldprobs,_advantages = self.memory.sample(minibatch)
                vtarget = _advantages + _values
                for _ in range(optim_step):
                    probs,values = self.model(_states)
                    dist = Categorical(probs)
                    new_probs = dist.log_prob(_actions)
                    ratio = torch.exp(new_probs - _oldprobs)
                    surr1 = ratio *_advantages
                    surr2 = torch.clamp(ratio,1-hypers.epsilon,1+hypers.epsilon) * _advantages
                    loss_policy = -torch.mean(torch.min(surr1,surr2))
                    entropy = dist.entropy().mean() * hypers.entropy_coeff
                    loss_critic = F.mse_loss(values.squeeze(-1),vtarget) * hypers.critic_coeff
                    total_loss = loss_policy + loss_critic - entropy
                    self.model.optim.zero_grad()
                    total_loss.backward()
                    nn.utils.clip_grad_norm_(self.model.parameters(),0.5)
                    self.model.optim.step()

            if traj % 25 == 0:
                print(self.memory.reward_fn())
                self.save()
            
t = PPO()
t.run(hypers.total_games,hypers.batchsize,hypers.minibatch,hypers.optim_steps)