In [3]:
from evaluation import evaluation_episode, image_display
import stable_baselines3 as sb3
from env_setup import init_env
import actor_critic
import torch

In [81]:
from torch import nn
from torchvision import models

class Actor_Critic(nn.Module):

    def __init__(self, n_actions):
        super().__init__()

        resnet = models.resnet50(weights = models.ResNet50_Weights.IMAGENET1K_V2)

        self.feature_extractor = nn.Sequential(*list(resnet.children())[:-1])

        self.actor = nn.Linear(resnet.fc.in_features, n_actions)

        self.critic = nn.Linear(resnet.fc.in_features, 1)

    def forward(self, x):

        features = self.feature_extractor(x)
        features = features.view(features.size(0), -1)

        logits = self.actor(features)
        
        values = self.critic(features)

        return logits, values

In [336]:
class PPO:

    def __init__(self, vec_env, model, n_steps, gamma, lambda_):
        
        self.vec_env = vec_env
        self.model = model
        self.n_steps = n_steps
        self.gamma = gamma
        self.lambda_ = lambda_

        self.data = {"observations": [], "actions": [], "log_probs": [], "values": [], "rewards": [], "dones": [], "advantages": [], "infos": []}
        self.count_steps = 0

    def observation_preprocessing(self, observations):
        
        observations = torch.tensor(observations)
        observations = observations.permute(0, 3, 1, 2) # from [N, H, W, C] to [N, C, H, W]
        observations = observations.float() / 255.0
        
        return observations

    def data_collection(self):
        
        observations = self.vec_env.reset()
        
        while self.count_steps < self.n_steps:

            observations = self.observation_preprocessing(observations)
            
            with torch.no_grad():
                
                logits, values = self.model(observations)
                
                distros = torch.distributions.Categorical(logits = logits)
                
                actions = distros.sample().unsqueeze(-1)
                
                log_probs = distros.log_prob(actions.squeeze(-1)).unsqueeze(-1)
                
            next_observations, rewards, dones, infos = self.vec_env.step(actions.cpu().numpy().squeeze(-1))
            
            self.data["observations"].append(observations) 
            self.data["actions"].append(actions)
            self.data["log_probs"].append(log_probs)
            self.data["values"].append(values)
            self.data["rewards"].append(rewards)
            self.data["dones"].append(dones)
            self.data["infos"].append(infos)
            
            self.count_steps += self.vec_env.num_envs
            observations = next_observations

        with torch.no_grad():

            observations = self.observation_preprocessing(observations)

            _, values = self.model(observations)

            self.data["values"].append(values)


    def stack(self):

        obs = self.data["observations"][0].shape

        self.data["observations"] = torch.stack(self.data["observations"])
        
        self.data["actions"] = torch.cat(self.data["actions"]).view(len(self.data["actions"]), -1)

        self.data["log_probs"] = torch.cat(self.data["log_probs"]).view(len(self.data["log_probs"]), -1)

        self.data["values"] = torch.cat(self.data["values"]).view(len(self.data["values"]), -1)

        self.data["rewards"] = torch.tensor(np.stack(self.data["rewards"]), dtype = torch.float32)
        
        self.data["dones"] = torch.tensor(np.stack(self.data["dones"]), dtype = torch.float32)
    
    def advantages_collector(self):

        rewards = self.data["rewards"]

        values = self.data["values"]

        dones = self.data["dones"]

        T, n_envs = rewards.shape
        
        gae = torch.zeros(n_envs)
        
        advantages = torch.zeros_like(rewards)
        
        for t in reversed(range(T)):
            
            delta = rewards[t] + self.gamma * values[t + 1] * (1 - dones[t]) - values[t]
            
            gae = delta + self.gamma * self.lambda_ * (1 - dones[t]) * gae
            
            advantages[t] = gae
        
        self.data["advantages"] = advantages
        
        self.data["returns"] = advantages + values[:-1]

In [337]:
env_vec = sb3.common.env_util.make_vec_env(env_id = init_env, n_envs = 5)

In [338]:
ac = Actor_Critic(n_actions = 4)
ppo = PPO(vec_env = env_vec, model = ac, n_steps = 15, gamma = 1, lambda_ = 3)
ppo.data_collection()
ppo.stack()
ppo.advantages_collector()

In [339]:
print(ppo.data["observations"].shape)  # [T, num_envs, C, H, W]
print(ppo.data["actions"].shape)       # [T, num_envs]
print(ppo.data["advantages"].shape)    # [T, num_envs]
print(ppo.data["returns"].shape)       # [T, num_envs]


torch.Size([3, 5, 3, 64, 64])
torch.Size([3, 5])
torch.Size([3, 5])
torch.Size([3, 5])
