Skip to content

Commit

Permalink
Merge pull request #1 from BBloggsbott/initial_set
Browse files Browse the repository at this point in the history
Conceptualizing work
  • Loading branch information
BBloggsbott committed Jul 20, 2020
2 parents e753ef3 + 593eecb commit 0c56c6c
Show file tree
Hide file tree
Showing 9 changed files with 89 additions and 1 deletion.
25 changes: 24 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
@@ -1 +1,24 @@
# workout
# Workout

### What is Workout?

Workout is an API to import and use `OpenAI-Gym`'s environment with `PyTorch` effortlessly

### Why it is required?

`PyTorch`: Flexible framework to implement deep neural networks and has better GPU integration
`OpenAI-Gym`: Provides extensive and varied Reinforcement Learning environments to use readily

![workout](docs/workout.jpg)

However, the integration between two is not very extensive. Many works have been done to implement
deep network based Reinforcement Learning algorithms using `PyTorch` seperately, then transfer the whole control
to `Gym`'s environment to estimate reward function, state of the system, possible actions for next step, etc.,
and pass it again to `PyTorch`'s model. Therefore, to avoid such complications, `Workout` provides a higher level of abstraction to the `Gym`'s environment, providing an interface to make it more `PyTorch` oriented. By doing so,
the users shall effortlessly use `Gym`'s environment without affecting `PyTorch`'s syntactic sugar. Also, the
translation to `PyTorch` codebase would improve the uniformity of the underlying kernel and helps heavily in
parallelization using GPUs.

### How it is done?

`Workout` provides several classes that acts as an interface between `Gym` and `Pytorch`. The package is centered towards Q-Learning. So it will allow users to define their own Policies, Models, preprocessing the inputs to and outputs from the model and define training loops or use the default ones.
Binary file added docs/gym.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/pytorch.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/workout.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Empty file added workout/interfaces/__init__.py
Empty file.
53 changes: 53 additions & 0 deletions workout/interfaces/base_interfaces.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
import gym
import numpy as np
import torch.nn as nn

from workout.utils.env_utils import get_space_size


class StateVectorEnvironemtInterface(object):
def __init__(self, env_name: str):
self.env_name = env_name
self.__env = gym.make(env_name)
self.__action_space_size = get_space_size(self.__env.action_space)
self.__observation_space_size = get_space_size(self.__env.observation_space)
self.__model = nn.Linear(
self.__observation_space_size, self.__action_space_size
)

self.__train_fn = (
lambda env, model, input_preprocessor, output_preprocessor, episodes: None
)

def __input_preprocessor(self, model_input):
return model_input # returns processed data

def __output_preprocessor(self, model_output):
return model_output

def train(self, episodes: int):
self.__train_fn(
self.__env,
self.__model,
self.__input_preprocessor,
self.__output_preprocessor,
episodes,
)

def get_train_fn(self):
return self.__train_fn

def set_train_fn(self, train_fn):
self.__train_fn = train_fn

def get_env(self):
return self.__env

def set_env(self, env):
self.__env = env

def get_model(self):
return self.__model

def set_model(self, model):
self.__model = model
1 change: 1 addition & 0 deletions workout/training/methods/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
# this package has default training methods for the interfaces
Empty file added workout/utils/__init__.py
Empty file.
11 changes: 11 additions & 0 deletions workout/utils/env_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
from gym.spaces import Space, Discrete, Box

import numpy as np


def get_space_size(space: Space):
space_type = type(space)
if space_type == Discrete:
return space.n
elif space_type == Box:
return np.prod(space.shape)

0 comments on commit 0c56c6c

Please sign in to comment.