<a href="https://colab.research.google.com/github/TBKHori/Music-Recon13/blob/main/Multiple_Inputs_and_Dictionary_Observations.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [2]:
!pip install gymnasium
!pip install torch
!pip install stable_baselines3

Collecting stable_baselines3
  Downloading stable_baselines3-2.0.0-py3-none-any.whl (178 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m178.4/178.4 kB[0m [31m3.0 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting gymnasium==0.28.1 (from stable_baselines3)
  Downloading gymnasium-0.28.1-py3-none-any.whl (925 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m925.5/925.5 kB[0m [31m10.6 MB/s[0m eta [36m0:00:00[0m
Collecting jax-jumpy>=1.0.0 (from gymnasium==0.28.1->stable_baselines3)
  Downloading jax_jumpy-1.0.0-py3-none-any.whl (20 kB)
Installing collected packages: jax-jumpy, gymnasium, stable_baselines3
  Attempting uninstall: gymnasium
    Found existing installation: gymnasium 0.29.0
    Uninstalling gymnasium-0.29.0:
      Successfully uninstalled gymnasium-0.29.0
Successfully installed gymnasium-0.28.1 jax-jumpy-1.0.0 stable_baselines3-2.0.0


In [8]:
import gymnasium as gym
import torch as th
from torch import nn
from gymnasium import spaces


In [9]:
from stable_baselines3.common.torch_layers import BaseFeaturesExtractor

In [10]:
class CustomCombinedExtractor(BaseFeaturesExtractor):
    def __init__(self, observation_space: spaces.Dict):
        # Call the parent class's constructor to initialize features_dim
        super().__init__(observation_space, features_dim=None)

        # Define extractors for each subspace
        extractors = {
            "image": nn.Sequential(nn.MaxPool2d(4), nn.Flatten()),
            "vector": nn.Linear(observation_space["vector"].shape[0], 16)
        }

        # Compute total concatenated feature size
        total_concat_size = sum(extractor(self._get_dummy_tensor(subspace))
                                for key, extractor in extractors.items())

        self.extractors = nn.ModuleDict(extractors)

        # Update the features dim manually
        self._features_dim = total_concat_size

    def forward(self, observations: th.Tensor) -> th.Tensor:
        encoded_tensor_list = []

        # Extract features for each subspace using their corresponding extractors
        for key, extractor in self.extractors.items():
            encoded_tensor_list.append(extractor(observations[key]))

        # Return a (B, self._features_dim) PyTorch tensor, where B is batch dimension.
        return th.cat(encoded_tensor_list, dim=1)

    def _get_dummy_tensor(self, subspace):
        # Helper method to create a dummy tensor for each subspace
        # This is used to compute the size of the extracted features
        dummy_input = th.zeros((1,) + subspace.shape)
        return extractor(dummy_input).flatten().shape[0]