In [2]:
import torch
import numpy as np
from torch.distributions import MultivariateNormal
from torch import nn
import torch.nn.functional as F
import gymnasium as gym
from torch.optim import Adam
from torch.utils.tensorboard import SummaryWriter
from torch.distributions.categorical import Categorical



In [None]:
def layer_init(layer, std=np.sqrt(2), bias_const=0.0):
    torch.nn.init.orthogonal_(layer.weight, std)
    torch.nn.init.constant_(layer.bias, bias_const)
    return layer
class Agent(nn.Module):
    def __init__(self, envs):
        super(Agent, self).__init__()
        self.critic = nn.Sequential(
            layer_init(nn.Linear(np.array(envs.single_observation_space.shape).prod(), 64)),
            nn.Tanh(),
            layer_init(nn.Linear(64, 64)),
            nn.Tanh(),
            layer_init(nn.Linear(64, 1), std=1.0),
        )
        self.actor = nn.Sequential(
            layer_init(nn.Linear(np.array(envs.single_observation_space.shape).prod(), 64)),
            nn.Tanh(),
            layer_init(nn.Linear(64, 64)),
            nn.Tanh(),
            layer_init(nn.Linear(64, envs.single_action_space.n), std=0.01),
        )

In [3]:
class neural_network(nn.Module):
    def __init__(self,input_dims,output_dims):
        super(neural_network,self).__init__()
        self.layer1=nn.Linear(input_dims,64)
        self.layer2=nn.Linear(64,64)
        self.layer3=nn.Linear(64,output_dims)
    def forward(self,state):
        if isinstance(state,np.ndarray):
            state=torch.tensor(state,dtype=torch.float)
            
        act1=F.relu(self.layer1(state))
        act2=F.relu(self.layer2(act1))
        output=self.layer3(act2)
        return output
    

