Skip to content

Nirnai/DeepRL

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

DeepRL

DeepRL is a pytorch implementation of Deep Reinforcement Learning algorithms. The project contains the following modules:

  • models
  • algorithms
  • evaluator

Models

Models implements a set of common policies and value functions.

Implemented policies are:

  • Gaussian
  • Gaussian Bounded
  • Gaussian Clipped
  • Deterministic
  • Cross Entropy Method (CEM)

Implemented value functions are:

  • ValueFunction
  • ActionValueFunction

Algorithms

The algorithms module implements a base class (BaseRL) for RL agents. Implemented algorithms are:

Trust Region Policy Optimization (TRPO)

Proximal Policy Optimization (PPO)

Twin Delayed Deep Deterministic Policy Gradient (TD3)

Soft Actor-Critic (SAC)

Cross Entropy Guided Policy (CGP)

An example of the exposed interface is shown in the following:

import gym
from algorithms import TD3

env = gym.make('Pendulum-v1')
agent = TD3(env)
done = True
for t in range(timesteps):
  if done:
      state = env.reset()
      state, reward, done = agent.act(state)
      metrics = agent.learn()

Evaluator

The evaluator module addresses the issue of reproducable RL, by seperating the evaluation process from the core algorithms. The evaluation class allows to easily generate statistical evaluation of an algorithm

import gym 
import dm_control2gym
from algorithms import PPO, TRPO, SAC, CGP, TD3
from evaluator import Evaluator
from evaluator.plot import plot_learning_curves, load_dataset

envs = [
    ('cartpole', 'balance'),
    ('cartpole', 'swingup'),
    ('acrobot', 'swingup'),
    ('cheetah', 'run'),
    ('hopper', 'hop'),
    ('walker', 'run')
]

algs = [TD3, CGP, SAC, TRPO, PPO]

for alg in algs:
  for domain, task in envs:
        env = dm_control2gym.make(domain_name=domain, task_name=task)
        agent = alg(env)
        evl = Evaluator(agent, 'path/to/output/data')
        evl.run_statistic(samples=20, seed=0)
returns = load_dataset('path/to/returns_offline.npz')
fig, ax = plot_learning_curves(returns, interval='bs')

Requirements

  • Python3
  • PyTorch
  • Gym
  • MuJoCo
  • dm_control
  • dm_control2gym

About

No description, website, or topics provided.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages