In [1]:
!pip install stable_baselines3
!pip install svgpath2mpl
!pip install gym

[0m

In [2]:
import gym
import torch 
import torch.nn as nn
import numpy as np
from typing import Callable, Dict, List, Optional, Tuple, Type, Union
from stable_baselines3 import PPO
from stable_baselines3.common.torch_layers import BaseFeaturesExtractor
from stable_baselines3.common.policies import ActorCriticPolicy
from environments.shared_wildfire_gym import SharedWildFireGym

In [3]:
class CustomCNN(BaseFeaturesExtractor):

    def _get_conv_out(self):
        o = self.conv(torch.zeros(1, 2, 100, 100))
        return int(np.prod(o.size()))

    def __init__(self, observation_space: gym.spaces.Box, features_dim: int = 4):
        super(CustomCNN, self).__init__(observation_space, features_dim)

        self.fc1  = nn.Sequential(
            nn.Linear(5, 50),
            nn.ReLU(),
            nn.Linear(50, 50),
            nn.ReLU(),
            nn.Linear(50, 50),
            nn.ReLU(),
            nn.Linear(50, 50),
            nn.ReLU(),
            nn.Linear(50, 50),
            nn.ReLU()
        )

        self.conv = nn.Sequential(
            nn.Conv2d(2, 64, kernel_size=3),
            nn.ReLU(),
            nn.MaxPool2d(2, stride=2),
            nn.Conv2d(64, 64, kernel_size=3),
            nn.ReLU(),
            nn.Conv2d(64, 64, kernel_size=3),
            nn.ReLU(),
            nn.MaxPool2d(2, stride=2),
            nn.Flatten()
        )
    
        conv_out_size = self._get_conv_out()

        self.fc2 = nn.Sequential(
        nn.Linear(conv_out_size, 500),
        nn.ReLU(),
        nn.Linear(500, 100),
        nn.ReLU(),
        )

        self.flatten = nn.Sequential(nn.Linear(150, features_dim), nn.ReLU())

    def forward(self, observations):
        input = torch.cat(
            [
                observations['bank_angle'], 
                observations['rho'], observations['theta'], 
                observations['psi'], 
                observations['other_bank_angle']
            ], dim=1)
        fc1_out = self.fc1(input)
        conv_out = torch.flatten(self.conv(observations['belief_map']),1)
        fc2_out = self.fc2(conv_out)
        return self.flatten(torch.cat((fc1_out, fc2_out), dim=1))



policy_kwargs = dict(
    features_extractor_class=CustomCNN,
    features_extractor_kwargs=dict(features_dim=256),
)


In [4]:
wildFireGym = SharedWildFireGym()
model = PPO("CnnPolicy", wildFireGym, verbose=3, n_epochs=5, policy_kwargs=policy_kwargs)
model.learn(total_timesteps=2000000)
model.save("policy")

AssertionError: 