In [5]:
class PPO_discrete_action:
    def __init__(self,env,max_timesteps,epochs,max_iteration_episodes,gamma,clip,lr,device,num_envs,writer,minibatch_size):
        self.env=env
        self.actor=neural_network(self.env.single_observation_space.shape[0],self.env.single_action_space.n).to(device)
        self.cretic=neural_network(self.env.single_observation_space.shape[0],1).to(device)
        self.max_timesteps=max_timesteps
        self.epochs=epochs
        self.max_iteration_episodes=max_iteration_episodes
        self.gamma=gamma
        self.clip=clip
        self.lr=lr
        self.actor_optimization= Adam(self.actor.parameters() , lr=self.lr)
        self.critic_optimization= Adam(self.cretic.parameters() , lr=self.lr)
       
        self.device=device
        self.num_envs=num_envs
        self.global_step=0
        self.minibatch_size=minibatch_size
        self.writer=writer
    def get_action(self,state):
        mean = self.actor(state)
        dist = Categorical(logits=mean)
        action=dist.sample()
    
        log_probs = dist.log_prob(action)
 
        return action, log_probs
        
        
    def train(self):
        
        t=0
        batch_states = torch.zeros((self.max_timesteps, self.num_envs) + self.env.single_observation_space.shape).to(self.device)
        batch_actions = torch.zeros((self.max_timesteps, self.num_envs) + self.env.single_action_space.shape).to(self.device)
        batch_rew = torch.zeros((self.max_timesteps, self.num_envs)).to(self.device)
        batch_done = torch.zeros((self.max_timesteps, self.num_envs)).to(self.device)
        batch_values=torch.zeros((self.max_timesteps, self.num_envs)).to(self.device)
        # TRY NOT TO MODIFY: start the game
        state,_=self.env.reset()
        next_obs = torch.Tensor(state).to(self.device)
        next_done = torch.zeros(self.num_envs).to(self.device)
        
            
        ep_lenght=0
        ep_rew=np.zeros((self.num_envs))
        step=0
        for i in range(self.max_iteration_episodes):
            step+=1
            self.global_step+=1
            batch_states[i]=next_obs
            V=self.evaluate_V(next_obs).flatten()
            batch_values[i]=V
            batch_done[i]=next_done
            
            action,_=self.get_action(next_obs)
            
            next_obs,reward,done,truncated,info=self.env.step(action.cpu().numpy())
            ep_rew+=reward
            
            next_obs, next_done = torch.Tensor(next_obs).to(self.device), torch.Tensor(done).to(self.device)
            batch_actions[i]=action
            batch_rew[i] = torch.tensor(reward).to(self.device).view(-1)
            # for item in info:
            #     if "episode" in item.keys():
            #         print(f"global_step={t}, episodic_return={item['episode']['r']}")
            #         self.writer.add_scalar("charts/episodic_return", item["episode"]["r"], t)
            #         self.writer.add_scalar("charts/episodic_length", item["episode"]["l"], t)
            #         break
            
            if (1 in done) or (1 in truncated) or (i == self.max_iteration_episodes-1) :
                print(ep_rew[0]/step)
                self.writer.add_scalar('rollout/ep_rew_mean',ep_rew[0]/step,self.global_step)
                eps_rew=np.zeros((self.num_envs))
                step=0
            state=next_obs
            # L=[]
            # L.append(action)
        batch_states = batch_states.reshape((-1,) + self.env.single_observation_space.shape)
        batch_actions = batch_actions.reshape((-1,) + self.env.single_action_space.shape)
        

        return batch_rew,batch_states,batch_actions,next_obs,next_done,batch_done,batch_values
    
    def evaluate_prob(self,batch_states,batch_actions):
        mean=self.actor(batch_states)
        dist=Categorical(logits=mean)
        log_prob=dist.log_prob(batch_actions)
        V=self.cretic(batch_states).squeeze()
        return V,log_prob
    
    def evaluate_V(self,batch_states):
        V=self.cretic(batch_states)
        return V.detach()
    
    def evaluate_Q(self,batch_rew,next_obs,next_done,batch_done):
        
        V=self.evaluate_V(next_obs).reshape(1, -1)
       
        batch_Q = torch.zeros_like(batch_rew).to(self.device)
     
        for j in reversed(range(self.max_timesteps)):
            
            if j==self.max_timesteps - 1:
                nextnonterminal = 1.0 - next_done
                next_return = V
            else:
                nextnonterminal = 1.0 - batch_done[j + 1]
                next_return = batch_rew[j + 1]
        
            
            batch_Q[j] = batch_rew[j] + self.gamma * nextnonterminal * next_return
        return batch_Q.detach()
    
    def scale(self,A):
        return (A-A.mean())/(A.std()+1e-10)
        
   
    def learn(self,x):
        k=0
        while k<x:
            batch_rew,batch_states,batch_actions,next_obs,next_done,batch_done,batch_values=self.train()
            _,batch_log_prob=self.evaluate_prob(batch_states,batch_actions)
            batch_log_prob=batch_log_prob.detach()
            V_k = batch_values.detach()
            Q=self.evaluate_Q(batch_rew,next_obs,next_done,batch_done)
            A_k= Q -V_k
            
            A_k=A_k.reshape(-1)
            batch_size=self.max_timesteps*self.num_envs
            b_inds = np.arange(batch_size)
            

            Q=Q.reshape(-1)
            for epoch in range(self.epochs):
                np.random.shuffle(b_inds)
                for start in range(0, batch_size, self.minibatch_size):
                    end = start + self.minibatch_size
                    mb_inds = b_inds[start:end]
                    V,curr_log_prob=self.evaluate_prob(batch_states[mb_inds],batch_actions[mb_inds])
                    ratio=torch.exp(curr_log_prob - batch_log_prob[mb_inds]).to(self.device)
                    A_km=self.scale(A_k[mb_inds])
                    loss1=ratio*A_km
                    loss2=(torch.clamp(ratio, 1-self.clip ,1+self.clip ).to(self.device)) *A_km
                    loss_pi=(-torch.min(loss1,loss2)).mean()
                    
                    self.actor_optimization.zero_grad()
                    loss_pi.backward()
                    self.actor_optimization.step()
                
                
                    loss_v=nn.MSELoss()(V, Q[mb_inds])
                    self.critic_optimization.zero_grad()
                    loss_v.backward()
                    self.critic_optimization.step()
            k+=Q.shape[0]
    def test(self,episodes):
        env=gym.make('CartPole-v1',render_mode='human')
        for episode in range(episodes):
            ep_len = 0            # episodic length
            ep_ret = 0            # episodic retur
            done=False
            obs,_=env.reset()
            print(obs)
            env.render()
            while not done:

                # Render environment if specified, off by default
                
                obs=torch.tensor(obs, dtype=torch.float).to(self.device)

                # Query deterministic action from policy and run it
                
                action = self.actor(obs)
                # print(action)
                # dist=Categorical(logits=action)
               
                # action=dist.sample()
           
                
                
                obs, rew, done, _,_ = env.step(action.detach().cpu().numpy().argmax())

                # Sum all episodic rewards as we go along
                ep_ret += rew
                
            # Track episodic length

            # returns episodic length and return in this iteration
                print('reward :',ep_ret)
        env.close()
            
        
                
                
            
        


