In [5]:
def image_display(img):
    
    plt.figure(figsize = (6, 6))
    
    plt.imshow(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
    
    plt.axis('off')
    
    plt.show()

In [182]:
import torch

In [183]:
class PPO:

    def __init__(self, vec_env, model, n_steps):
        
        self.vec_env = vec_env
        self.model = model
        self.n_steps = n_steps

        self.data = []
        self.count_steps = 0

    def observation_preprocessing(self, observations):
        
        observations = torch.tensor(observations)
        observations = observations.permute(0, 3, 1, 2) # from [N, H, W, C] to [N, C, H, W]
        observations = observations.float() / 255.0
        
        return observations

    def data_collection(self):
        
        observations = self.vec_env.reset()
        
        while self.count_steps < self.n_steps:    
            
            prev_observations = self.observation_preprocessing(observations)
            
            action_distros, values = self.model(prev_observations)
            
            actions = action_distros.argmax(dim = -1, keepdim = True)
            
            action_probs = action_distros.gather(dim = -1, index = actions)
            
            log_action_probs = torch.log(action_probs)
            
            observations, rewards, done, info = self.vec_env.step(actions) # consider using info for training networks
            
            self.data.append([prev_observations, actions, log_action_probs, values, rewards, done])
            
            self.count_steps += self.vec_env.num_envs
