In [None]:
!pip install stable_baselines3

In [1]:
import gym
import torch 
import torch.nn as nn
import numpy as np
from stable_baselines3 import PPO
from stable_baselines3.common.torch_layers import BaseFeaturesExtractor
from environments.wildfire_gym import WildFireGym

ModuleNotFoundError: No module named 'stable_baselines3'

In [None]:
class CustomCNN(BaseFeaturesExtractor):

    def _get_conv_out(self):
        o = self.conv(torch.zeros(1, self.channels, self.height, self.width))
        return int(np.prod(o.size()))

    def __init__(self, observation_space: gym.spaces.Box, features_dim: int = 4):
        super(CustomCNN, self).__init__(observation_space, features_dim)

        self.fc1  = nn.Sequential(
            nn.Linear(5, 50),
            nn.ReLU(),
            nn.Linear(50, 50),
            nn.ReLU(),
            nn.Linear(50, 50),
            nn.ReLU(),
            nn.Linear(50, 50),
            nn.ReLU(),
            nn.Linear(50, 50),
            nn.ReLU()
        )

        self.conv = nn.Sequential(
            nn.Conv2d(2, 64, kernel_size=3),
            nn.ReLU(),
            nn.MaxPool2d(2, stride=2),
            nn.Conv2d(64, 64, kernel_size=3),
            nn.ReLU(),
            nn.Conv2d(64, 64, kernel_size=3),
            nn.ReLU(),
            nn.MaxPool2d(2, stride=2),
            nn.Flatten()
        )
    
        conv_out_size = self._get_conv_out()

        self.fc2 = nn.Sequential(
        nn.Linear(conv_out_size, 500),
        nn.ReLU(),
        nn.Linear(500, 100),
        nn.ReLU(),
        )

        self.fc3 = nn.Sequential(
            nn.Linear(150, 150),
            nn.ReLU(),
        )

        self.flatten = nn.Sequential(nn.Linear(150, features_dim), nn.ReLU())

    def forward(self, observations):

        fc1_out = self.fc1(observations['bank_angle'])
        conv_out = torch.flatten(self.conv(observations['belief_map']),1)
        fc2_out = self.fc2(conv_out)
        
        fc3_out = self.fc3(torch.cat((fc1_out, fc2_out), dim=1))
        return self.flatten(fc3_out)



policy_kwargs = dict(
    features_extractor_class=CustomCNN,
    features_extractor_kwargs=dict(features_dim=128),
)

In [None]:
wildFireGym = WildFireGym()
model = PPO(CustomCNN, wildFireGym, verbose=3, gamma=0.95, n_steps=256, ent_coef=0.0905168, learning_rate=0.00062211, vf_coef=0.042202, max_grad_norm=0.9, gae_lambda=0.99, n_epochs=5, clip_range=0.3, batch_size=256)
model.learn(total_timesteps=2000000)
model.save("policy")