Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add GAIL #130

Merged
merged 1 commit into from
Jun 29, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Gallery.md
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ Users are also welcome to contribute their own training examples and demos to th
| [Dual-clip PPO](https://arxiv.org/abs/1912.09729) | ![discrete](https://img.shields.io/badge/-discrete-brightgreen) | [code](./examples/cartpole/) |
| [MAPPO](https://arxiv.org/abs/2103.01955) | ![MARL](https://img.shields.io/badge/-MARL-yellow) | [code](./examples/mpe/) |
| [JRPO](https://arxiv.org/abs/2302.07515) | ![MARL](https://img.shields.io/badge/-MARL-yellow) | [code](./examples/mpe/) |
| [GAIL](https://arxiv.org/abs/1606.03476) | ![offline](https://img.shields.io/badge/-offlineRL-darkblue) | [code](./examples/gail/) |
| [DQN](https://arxiv.org/abs/1312.5602) | ![discrete](https://img.shields.io/badge/-discrete-brightgreen) ![value](https://img.shields.io/badge/-value-orange) ![offpolicy](https://img.shields.io/badge/-offpolicy-blue) | [code](./examples/gridworld/) |
| [MAT](https://arxiv.org/abs/2205.14953) | ![MARL](https://img.shields.io/badge/-MARL-yellow) ![Transformer](https://img.shields.io/badge/-Transformer-blue) | [code](./examples/mpe/) |
| Self-Play | ![selfplay](https://img.shields.io/badge/-selfplay-blue) ![discrete](https://img.shields.io/badge/-discrete-brightgreen) | [code](./examples/selfplay/) |
Expand Down
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ Algorithms currently supported by OpenRL (for more details, please refer to [Gal
- [Dual-clip PPO](https://arxiv.org/abs/1912.09729)
- [Multi-agent PPO (MAPPO)](https://arxiv.org/abs/2103.01955)
- [Joint-ratio Policy Optimization (JRPO)](https://arxiv.org/abs/2302.07515)
- [Generative Adversarial Imitation Learning (GAIL)](https://arxiv.org/abs/1606.03476)
- [Deep Q-Network (DQN)](https://arxiv.org/abs/1312.5602)
- [Multi-Agent Transformer (MAT)](https://arxiv.org/abs/2205.14953)

Expand Down
1 change: 1 addition & 0 deletions README_zh.md
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ OpenRL目前支持的算法(更多详情请参考 [Gallery](Gallery.md)):
- [Dual-clip PPO](https://arxiv.org/abs/1912.09729)
- [Multi-agent PPO (MAPPO)](https://arxiv.org/abs/2103.01955)
- [Joint-ratio Policy Optimization (JRPO)](https://arxiv.org/abs/2302.07515)
- [Generative Adversarial Imitation Learning (GAIL)](https://arxiv.org/abs/1606.03476)
- [Deep Q-Network (DQN)](https://arxiv.org/abs/1312.5602)
- [Multi-Agent Transformer (MAT)](https://arxiv.org/abs/2205.14953)

Expand Down
7 changes: 5 additions & 2 deletions examples/gail/README.md
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
## Prepare Dataset

Run following command to generate dataset for GAIL: `python gen_data.py`
Run following command to generate dataset for GAIL: `python gen_data.py`, then you will get a file named `data.pkl` in current folder.

## Train

Run following command to train GAIL: `python train_gail.py --config cartpole_gail.yaml`
Run following command to train GAIL: `python train_gail.py --config cartpole_gail.yaml`

With GAIL, we can even train the agent without expert action!
Run following command to train GAIL without expert action: `python train_gail.py --config cartpole_gail_without_action.yaml`
1 change: 1 addition & 0 deletions examples/gail/cartpole_gail.yaml
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
seed: 0
expert_data: "./data.pkl"
reward_class:
id: "GAILReward"
5 changes: 5 additions & 0 deletions examples/gail/cartpole_gail_without_action.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
seed: 0
expert_data: "./data.pkl"
gail_use_action: false
reward_class:
id: "GAILReward"
6 changes: 3 additions & 3 deletions examples/gail/gen_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def train():
return agent


def gen_data():
def gen_data(total_episode):
# begin to test
# Create an environment for testing and set the number of environments to interact with to 9. Set rendering mode to group_human.
render_mode = None
Expand All @@ -50,7 +50,7 @@ def gen_data():
agent = Agent(Net(env))
agent.load("ppo_agent")

env = GenDataWrapper(env, data_save_path="data.pkl", total_episode=5000)
env = GenDataWrapper(env, data_save_path="data.pkl", total_episode=total_episode)
obs, info = env.reset()
done = False
while not done:
Expand All @@ -62,4 +62,4 @@ def gen_data():

if __name__ == "__main__":
train()
gen_data()
gen_data(total_episode=500)
42 changes: 42 additions & 0 deletions examples/gail/gen_data_v1.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
"""
Used for generate offline data for GAIL.
"""

from openrl.envs.common import make
from openrl.envs.vec_env.wrappers.gen_data import GenDataWrapper_v1 as GenDataWrapper
from openrl.envs.wrappers.monitor import Monitor
from openrl.modules.common import PPONet as Net
from openrl.runners.common import PPOAgent as Agent

env_wrappers = [
Monitor,
]


def gen_data(total_episode):
# begin to test
# Create an environment for testing and set the number of environments to interact with to 9. Set rendering mode to group_human.
render_mode = None
env = make(
"CartPole-v1",
render_mode=render_mode,
env_num=9,
asynchronous=True,
env_wrappers=env_wrappers,
)

agent = Agent(Net(env))
agent.load("ppo_agent")

env = GenDataWrapper(env, data_save_path="data_v1.pkl", total_episode=total_episode)
obs, info = env.reset()
done = False
while not done:
# Based on environmental observation input, predict next action.
action, _ = agent.act(obs, deterministic=True)
obs, r, done, info = env.step(action)
env.close()


if __name__ == "__main__":
gen_data(total_episode=50)
38 changes: 38 additions & 0 deletions examples/gail/test_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Copyright 2023 The OpenRL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

""""""
import torch

from openrl.datasets.expert_dataset import ExpertDataset


def test_dataset():
dataset = ExpertDataset(file_name="data_small.pkl", seed=0)
print("data length:", len(dataset))
print("data[0]:", dataset[0][0])
print("data[1]:", dataset[1][0])
print("data[len(data)-1]:", dataset[len(dataset) - 1][0])

data_loader = torch.utils.data.DataLoader(
dataset=dataset, batch_size=128, shuffle=False, drop_last=True
)
for batch_data in data_loader:
expert_obs, expert_action = batch_data


if __name__ == "__main__":
test_dataset()
16 changes: 11 additions & 5 deletions examples/gail/train_gail.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

from openrl.configs.config import create_config_parser
from openrl.envs.common import make
from openrl.envs.wrappers.extra_wrappers import ZeroRewardWrapper
from openrl.modules.common import GAILNet as Net
from openrl.runners.common import GAILAgent as Agent

Expand All @@ -12,18 +13,18 @@ def train():
cfg_parser = create_config_parser()
cfg = cfg_parser.parse_args()

# We use ZeroRewardWrapper to make sure that we don't get any reward from the environment.
# create environment, set environment parallelism to 9
env = make("CartPole-v1", env_num=3, cfg=cfg)
env = make("CartPole-v1", env_num=3, cfg=cfg, env_wrappers=[ZeroRewardWrapper])

net = Net(
env,
cfg=cfg,
)
# initialize the trainer
agent = Agent(net)
# start training, set total number of training steps to 20000
# agent.train(total_time_steps=20000)
agent.train(total_time_steps=600)
# start training, set total number of training steps to 5000
agent.train(total_time_steps=7500)

env.close()
return agent
Expand All @@ -32,7 +33,12 @@ def train():
def evaluation(agent):
# begin to test
# Create an environment for testing and set the number of environments to interact with to 9. Set rendering mode to group_human.
env = make("CartPole-v1", render_mode="group_human", env_num=9, asynchronous=True)
render_mode = ( # use this if you want to see the rendering of the environment
"group_human"
)
render_mode = None
env = make("CartPole-v1", render_mode=render_mode, env_num=9, asynchronous=True)

# The trained agent sets up the interactive environment it needs.
agent.set_env(env)
# Initialize the environment and get initial observations and environmental information.
Expand Down
Loading
Loading