In [10]:
def make_env(gym_id):
    def thunk():
        env = gym.make(gym_id)
        env = gym.wrappers.RecordEpisodeStatistics(env)
      
        return env

    return thunk
envs = gym.vector.SyncVectorEnv(
        [make_env("CartPole-v1")for i in range(1)])
device = torch.device("cuda")
writer = SummaryWriter("raed12/PPO_discreate_action6")

model=PPO_discrete_action(envs,1_000,10,500,0.99,0.2,3e-4,device,1,writer,4)
model.learn(1_000_000)
writer.close()

1.0
1.736842105263158
3.2
2.8461538461538463
5.625
4.6
6.75
11.384615384615385
4.609756097560975
10.0
17.153846153846153
18.153846153846153
6.2444444444444445
11.407407407407407
18.11111111111111
28.166666666666668
14.52
10.81081081081081
18.391304347826086
14.21875
21.681818181818183
30.8125
71.42857142857143
1.0
1.4516129032258065
2.4516129032258065
4.04
3.8857142857142857
7.181818181818182
7.32
5.357142857142857
15.0625
17.066666666666666
8.314285714285715
4.506024096385542
23.0
20.55
25.176470588235293
9.73469387755102
21.73913043478261
1.0
4.0
6.333333333333333
4.518518518518518
4.388888888888889
6.266666666666667
7.962962962962963
10.772727272727273
4.885245901639344
9.514285714285714
20.58823529411765
9.13953488372093
21.68421052631579
15.714285714285714
34.84615384615385
31.2
23.285714285714285
45.45454545454545
1.0
3.3846153846153846
3.0952380952380953
6.416666666666667
5.052631578947368
4.096774193548387
5.233333333333333
5.243243243243243
8.185185185185185
8.620689655172415


KeyboardInterrupt: 

In [11]:
model.test(20)

[-0.01285525 -0.02436538  0.04946968  0.0048667 ]
reward : 1.0
reward : 2.0
reward : 3.0
reward : 4.0
reward : 5.0
reward : 6.0
reward : 7.0
reward : 8.0
reward : 9.0
reward : 10.0
reward : 11.0
[ 0.04661703  0.00735197  0.01744271 -0.04156348]
reward : 1.0
reward : 2.0
reward : 3.0
reward : 4.0
reward : 5.0
reward : 6.0
reward : 7.0
reward : 8.0
reward : 9.0
reward : 10.0
[-0.03143214 -0.00679057  0.00839029 -0.04580721]
reward : 1.0
reward : 2.0
reward : 3.0
reward : 4.0
reward : 5.0
reward : 6.0
reward : 7.0
reward : 8.0
reward : 9.0
[ 0.00466671 -0.01908648  0.01184514  0.02839329]
reward : 1.0
reward : 2.0
reward : 3.0
reward : 4.0
reward : 5.0
reward : 6.0
reward : 7.0
reward : 8.0
reward : 9.0
reward : 10.0
[-0.01541157 -0.00594757 -0.04235777  0.04671635]
reward : 1.0
reward : 2.0
reward : 3.0
reward : 4.0
reward : 5.0
reward : 6.0
reward : 7.0
reward : 8.0
reward : 9.0
[ 0.01199435 -0.04556238 -0.00168998  0.03495909]
reward : 1.0
reward : 2.0
reward : 3.0
reward : 4.0
reward 