-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #1 from BBloggsbott/initial_set
Conceptualizing work
- Loading branch information
Showing
9 changed files
with
89 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1,24 @@ | ||
# workout | ||
# Workout | ||
|
||
### What is Workout? | ||
|
||
Workout is an API to import and use `OpenAI-Gym`'s environment with `PyTorch` effortlessly | ||
|
||
### Why it is required? | ||
|
||
`PyTorch`: Flexible framework to implement deep neural networks and has better GPU integration | ||
`OpenAI-Gym`: Provides extensive and varied Reinforcement Learning environments to use readily | ||
|
||
![workout](docs/workout.jpg) | ||
|
||
However, the integration between two is not very extensive. Many works have been done to implement | ||
deep network based Reinforcement Learning algorithms using `PyTorch` seperately, then transfer the whole control | ||
to `Gym`'s environment to estimate reward function, state of the system, possible actions for next step, etc., | ||
and pass it again to `PyTorch`'s model. Therefore, to avoid such complications, `Workout` provides a higher level of abstraction to the `Gym`'s environment, providing an interface to make it more `PyTorch` oriented. By doing so, | ||
the users shall effortlessly use `Gym`'s environment without affecting `PyTorch`'s syntactic sugar. Also, the | ||
translation to `PyTorch` codebase would improve the uniformity of the underlying kernel and helps heavily in | ||
parallelization using GPUs. | ||
|
||
### How it is done? | ||
|
||
`Workout` provides several classes that acts as an interface between `Gym` and `Pytorch`. The package is centered towards Q-Learning. So it will allow users to define their own Policies, Models, preprocessing the inputs to and outputs from the model and define training loops or use the default ones. |
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,53 @@ | ||
import gym | ||
import numpy as np | ||
import torch.nn as nn | ||
|
||
from workout.utils.env_utils import get_space_size | ||
|
||
|
||
class StateVectorEnvironemtInterface(object): | ||
def __init__(self, env_name: str): | ||
self.env_name = env_name | ||
self.__env = gym.make(env_name) | ||
self.__action_space_size = get_space_size(self.__env.action_space) | ||
self.__observation_space_size = get_space_size(self.__env.observation_space) | ||
self.__model = nn.Linear( | ||
self.__observation_space_size, self.__action_space_size | ||
) | ||
|
||
self.__train_fn = ( | ||
lambda env, model, input_preprocessor, output_preprocessor, episodes: None | ||
) | ||
|
||
def __input_preprocessor(self, model_input): | ||
return model_input # returns processed data | ||
|
||
def __output_preprocessor(self, model_output): | ||
return model_output | ||
|
||
def train(self, episodes: int): | ||
self.__train_fn( | ||
self.__env, | ||
self.__model, | ||
self.__input_preprocessor, | ||
self.__output_preprocessor, | ||
episodes, | ||
) | ||
|
||
def get_train_fn(self): | ||
return self.__train_fn | ||
|
||
def set_train_fn(self, train_fn): | ||
self.__train_fn = train_fn | ||
|
||
def get_env(self): | ||
return self.__env | ||
|
||
def set_env(self, env): | ||
self.__env = env | ||
|
||
def get_model(self): | ||
return self.__model | ||
|
||
def set_model(self, model): | ||
self.__model = model |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
# this package has default training methods for the interfaces |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
from gym.spaces import Space, Discrete, Box | ||
|
||
import numpy as np | ||
|
||
|
||
def get_space_size(space: Space): | ||
space_type = type(space) | ||
if space_type == Discrete: | ||
return space.n | ||
elif space_type == Box: | ||
return np.prod(space.shape) |