In [1]:
import torch
import numpy as np
from torch.distributions import MultivariateNormal
from torch import nn
import torch.nn.functional as F
import gymnasium as gym
from torch.optim import Adam
from torch.utils.tensorboard import SummaryWriter
from torch.distributions.categorical import Categorical



In [10]:
def layer_init(layer, std=np.sqrt(2), bias_const=0.0):
    torch.nn.init.orthogonal_(layer.weight, std)
    torch.nn.init.constant_(layer.bias, bias_const)
    return layer
class neural_network(nn.Module):
    def __init__(self, envs):
        super(neural_network, self).__init__()
        self.critic = nn.Sequential(
            layer_init(nn.Linear(np.array(envs.single_observation_space.shape).prod(), 64)),
            nn.Tanh(),
            layer_init(nn.Linear(64, 64)),
            nn.Tanh(),
            layer_init(nn.Linear(64, 1), std=1.0),
        )
        self.actor = nn.Sequential(
            layer_init(nn.Linear(np.array(envs.single_observation_space.shape).prod(), 64)),
            nn.Tanh(),
            layer_init(nn.Linear(64, 64)),
            nn.Tanh(),
            layer_init(nn.Linear(64, envs.single_action_space.n), std=0.01),
        )

In [22]:
class PPO_discrete_action:
    def __init__(self,env,max_timesteps,epochs,max_iteration_episodes,gamma,clip,lr,device,num_envs,writer,minibatch_size,vf_coef,ent_coef):
        self.env=env
        self.actor=neural_network(self.env).actor.to(device)
        self.cretic=neural_network(self.env).critic.to(device)
        self.max_timesteps=max_timesteps
        self.epochs=epochs
        self.max_iteration_episodes=max_iteration_episodes
        self.gamma=gamma
        self.clip=clip
        self.lr=lr
        self.actor_optimization= Adam(self.actor.parameters() , lr=self.lr)
        self.critic_optimization= Adam(self.cretic.parameters() , lr=self.lr)
       
        self.device=device
        self.num_envs=num_envs
        self.global_step=0
        self.minibatch_size=minibatch_size
        self.writer=writer
        self.ent_coef=ent_coef
        self.vf_coef=vf_coef
        
    def get_action(self,state):
        mean = self.actor(state)
        dist = Categorical(logits=mean)
        action=dist.sample()
    
        log_probs = dist.log_prob(action)
 
        return action, log_probs
        
        
    def train(self):
        
        t=0
        batch_states = torch.zeros((self.max_timesteps, self.num_envs) + self.env.single_observation_space.shape).to(self.device)
        batch_actions = torch.zeros((self.max_timesteps, self.num_envs) + self.env.single_action_space.shape).to(self.device)
        batch_rew = torch.zeros((self.max_timesteps, self.num_envs)).to(self.device)
        batch_done = torch.zeros((self.max_timesteps, self.num_envs)).to(self.device)
        batch_values=torch.zeros((self.max_timesteps, self.num_envs)).to(self.device)
        # TRY NOT TO MODIFY: start the game
        state,_=self.env.reset()
        next_obs = torch.Tensor(state).to(self.device)
        next_done = torch.zeros(self.num_envs).to(self.device)
        
            
        ep_lenght=0
        ep_rew=np.zeros((self.num_envs))
        step=0
        for i in range(self.max_iteration_episodes):
            step+=1
            self.global_step+=1
            batch_states[i]=next_obs
            V=self.evaluate_V(next_obs).flatten()
            batch_values[i]=V
            batch_done[i]=next_done
            
            action,_=self.get_action(next_obs)
            
            next_obs,reward,done,truncated,info=self.env.step(action.cpu().numpy())
            ep_rew+=reward
            
            next_obs, next_done = torch.Tensor(next_obs).to(self.device), torch.Tensor(done).to(self.device)
            batch_actions[i]=action
            batch_rew[i] = torch.tensor(reward).to(self.device).view(-1)
            # for item in info:
            #     if "episode" in item.keys():
            #         print(f"global_step={t}, episodic_return={item['episode']['r']}")
            #         self.writer.add_scalar("charts/episodic_return", item["episode"]["r"], t)
            #         self.writer.add_scalar("charts/episodic_length", item["episode"]["l"], t)
            #         break
            
            if (1 in done) or (1 in truncated) or (i == self.max_iteration_episodes-1) :
                print(ep_rew[0]/step)
                self.writer.add_scalar('rollout/ep_rew_mean',ep_rew[0]/step,self.global_step)
                eps_rew=np.zeros((self.num_envs))
                step=0
            state=next_obs
            # L=[]
            # L.append(action)
        batch_states = batch_states.reshape((-1,) + self.env.single_observation_space.shape)
        batch_actions = batch_actions.reshape((-1,) + self.env.single_action_space.shape)
        

        return batch_rew,batch_states,batch_actions,next_obs,next_done,batch_done,batch_values
    
    def evaluate_prob(self,batch_states,batch_actions):
        mean=self.actor(batch_states)
        dist=Categorical(logits=mean)
        log_prob=dist.log_prob(batch_actions)
        V=self.cretic(batch_states).squeeze()
        return V,log_prob,dist.entropy()
    
    def evaluate_V(self,batch_states):
        V=self.cretic(batch_states)
        return V.detach()
    
    def evaluate_Q(self,batch_rew,next_obs,next_done,batch_done):
        
        V=self.evaluate_V(next_obs).reshape(1, -1)
       
        batch_Q = torch.zeros_like(batch_rew).to(self.device)
     
        for j in reversed(range(self.max_timesteps)):
            
            if j==self.max_timesteps - 1:
                nextnonterminal = 1.0 - next_done
                next_return = V
            else:
                nextnonterminal = 1.0 - batch_done[j + 1]
                next_return = batch_rew[j + 1]
        
            
            batch_Q[j] = batch_rew[j] + self.gamma * nextnonterminal * next_return
        return batch_Q.detach()
    
    def scale(self,A):
        return (A-A.mean())/(A.std()+1e-10)
        
   
    def learn(self,x):
        k=0
        while k<x:
            batch_rew,batch_states,batch_actions,next_obs,next_done,batch_done,batch_values=self.train()
            _,batch_log_prob,_=self.evaluate_prob(batch_states,batch_actions)
            batch_log_prob=batch_log_prob.detach()
            V_k = batch_values.detach()
            Q=self.evaluate_Q(batch_rew,next_obs,next_done,batch_done)
            A_k= Q -V_k
            
            A_k=A_k.reshape(-1)
            batch_size=self.max_timesteps*self.num_envs
            b_inds = np.arange(batch_size)
            

            Q=Q.reshape(-1)
            clipfracs = []
            for epoch in range(self.epochs):
                    np.random.shuffle(b_inds)
                    for start in range(0, self.max_iteration_episodes, self.minibatch_size):
                        end = start + self.minibatch_size
                        mb_inds = b_inds[start:end]

                        V, newlogprob ,entropy= self.evaluate_prob(batch_states[mb_inds], batch_actions.long()[mb_inds])
                        logratio = newlogprob - batch_log_prob[mb_inds]
                        ratio = logratio.exp()

                        with torch.no_grad():
                            # calculate approx_kl http://joschu.net/blog/kl-approx.html
                            old_approx_kl = (-logratio).mean()
                            approx_kl = ((ratio - 1) - logratio).mean()
                            clipfracs += [((ratio - 1.0).abs() > self.clip).float().mean().item()]

                        mb_advantages = A_k[mb_inds]
                        
                        mb_advantages = (mb_advantages - mb_advantages.mean()) / (mb_advantages.std() + 1e-8)

                        # Policy loss
                        pg_loss1 = -mb_advantages * ratio
                        pg_loss2 = -mb_advantages * torch.clamp(ratio, 1 - self.clip, 1 + self.clip)
                        pg_loss = torch.max(pg_loss1, pg_loss2).mean()

                        # Value loss
                        newvalue = V.view(-1)
                        # if args.clip_vloss:
                        v_loss_unclipped = (newvalue - Q[mb_inds]) ** 2
                        v_clipped = batch_values[mb_inds] + torch.clamp(
                            newvalue - batch_values[mb_inds],
                            -self.clip,
                            self.clip,
                        )
                        v_loss_clipped = (v_clipped - Q[mb_inds]) ** 2
                        v_loss_max = torch.max(v_loss_unclipped, v_loss_clipped)
                        v_loss = 0.5 * v_loss_max.mean()
                        # else:
                        #     v_loss = 0.5 * ((newvalue - Q[mb_inds]) ** 2).mean()

                        entropy_loss = entropy.mean()
                        loss = pg_loss - self.ent_coef * entropy_loss + v_loss * self.vf_coef
                        
                        self.actor_optimization.zero_grad()
                        self.critic_optimization.zero_grad()
                        loss.backward()
                        self.actor_optimization.step()
                        
                        
                        
                        self.critic_optimization.step()
            k+=Q.shape[0]
    def test(self,episodes):
        env=gym.make('CartPole-v1',render_mode='human')
        for episode in range(episodes):
            ep_len = 0            # episodic length
            ep_ret = 0            # episodic retur
            done=False
            obs,_=env.reset()
            print(obs)
            env.render()
            while not done:

                # Render environment if specified, off by default
                
                obs=torch.tensor(obs, dtype=torch.float).to(self.device)

                # Query deterministic action from policy and run it
                
                action = self.actor(obs)
                # print(action)
                # dist=Categorical(logits=action)
               
                # action=dist.sample()
           
                
                
                obs, rew, done, _,_ = env.step(action.detach().cpu().numpy().argmax())

                # Sum all episodic rewards as we go along
                ep_ret += rew
                
            # Track episodic length

            # returns episodic length and return in this iteration
                print('reward :',ep_ret)
        env.close()
            
        
                
                
            
        


