-
-
Notifications
You must be signed in to change notification settings - Fork 84
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Create 4 mixed_intrinsic_rewards.ipynb
Co-Authored-By: Roger Creus <31919499+roger-creus@users.noreply.github.com>
- Loading branch information
1 parent
7eb526e
commit 1b84a34
Showing
1 changed file
with
116 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,116 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"**RLeXplore allows you to combine multiple intrinsic rewards to explore the potential assembly advantages.**" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"**Load the libraries**" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"import gymnasium as gym\n", | ||
"import numpy as np\n", | ||
"import torch as th\n", | ||
"\n", | ||
"from rllte.env.utils import Gymnasium2Torch\n", | ||
"from rllte.xplore.reward import Fabric, RE3, ICM\n", | ||
"from rllte.agent import PPO" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"**Create a fake Atari environment with image observations**" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"class FakeAtari(gym.Env):\n", | ||
" def __init__(self):\n", | ||
" self.action_space = gym.spaces.Discrete(7)\n", | ||
" self.observation_space = gym.spaces.Box(low=0, high=1, shape=(4, 84, 84))\n", | ||
" self.count = 0\n", | ||
"\n", | ||
" def reset(self):\n", | ||
" self.count = 0\n", | ||
" return self.observation_space.sample(), {}\n", | ||
"\n", | ||
" def step(self, action):\n", | ||
" self.count += 1\n", | ||
" if self.count > 100 and np.random.rand() < 0.1:\n", | ||
" term = trunc = True\n", | ||
" else:\n", | ||
" term = trunc = False\n", | ||
" return self.observation_space.sample(), 0, term, trunc, {}" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"**Use the `Fabric` class to create a mixed intrinsic reward**" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"# set the parameters\n", | ||
"device = 'cuda' if th.cuda.is_available() else 'cpu'\n", | ||
"n_steps = 128\n", | ||
"n_envs = 8\n", | ||
"# create the vectorized environments\n", | ||
"envs = gym.vector.AsyncVectorEnv([FakeAtari for _ in range(n_envs)])\n", | ||
"# wrap the environments to convert the observations to torch tensors\n", | ||
"envs = Gymnasium2Torch(envs, device)\n", | ||
"# create two intrinsic reward functions\n", | ||
"irs1 = ICM(envs, device)\n", | ||
"irs2 = RE3(envs, device)\n", | ||
"# create the mixed intrinsic reward function\n", | ||
"irs = Fabric([irs1, irs2])" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"# start the training\n", | ||
"device = 'cuda' if th.cuda.is_available() else 'cpu'\n", | ||
"# create the PPO agent\n", | ||
"agent = PPO(envs, device=device)\n", | ||
"# set the intrinsic reward module\n", | ||
"agent.set(reward=irs)\n", | ||
"# train the agent\n", | ||
"agent.train(10000)" | ||
] | ||
} | ||
], | ||
"metadata": { | ||
"language_info": { | ||
"name": "python" | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 2 | ||
} |