Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

Import gym as optional package to build docs successfully #458

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
31 changes: 19 additions & 12 deletions pl_bolts/models/rl/common/gym_wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,21 +4,28 @@
"""
import collections

import gym
import gym.spaces
import numpy as np
import torch

from pl_bolts.utils import _OPENCV_AVAILABLE
from pl_bolts.utils import _GYM_AVAILABLE, _OPENCV_AVAILABLE
from pl_bolts.utils.warnings import warn_missing_pkg

if _GYM_AVAILABLE:
import gym.spaces
from gym import ObservationWrapper, Wrapper
from gym import make as gym_make
else: # pragma: no-cover
warn_missing_pkg('gym')
Wrapper = object
ObservationWrapper = object

if _OPENCV_AVAILABLE:
import cv2
else:
warn_missing_pkg('cv2', pypi_name='opencv-python') # pragma: no-cover


class ToTensor(gym.Wrapper):
class ToTensor(Wrapper):
"""For environments where the user need to press FIRE for the game to start."""

def __init__(self, env=None):
Expand All @@ -34,7 +41,7 @@ def reset(self):
return torch.tensor(self.env.reset())


class FireResetEnv(gym.Wrapper):
class FireResetEnv(Wrapper):
"""For environments where the user need to press FIRE for the game to start."""

def __init__(self, env=None):
Expand All @@ -58,7 +65,7 @@ def reset(self):
return obs


class MaxAndSkipEnv(gym.Wrapper):
class MaxAndSkipEnv(Wrapper):
"""Return only every `skip`-th frame"""

def __init__(self, env=None, skip=4):
Expand Down Expand Up @@ -88,7 +95,7 @@ def reset(self):
return obs


class ProcessFrame84(gym.ObservationWrapper):
class ProcessFrame84(ObservationWrapper):
"""preprocessing images from env"""

def __init__(self, env=None):
Expand Down Expand Up @@ -121,7 +128,7 @@ def process(frame):
return x_t.astype(np.uint8)


class ImageToPyTorch(gym.ObservationWrapper):
class ImageToPyTorch(ObservationWrapper):
"""converts image to pytorch format"""

def __init__(self, env):
Expand All @@ -142,15 +149,15 @@ def observation(observation):
return np.moveaxis(observation, 2, 0)


class ScaledFloatFrame(gym.ObservationWrapper):
class ScaledFloatFrame(ObservationWrapper):
"""scales the pixels"""

@staticmethod
def observation(obs):
return np.array(obs).astype(np.float32) / 255.0


class BufferWrapper(gym.ObservationWrapper):
class BufferWrapper(ObservationWrapper):
""""Wrapper for image stacking"""

def __init__(self, env, n_steps, dtype=np.float32):
Expand All @@ -176,7 +183,7 @@ def observation(self, observation):
return self.buffer


class DataAugmentation(gym.ObservationWrapper):
class DataAugmentation(ObservationWrapper):
"""
Carries out basic data augmentation on the env observations
- ToTensor
Expand All @@ -197,7 +204,7 @@ def observation(self, obs):

def make_environment(env_name):
"""Convert environment with wrappers"""
env = gym.make(env_name)
env = gym_make(env_name)
env = MaxAndSkipEnv(env)
env = FireResetEnv(env)
env = ProcessFrame84(env)
Expand Down
6 changes: 4 additions & 2 deletions pl_bolts/models/rl/dqn_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,17 @@
from pl_bolts.datamodules.experience_source import Experience, ExperienceSourceDataset
from pl_bolts.losses.rl import dqn_loss
from pl_bolts.models.rl.common.agents import ValueAgent
from pl_bolts.models.rl.common.gym_wrappers import make_environment
from pl_bolts.models.rl.common.memory import MultiStepBuffer
from pl_bolts.models.rl.common.networks import CNN
from pl_bolts.utils import _GYM_AVAILABLE
from pl_bolts.utils.warnings import warn_missing_pkg

if _GYM_AVAILABLE:
from pl_bolts.models.rl.common.gym_wrappers import gym, make_environment
from gym import Env
else:
warn_missing_pkg('gym') # pragma: no-cover
Env = object


class DQN(pl.LightningModule):
Expand Down Expand Up @@ -336,7 +338,7 @@ def test_dataloader(self) -> DataLoader:
return self._dataloader()

@staticmethod
def make_environment(env_name: str, seed: Optional[int] = None) -> gym.Env:
def make_environment(env_name: str, seed: Optional[int] = None) -> Env:
"""
Initialise gym environment

Expand Down