In [24]:
def make_env(gym_id):
    def thunk():
        env = gym.make(gym_id)
        env = gym.wrappers.RecordEpisodeStatistics(env)
      
        return env

    return thunk
envs = gym.vector.SyncVectorEnv(
        [make_env("CartPole-v1")for i in range(1)])
device = torch.device("cuda")
writer = SummaryWriter("raed12/PPO_discreate_action9")

model=PPO_discrete_action(envs,1_000,10,500,0.99,0.2,3e-4,device,1,writer,4,0.5,0.01)
model.learn(1_000_000)
writer.close()

1.0
3.5384615384615383
3.5555555555555554
5.923076923076923
6.5
3.757575757575758
10.538461538461538
14.7
6.444444444444445
11.235294117647058
3.6527777777777777
18.533333333333335
5.149253731343284
19.157894736842106
19.2
28.428571428571427
21.94736842105263
28.8
25.0
41.90909090909091
26.61111111111111
23.80952380952381
1.0
2.5
2.764705882352941
4.916666666666667
4.470588235294118
4.166666666666667
7.666666666666667
6.2272727272727275
8.61111111111111
7.739130434782608
14.692307692307692
15.692307692307692
8.555555555555555
4.3478260869565215
10.375
19.444444444444443
11.606060606060606
15.185185185185185
20.523809523809526
34.15384615384615
50.333333333333336
24.842105263157894
32.46666666666667
38.46153846153846
1.0
2.2857142857142856
4.2
3.3333333333333335
2.935483870967742
8.0
6.777777777777778
9.133333333333333
7.85
15.272727272727273
13.0
17.545454545454547
7.655172413793103
14.058823529411764
13.578947368421053
22.5
19.0
17.764705882352942
15.380952380952381
12.535714285714286

