Skip to content

Commit

Permalink
Create 4 mixed_intrinsic_rewards.ipynb
Browse files Browse the repository at this point in the history
Co-Authored-By: Roger Creus <31919499+roger-creus@users.noreply.github.com>
  • Loading branch information
yuanmingqi and roger-creus committed May 14, 2024
1 parent 7eb526e commit 1b84a34
Showing 1 changed file with 116 additions and 0 deletions.
116 changes: 116 additions & 0 deletions examples/rlexplore/4 mixed_intrinsic_rewards.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"**RLeXplore allows you to combine multiple intrinsic rewards to explore the potential assembly advantages.**"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"**Load the libraries**"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import gymnasium as gym\n",
"import numpy as np\n",
"import torch as th\n",
"\n",
"from rllte.env.utils import Gymnasium2Torch\n",
"from rllte.xplore.reward import Fabric, RE3, ICM\n",
"from rllte.agent import PPO"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"**Create a fake Atari environment with image observations**"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"class FakeAtari(gym.Env):\n",
" def __init__(self):\n",
" self.action_space = gym.spaces.Discrete(7)\n",
" self.observation_space = gym.spaces.Box(low=0, high=1, shape=(4, 84, 84))\n",
" self.count = 0\n",
"\n",
" def reset(self):\n",
" self.count = 0\n",
" return self.observation_space.sample(), {}\n",
"\n",
" def step(self, action):\n",
" self.count += 1\n",
" if self.count > 100 and np.random.rand() < 0.1:\n",
" term = trunc = True\n",
" else:\n",
" term = trunc = False\n",
" return self.observation_space.sample(), 0, term, trunc, {}"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"**Use the `Fabric` class to create a mixed intrinsic reward**"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# set the parameters\n",
"device = 'cuda' if th.cuda.is_available() else 'cpu'\n",
"n_steps = 128\n",
"n_envs = 8\n",
"# create the vectorized environments\n",
"envs = gym.vector.AsyncVectorEnv([FakeAtari for _ in range(n_envs)])\n",
"# wrap the environments to convert the observations to torch tensors\n",
"envs = Gymnasium2Torch(envs, device)\n",
"# create two intrinsic reward functions\n",
"irs1 = ICM(envs, device)\n",
"irs2 = RE3(envs, device)\n",
"# create the mixed intrinsic reward function\n",
"irs = Fabric([irs1, irs2])"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# start the training\n",
"device = 'cuda' if th.cuda.is_available() else 'cpu'\n",
"# create the PPO agent\n",
"agent = PPO(envs, device=device)\n",
"# set the intrinsic reward module\n",
"agent.set(reward=irs)\n",
"# train the agent\n",
"agent.train(10000)"
]
}
],
"metadata": {
"language_info": {
"name": "python"
}
},
"nbformat": 4,
"nbformat_minor": 2
}

0 comments on commit 1b84a34

Please sign in to comment.