In [53]:
from evaluation import evaluation_episode, image_display
from torch.utils.data import DataLoader, Dataset
import stable_baselines3 as sb3
from env_setup import init_env
import numpy as np
import torch

In [83]:
class Dataset(Dataset):
    
    def __init__(self, observations, actions, log_probs, values, returns, advantages):
        
        self.observations = observations
        
        self.actions = actions
        
        self.log_probs = log_probs
        
        self.values = values
        
        self.returns = returns
        
        self.advantages = advantages

    def __len__(self):
        
        return len(self.observations)

    def __getitem__(self, idx):
        
        obs = self.observations[idx]
        
        acts = self.actions[idx]
        
        log_probs = self.actions[idx]
        
        values = self.values[idx]
        
        returns = self.returns[idx]
        
        advantages = self.advantages[idx]

        return obs, acts, log_probs, values, returns, advantages

In [84]:
class PPO:

    def __init__(self, vec_env, model, n_steps, gamma, lambda_):
        
        self.vec_env = vec_env
        self.model = model
        self.n_steps = n_steps
        self.gamma = gamma
        self.lambda_ = lambda_

        self.data = {"observations": [], "actions": [], "log_probs": [], "values": [], "rewards": [], "dones": [], "infos": []}
        self.count_steps = 0

    def observation_preprocessing(self, observations):
        
        observations = torch.tensor(observations)
        observations = observations.permute(0, 3, 1, 2) # from [N, H, W, C] to [N, C, H, W]
        observations = observations.float() / 255.0
        
        return observations

    def data_collection(self):
        
        observations = self.vec_env.reset()
        
        while self.count_steps < self.n_steps:

            observations = self.observation_preprocessing(observations)
            
            with torch.no_grad():
                
                logits, values = self.model(observations)
                
                distros = torch.distributions.Categorical(logits = logits)
                
                actions = distros.sample().unsqueeze(-1)
                
                log_probs = distros.log_prob(actions.squeeze(-1)).unsqueeze(-1)
                
            next_observations, rewards, dones, infos = self.vec_env.step(actions.cpu().numpy().squeeze(-1))
            
            self.data["observations"].append(observations) 
            self.data["actions"].append(actions)
            self.data["log_probs"].append(log_probs)
            self.data["values"].append(values)
            self.data["rewards"].append(rewards)
            self.data["dones"].append(dones)
            self.data["infos"].append(infos)
            
            self.count_steps += self.vec_env.num_envs
            observations = next_observations

        with torch.no_grad():

            observations = self.observation_preprocessing(observations)

            _, values = self.model(observations)

            self.data["values"].append(values)


    def stack(self):

        obs = self.data["observations"][0].shape

        self.data["observations"] = torch.stack(self.data["observations"])
        
        self.data["actions"] = torch.cat(self.data["actions"]).view(len(self.data["actions"]), -1)

        self.data["log_probs"] = torch.cat(self.data["log_probs"]).view(len(self.data["log_probs"]), -1)

        self.data["values"] = torch.cat(self.data["values"]).view(len(self.data["values"]), -1)

        self.data["rewards"] = torch.tensor(np.stack(self.data["rewards"]), dtype = torch.float32)
        
        self.data["dones"] = torch.tensor(np.stack(self.data["dones"]), dtype = torch.float32)
    
    def advantages_collector(self):

        rewards = self.data["rewards"]

        values = self.data["values"]

        dones = self.data["dones"]

        T, n_envs = rewards.shape
        
        gae = torch.zeros(n_envs)
        
        advantages = torch.zeros_like(rewards)
        
        for t in reversed(range(T)):
            
            delta = rewards[t] + self.gamma * values[t + 1] * (1 - dones[t]) - values[t]
            
            gae = delta + self.gamma * self.lambda_ * (1 - dones[t]) * gae
            
            advantages[t] = gae
        
        self.data["advantages"] = advantages

        self.data["values"] = values[:-1]
        
        self.data["returns"] = advantages + values[:-1]

    def flatten(self):

        self.data["observations"] = torch.flatten(self.data["observations"], start_dim = 0, end_dim = 1)

        self.data["actions"] = torch.flatten(self.data["actions"], start_dim = 0, end_dim = 1)

        self.data["log_probs"] = torch.flatten(self.data["log_probs"], start_dim = 0, end_dim = 1)

        self.data["values"] = torch.flatten(self.data["values"], start_dim = 0, end_dim = 1)

        self.data["returns"] = torch.flatten(self.data["returns"], start_dim = 0, end_dim = 1)

        advantages = torch.flatten(self.data["advantages"], start_dim = 0, end_dim = 1)

        self.data["advantages"] = (advantages - advantages.mean()) / (advantages.std() + 1e-8)


    def training_loop(self, n_epochs, batch_size):

        dataset = Dataset(
            self.data["observations"],
            self.data["actions"],
            self.data["log_probs"],
            self.data["values"],
            self.data["returns"],
            self.data["advantages"]
        )
        
        dataloader = DataLoader(dataset, batch_size = batch, shuffle = True)

        self.model.train()
        
        for epoch in range(n_epochs):
            
            for obs, acts, log_probs, values, rewards, advantages in dataloader:
                
                new_logits, new_values = self.model(obs)
                
                new_distros = torch.distributions.Categorical(logits = logits)
                
                new_actions = new_distros.sample().unsqueeze(-1)
                
                new_log_probs = new_distros.log_prob(new_actions.squeeze(-1)).unsqueeze(-1)

                ratio = torch.exp(new_log_probs - log_probs)


In [85]:
env_vec = sb3.common.env_util.make_vec_env(env_id = init_env, n_envs = 5)

In [86]:
ac = Actor_Critic(n_actions = 4)
ppo = PPO(vec_env = env_vec, model = ac, n_steps = 15, gamma = 1, lambda_ = 3)
ppo.data_collection()
ppo.stack()
ppo.advantages_collector()
ppo.flatten()

In [92]:
n_epochs = 1
batch_size = 5

dataset = Dataset(
    ppo.data["observations"],
    ppo.data["actions"],
    ppo.data["log_probs"],
    ppo.data["values"],
    ppo.data["returns"],
    ppo.data["advantages"]
)

dataloader = DataLoader(dataset, batch_size = batch_size, shuffle = True)

ppo.model.train()

for epoch in range(n_epochs):

    for obs, acts, log_probs, values, returns, advantages in dataloader:
        
        new_logits, new_values = ac(obs)
        
        new_distros = torch.distributions.Categorical(logits = new_logits)
        
        new_actions = new_distros.sample()
        
        new_log_probs = new_distros.log_prob(new_actions)

        ratios = torch.exp(new_log_probs - log_probs)

        

ratios: tensor([0.2629, 0.0357, 0.0906, 0.2586, 0.0304], grad_fn=<ExpBackward0>)
ratios: tensor([0.2473, 0.1162, 0.0809, 0.0377, 0.2464], grad_fn=<ExpBackward0>)
ratios: tensor([0.0168, 0.2361, 0.0353, 0.0944, 0.0865], grad_fn=<ExpBackward0>)