KeyboardInterrupt: 

In [25]:
model.test(20)

[ 0.03542046  0.00746696 -0.03154691  0.02794357]
reward : 1.0
reward : 2.0
reward : 3.0
reward : 4.0
reward : 5.0
reward : 6.0
reward : 7.0
reward : 8.0
reward : 9.0
reward : 10.0
reward : 11.0
reward : 12.0
reward : 13.0
reward : 14.0
reward : 15.0
reward : 16.0
reward : 17.0
[ 0.04452743 -0.02255522 -0.04520242  0.04775022]
reward : 1.0
reward : 2.0
reward : 3.0
reward : 4.0
reward : 5.0
reward : 6.0
reward : 7.0
reward : 8.0
reward : 9.0
reward : 10.0
reward : 11.0
reward : 12.0
reward : 13.0
reward : 14.0
reward : 15.0
reward : 16.0
[-0.01627425 -0.04229369 -0.03697155 -0.03021166]
reward : 1.0
reward : 2.0
reward : 3.0
reward : 4.0
reward : 5.0
reward : 6.0
reward : 7.0
reward : 8.0
reward : 9.0
reward : 10.0
reward : 11.0
reward : 12.0
reward : 13.0
reward : 14.0
reward : 15.0
[-0.04228294 -0.01312296  0.01685164  0.03488052]
reward : 1.0
reward : 2.0
reward : 3.0
reward : 4.0
reward : 5.0
reward : 6.0
reward : 7.0
reward : 8.0
reward : 9.0
reward : 10.0
reward : 11.0
reward : 1