diff --git a/examples/rlexplore/0 quick_start.ipynb b/examples/rlexplore/0 quick_start.ipynb new file mode 100644 index 00000000..52172336 --- /dev/null +++ b/examples/rlexplore/0 quick_start.ipynb @@ -0,0 +1,486 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "import gymnasium as gym\n", + "import numpy as np\n", + "import torch as th\n", + "\n", + "from rllte.env.utils import Gymnasium2Torch\n", + "from rllte.xplore.reward import ICM" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "**Create a fake Atari environment with image observations**" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "class FakeAtari(gym.Env):\n", + " def __init__(self):\n", + " self.action_space = gym.spaces.Discrete(7)\n", + " self.observation_space = gym.spaces.Box(low=0, high=1, shape=(4, 84, 84))\n", + " self.count = 0\n", + "\n", + " def reset(self):\n", + " self.count = 0\n", + " return self.observation_space.sample(), {}\n", + "\n", + " def step(self, action):\n", + " self.count += 1\n", + " if self.count > 100 and np.random.rand() < 0.1:\n", + " term = trunc = True\n", + " else:\n", + " term = trunc = False\n", + " return self.observation_space.sample(), 0, term, trunc, {}" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "**Synchronous Mode, the `.update()` will be automatically invoked in the `.compute()` function, usually for on-policy RL algorithms.**" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Box(0.0, 1.0, (4, 84, 84), float32)\n", + "Discrete(7)\n", + "tensor([[5.3476, 7.1356, 6.8465, ..., 7.4179, 7.0747, 5.7406],\n", + " [6.2589, 7.1989, 5.0291, ..., 6.7728, 7.8519, 6.2785],\n", + " [6.5028, 7.3675, 5.8179, ..., 6.5310, 5.7410, 6.5957],\n", + " ...,\n", + " [7.3416, 6.4600, 5.5094, ..., 7.5757, 8.3019, 6.7766],\n", + " [7.3124, 6.6850, 6.6613, ..., 6.3896, 7.5636, 7.0359],\n", + " [8.4999, 6.5634, 7.4811, ..., 7.7395, 7.5860, 7.3720]],\n", + " device='mps:0')\n", + "torch.Size([128, 8])\n" + ] + } + ], + "source": [ + "# set the parameters\n", + "device = 'mps' if th.backends.mps.is_available() else 'cuda' if th.cuda.is_available() else 'cpu'\n", + "n_steps = 128\n", + "n_envs = 8\n", + "# create the vectorized environments\n", + "envs = gym.vector.AsyncVectorEnv([FakeAtari for _ in range(n_envs)])\n", + "# wrap the environments to convert the observations to torch tensors\n", + "envs = Gymnasium2Torch(envs, device)\n", + "# create the intrinsic reward module\n", + "irs = ICM(envs, device)\n", + "# reset the environments and get the initial observations\n", + "obs, infos = envs.reset()\n", + "# create a dictionary to store the samples\n", + "samples = {'observations':[], \n", + " 'actions':[], \n", + " 'rewards':[],\n", + " 'terminateds':[],\n", + " 'truncateds':[],\n", + " 'next_observations':[]}\n", + "# sampling loop\n", + "for _ in range(n_steps):\n", + " # sample random actions\n", + " actions = th.stack([th.as_tensor(envs.action_space.sample()) for _ in range(n_envs)])\n", + " # environment step\n", + " next_obs, rewards, terminateds, truncateds, infos = envs.step(actions)\n", + " # watch the interactions and get necessary information for the intrinsic reward computation\n", + " irs.watch(observations=obs, \n", + " actions=actions, \n", + " rewards=rewards,\n", + " terminateds=terminateds,\n", + " truncateds=truncateds,\n", + " next_observations=next_obs)\n", + " # store the samples\n", + " samples['observations'].append(obs)\n", + " samples['actions'].append(actions)\n", + " samples['rewards'].append(rewards)\n", + " samples['terminateds'].append(terminateds)\n", + " samples['truncateds'].append(truncateds)\n", + " samples['next_observations'].append(next_obs)\n", + " obs = next_obs\n", + "# compute the intrinsic rewards\n", + "samples = {k: th.stack(v) for k, v in samples.items()}\n", + "intrinsic_rewards = irs.compute(samples=samples)\n", + "print(intrinsic_rewards)\n", + "print(intrinsic_rewards.shape)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "**Asynchronous Mode, the `.update()` must be invoked separately, usually for off-policy RL algorithms.**" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "tensor([[1.8394, 1.8430, 2.5142, 2.0302, 2.1765, 1.6593, 1.8448, 1.6650]],\n", + " device='mps:0') torch.Size([1, 8])\n", + "tensor([[0.0185, 0.0136, 0.0185, 0.0149, 0.0150, 0.0161, 0.0186, 0.0180]],\n", + " device='mps:0') torch.Size([1, 8])\n", + "tensor([[0.0159, 0.0131, 0.0197, 0.0171, 0.0144, 0.0159, 0.0196, 0.0190]],\n", + " device='mps:0') torch.Size([1, 8])\n", + "tensor([[0.0156, 0.0186, 0.0207, 0.0172, 0.0214, 0.0157, 0.0186, 0.0208]],\n", + " device='mps:0') torch.Size([1, 8])\n", + "tensor([[0.0224, 0.0202, 0.0201, 0.0224, 0.0201, 0.0202, 0.0224, 0.0187]],\n", + " device='mps:0') torch.Size([1, 8])\n", + "tensor([[0.0240, 0.0248, 0.0165, 0.0183, 0.0242, 0.0183, 0.0182, 0.0249]],\n", + " device='mps:0') torch.Size([1, 8])\n", + "tensor([[0.0231, 0.0264, 0.0264, 0.0265, 0.0265, 0.0175, 0.0257, 0.0265]],\n", + " device='mps:0') torch.Size([1, 8])\n", + "tensor([[0.0243, 0.0187, 0.0280, 0.0187, 0.0206, 0.0225, 0.0187, 0.0206]],\n", + " device='mps:0') torch.Size([1, 8])\n", + "tensor([[0.0295, 0.0197, 0.0294, 0.0238, 0.0197, 0.0285, 0.0295, 0.0216]],\n", + " device='mps:0') torch.Size([1, 8])\n", + "tensor([[0.0269, 0.0248, 0.0267, 0.0227, 0.0309, 0.0226, 0.0206, 0.0248]],\n", + " device='mps:0') torch.Size([1, 8])\n", + "tensor([[0.0237, 0.0281, 0.0322, 0.0280, 0.0313, 0.0236, 0.0214, 0.0279]],\n", + " device='mps:0') torch.Size([1, 8])\n", + "tensor([[0.0290, 0.0324, 0.0335, 0.0224, 0.0335, 0.0335, 0.0291, 0.0245]],\n", + " device='mps:0') torch.Size([1, 8])\n", + "tensor([[0.0303, 0.0336, 0.0349, 0.0279, 0.0232, 0.0302, 0.0336, 0.0280]],\n", + " device='mps:0') torch.Size([1, 8])\n", + "tensor([[0.0312, 0.0312, 0.0313, 0.0359, 0.0359, 0.0291, 0.0240, 0.0290]],\n", + " device='mps:0') torch.Size([1, 8])\n", + "tensor([[0.0359, 0.0360, 0.0371, 0.0361, 0.0273, 0.0322, 0.0372, 0.0273]],\n", + " device='mps:0') torch.Size([1, 8])\n", + "tensor([[0.0382, 0.0383, 0.0332, 0.0308, 0.0281, 0.0281, 0.0333, 0.0384]],\n", + " device='mps:0') torch.Size([1, 8])\n", + "tensor([[0.0288, 0.0262, 0.0343, 0.0263, 0.0394, 0.0393, 0.0382, 0.0393]],\n", + " device='mps:0') torch.Size([1, 8])\n", + "tensor([[0.0405, 0.0326, 0.0353, 0.0351, 0.0404, 0.0352, 0.0391, 0.0405]],\n", + " device='mps:0') torch.Size([1, 8])\n", + "tensor([[0.0305, 0.0414, 0.0304, 0.0336, 0.0305, 0.0402, 0.0414, 0.0278]],\n", + " device='mps:0') torch.Size([1, 8])\n", + "tensor([[0.0411, 0.0410, 0.0313, 0.0313, 0.0311, 0.0424, 0.0425, 0.0343]],\n", + " device='mps:0') torch.Size([1, 8])\n", + "tensor([[0.0289, 0.0379, 0.0349, 0.0437, 0.0419, 0.0290, 0.0436, 0.0289]],\n", + " device='mps:0') torch.Size([1, 8])\n", + "tensor([[0.0297, 0.0296, 0.0445, 0.0326, 0.0387, 0.0326, 0.0325, 0.0445]],\n", + " device='mps:0') torch.Size([1, 8])\n", + "tensor([[0.0367, 0.0438, 0.0395, 0.0452, 0.0454, 0.0366, 0.0395, 0.0396]],\n", + " device='mps:0') torch.Size([1, 8])\n", + "tensor([[0.0447, 0.0340, 0.0463, 0.0373, 0.0463, 0.0309, 0.0465, 0.0402]],\n", + " device='mps:0') torch.Size([1, 8])\n", + "tensor([[0.0459, 0.0380, 0.0458, 0.0410, 0.0382, 0.0458, 0.0473, 0.0316]],\n", + " device='mps:0') torch.Size([1, 8])\n", + "tensor([[0.0480, 0.0419, 0.0481, 0.0482, 0.0480, 0.0388, 0.0468, 0.0417]],\n", + " device='mps:0') torch.Size([1, 8])\n", + "tensor([[0.0490, 0.0476, 0.0489, 0.0490, 0.0477, 0.0327, 0.0475, 0.0360]],\n", + " device='mps:0') torch.Size([1, 8])\n", + "tensor([[0.0500, 0.0484, 0.0434, 0.0404, 0.0436, 0.0332, 0.0402, 0.0403]],\n", + " device='mps:0') torch.Size([1, 8])\n", + "tensor([[0.0507, 0.0339, 0.0373, 0.0441, 0.0507, 0.0493, 0.0410, 0.0492]],\n", + " device='mps:0') torch.Size([1, 8])\n", + "tensor([[0.0450, 0.0499, 0.0500, 0.0379, 0.0448, 0.0517, 0.0517, 0.0415]],\n", + " device='mps:0') torch.Size([1, 8])\n", + "tensor([[0.0524, 0.0507, 0.0526, 0.0525, 0.0525, 0.0422, 0.0348, 0.0386]],\n", + " device='mps:0') torch.Size([1, 8])\n", + "tensor([[0.0353, 0.0429, 0.0355, 0.0534, 0.0515, 0.0428, 0.0517, 0.0392]],\n", + " device='mps:0') torch.Size([1, 8])\n", + "tensor([[0.0436, 0.0523, 0.0540, 0.0469, 0.0538, 0.0468, 0.0540, 0.0398]],\n", + " device='mps:0') torch.Size([1, 8])\n", + "tensor([[0.0532, 0.0477, 0.0402, 0.0365, 0.0547, 0.0443, 0.0401, 0.0547]],\n", + " device='mps:0') torch.Size([1, 8])\n", + "tensor([[0.0555, 0.0409, 0.0556, 0.0538, 0.0369, 0.0555, 0.0556, 0.0539]],\n", + " device='mps:0') torch.Size([1, 8])\n", + "tensor([[0.0546, 0.0455, 0.0543, 0.0455, 0.0563, 0.0415, 0.0562, 0.0454]],\n", + " device='mps:0') torch.Size([1, 8])\n", + "tensor([[0.0497, 0.0570, 0.0551, 0.0571, 0.0574, 0.0461, 0.0420, 0.0496]],\n", + " device='mps:0') torch.Size([1, 8])\n", + "tensor([[0.0559, 0.0425, 0.0559, 0.0385, 0.0502, 0.0426, 0.0577, 0.0579]],\n", + " device='mps:0') torch.Size([1, 8])\n", + "tensor([[0.0509, 0.0432, 0.0586, 0.0391, 0.0429, 0.0432, 0.0430, 0.0585]],\n", + " device='mps:0') torch.Size([1, 8])\n", + "tensor([[0.0395, 0.0513, 0.0594, 0.0593, 0.0573, 0.0574, 0.0594, 0.0571]],\n", + " device='mps:0') torch.Size([1, 8])\n", + "tensor([[0.0600, 0.0522, 0.0601, 0.0524, 0.0442, 0.0598, 0.0441, 0.0582]],\n", + " device='mps:0') torch.Size([1, 8])\n", + "tensor([[0.0490, 0.0406, 0.0608, 0.0404, 0.0606, 0.0586, 0.0606, 0.0405]],\n", + " device='mps:0') torch.Size([1, 8])\n", + "tensor([[0.0595, 0.0410, 0.0615, 0.0409, 0.0614, 0.0410, 0.0496, 0.0615]],\n", + " device='mps:0') torch.Size([1, 8])\n", + "tensor([[0.0502, 0.0455, 0.0621, 0.0456, 0.0621, 0.0623, 0.0502, 0.0620]],\n", + " device='mps:0') torch.Size([1, 8])\n", + "tensor([[0.0462, 0.0421, 0.0609, 0.0630, 0.0463, 0.0606, 0.0545, 0.0547]],\n", + " device='mps:0') torch.Size([1, 8])\n", + "tensor([[0.0512, 0.0425, 0.0513, 0.0634, 0.0466, 0.0425, 0.0634, 0.0635]],\n", + " device='mps:0') torch.Size([1, 8])\n", + "tensor([[0.0470, 0.0642, 0.0474, 0.0621, 0.0557, 0.0472, 0.0625, 0.0471]],\n", + " device='mps:0') torch.Size([1, 8])\n", + "tensor([[0.0648, 0.0523, 0.0431, 0.0648, 0.0565, 0.0521, 0.0526, 0.0566]],\n", + " device='mps:0') torch.Size([1, 8])\n", + "tensor([[0.0437, 0.0529, 0.0568, 0.0655, 0.0479, 0.0437, 0.0633, 0.0529]],\n", + " device='mps:0') torch.Size([1, 8])\n", + "tensor([[0.0662, 0.0659, 0.0533, 0.0485, 0.0575, 0.0642, 0.0660, 0.0660]],\n", + " device='mps:0') torch.Size([1, 8])\n", + "tensor([[0.0667, 0.0669, 0.0669, 0.0581, 0.0538, 0.0447, 0.0667, 0.0540]],\n", + " device='mps:0') torch.Size([1, 8])\n", + "tensor([[0.0674, 0.0672, 0.0496, 0.0677, 0.0674, 0.0674, 0.0544, 0.0448]],\n", + " device='mps:0') torch.Size([1, 8])\n", + "tensor([[0.0682, 0.0679, 0.0680, 0.0498, 0.0453, 0.0456, 0.0681, 0.0592]],\n", + " device='mps:0') torch.Size([1, 8])\n", + "tensor([[0.0663, 0.0597, 0.0686, 0.0552, 0.0554, 0.0665, 0.0686, 0.0665]],\n", + " device='mps:0') torch.Size([1, 8])\n", + "tensor([[0.0689, 0.0603, 0.0511, 0.0669, 0.0560, 0.0604, 0.0668, 0.0602]],\n", + " device='mps:0') torch.Size([1, 8])\n", + "tensor([[0.0701, 0.0676, 0.0566, 0.0515, 0.0610, 0.0676, 0.0564, 0.0467]],\n", + " device='mps:0') torch.Size([1, 8])\n", + "tensor([[0.0569, 0.0708, 0.0471, 0.0573, 0.0518, 0.0473, 0.0613, 0.0472]],\n", + " device='mps:0') torch.Size([1, 8])\n", + "tensor([[0.0619, 0.0573, 0.0573, 0.0475, 0.0522, 0.0471, 0.0575, 0.0691]],\n", + " device='mps:0') torch.Size([1, 8])\n", + "tensor([[0.0714, 0.0718, 0.0477, 0.0717, 0.0527, 0.0716, 0.0477, 0.0624]],\n", + " device='mps:0') torch.Size([1, 8])\n", + "tensor([[0.0481, 0.0584, 0.0723, 0.0629, 0.0630, 0.0724, 0.0534, 0.0699]],\n", + " device='mps:0') torch.Size([1, 8])\n", + "tensor([[0.0484, 0.0705, 0.0636, 0.0486, 0.0538, 0.0590, 0.0586, 0.0536]],\n", + " device='mps:0') torch.Size([1, 8])\n", + "tensor([[0.0641, 0.0541, 0.0593, 0.0595, 0.0591, 0.0737, 0.0736, 0.0639]],\n", + " device='mps:0') torch.Size([1, 8])\n", + "tensor([[0.0743, 0.0741, 0.0545, 0.0543, 0.0599, 0.0596, 0.0739, 0.0497]],\n", + " device='mps:0') torch.Size([1, 8])\n", + "tensor([[0.0722, 0.0495, 0.0724, 0.0648, 0.0724, 0.0601, 0.0747, 0.0603]],\n", + " device='mps:0') torch.Size([1, 8])\n", + "tensor([[0.0551, 0.0752, 0.0606, 0.0552, 0.0749, 0.0753, 0.0754, 0.0753]],\n", + " device='mps:0') torch.Size([1, 8])\n", + "tensor([[0.0560, 0.0508, 0.0611, 0.0759, 0.0556, 0.0761, 0.0612, 0.0611]],\n", + " device='mps:0') torch.Size([1, 8])\n", + "tensor([[0.0510, 0.0559, 0.0766, 0.0762, 0.0507, 0.0665, 0.0766, 0.0664]],\n", + " device='mps:0') torch.Size([1, 8])\n", + "tensor([[0.0512, 0.0744, 0.0619, 0.0742, 0.0667, 0.0565, 0.0668, 0.0745]],\n", + " device='mps:0') torch.Size([1, 8])\n", + "tensor([[0.0772, 0.0774, 0.0774, 0.0776, 0.0749, 0.0674, 0.0776, 0.0750]],\n", + " device='mps:0') torch.Size([1, 8])\n", + "tensor([[0.0572, 0.0629, 0.0680, 0.0754, 0.0678, 0.0575, 0.0777, 0.0677]],\n", + " device='mps:0') torch.Size([1, 8])\n", + "tensor([[0.0524, 0.0761, 0.0636, 0.0785, 0.0684, 0.0788, 0.0786, 0.0788]],\n", + " device='mps:0') torch.Size([1, 8])\n", + "tensor([[0.0766, 0.0767, 0.0639, 0.0686, 0.0792, 0.0582, 0.0792, 0.0793]],\n", + " device='mps:0') torch.Size([1, 8])\n", + "tensor([[0.0692, 0.0798, 0.0644, 0.0796, 0.0771, 0.0583, 0.0642, 0.0692]],\n", + " device='mps:0') torch.Size([1, 8])\n", + "tensor([[0.0696, 0.0647, 0.0804, 0.0645, 0.0533, 0.0648, 0.0804, 0.0801]],\n", + " device='mps:0') torch.Size([1, 8])\n", + "tensor([[0.0807, 0.0807, 0.0653, 0.0652, 0.0537, 0.0537, 0.0652, 0.0650]],\n", + " device='mps:0') torch.Size([1, 8])\n", + "tensor([[0.0812, 0.0813, 0.0706, 0.0706, 0.0816, 0.0788, 0.0707, 0.0543]],\n", + " device='mps:0') torch.Size([1, 8])\n", + "tensor([[0.0548, 0.0816, 0.0815, 0.0792, 0.0709, 0.0660, 0.0600, 0.0789]],\n", + " device='mps:0') torch.Size([1, 8])\n", + "tensor([[0.0665, 0.0602, 0.0825, 0.0825, 0.0820, 0.0794, 0.0821, 0.0822]],\n", + " device='mps:0') torch.Size([1, 8])\n", + "tensor([[0.0802, 0.0831, 0.0801, 0.0667, 0.0828, 0.0829, 0.0607, 0.0667]],\n", + " device='mps:0') torch.Size([1, 8])\n", + "tensor([[0.0833, 0.0613, 0.0557, 0.0612, 0.0555, 0.0809, 0.0610, 0.0554]],\n", + " device='mps:0') torch.Size([1, 8])\n", + "tensor([[0.0616, 0.0811, 0.0840, 0.0562, 0.0729, 0.0838, 0.0836, 0.0621]],\n", + " device='mps:0') torch.Size([1, 8])\n", + "tensor([[0.0847, 0.0560, 0.0565, 0.0622, 0.0846, 0.0732, 0.0843, 0.0840]],\n", + " device='mps:0') torch.Size([1, 8])\n", + "tensor([[0.0850, 0.0737, 0.0846, 0.0624, 0.0851, 0.0567, 0.0568, 0.0851]],\n", + " device='mps:0') torch.Size([1, 8])\n", + "tensor([[0.0687, 0.0745, 0.0855, 0.0570, 0.0569, 0.0852, 0.0855, 0.0855]],\n", + " device='mps:0') torch.Size([1, 8])\n", + "tensor([[0.0692, 0.0859, 0.0694, 0.0856, 0.0745, 0.0747, 0.0745, 0.0632]],\n", + " device='mps:0') torch.Size([1, 8])\n", + "tensor([[0.0865, 0.0577, 0.0752, 0.0865, 0.0862, 0.0868, 0.0575, 0.0575]],\n", + " device='mps:0') torch.Size([1, 8])\n", + "tensor([[0.0755, 0.0640, 0.0868, 0.0640, 0.0756, 0.0755, 0.0580, 0.0842]],\n", + " device='mps:0') torch.Size([1, 8])\n", + "tensor([[0.0874, 0.0879, 0.0706, 0.0763, 0.0846, 0.0761, 0.0582, 0.0875]],\n", + " device='mps:0') torch.Size([1, 8])\n", + "tensor([[0.0765, 0.0876, 0.0851, 0.0762, 0.0767, 0.0586, 0.0878, 0.0878]],\n", + " device='mps:0') torch.Size([1, 8])\n", + "tensor([[0.0772, 0.0649, 0.0712, 0.0856, 0.0589, 0.0649, 0.0646, 0.0711]],\n", + " device='mps:0') torch.Size([1, 8])\n", + "tensor([[0.0714, 0.0594, 0.0649, 0.0889, 0.0887, 0.0592, 0.0718, 0.0591]],\n", + " device='mps:0') torch.Size([1, 8])\n", + "tensor([[0.0893, 0.0867, 0.0894, 0.0595, 0.0720, 0.0893, 0.0891, 0.0660]],\n", + " device='mps:0') torch.Size([1, 8])\n", + "tensor([[0.0897, 0.0896, 0.0868, 0.0868, 0.0894, 0.0724, 0.0662, 0.0723]],\n", + " device='mps:0') torch.Size([1, 8])\n", + "tensor([[0.0874, 0.0901, 0.0786, 0.0873, 0.0600, 0.0784, 0.0781, 0.0904]],\n", + " device='mps:0') torch.Size([1, 8])\n", + "tensor([[0.0667, 0.0605, 0.0792, 0.0662, 0.0734, 0.0665, 0.0910, 0.0908]],\n", + " device='mps:0') torch.Size([1, 8])\n", + "tensor([[0.0914, 0.0608, 0.0912, 0.0914, 0.0609, 0.0882, 0.0883, 0.0671]],\n", + " device='mps:0') torch.Size([1, 8])\n", + "tensor([[0.0890, 0.0739, 0.0677, 0.0739, 0.0609, 0.0796, 0.0916, 0.0802]],\n", + " device='mps:0') torch.Size([1, 8])\n", + "tensor([[0.0679, 0.0890, 0.0614, 0.0680, 0.0742, 0.0677, 0.0921, 0.0798]],\n", + " device='mps:0') torch.Size([1, 8])\n", + "tensor([[0.0896, 0.0897, 0.0923, 0.0925, 0.0807, 0.0678, 0.0925, 0.0618]],\n", + " device='mps:0') torch.Size([1, 8])\n", + "tensor([[0.0683, 0.0683, 0.0751, 0.0683, 0.0682, 0.0683, 0.0928, 0.0929]],\n", + " device='mps:0') torch.Size([1, 8])\n", + "tensor([[0.0622, 0.0813, 0.0756, 0.0939, 0.0624, 0.0815, 0.0625, 0.0905]],\n", + " device='mps:0') torch.Size([1, 8])\n", + "tensor([[0.0818, 0.0910, 0.0909, 0.0906, 0.0940, 0.0816, 0.0908, 0.0685]],\n", + " device='mps:0') torch.Size([1, 8])\n", + "tensor([[0.0942, 0.0822, 0.0945, 0.0914, 0.0821, 0.0628, 0.0914, 0.0632]],\n", + " device='mps:0') torch.Size([1, 8])\n", + "tensor([[0.0825, 0.0947, 0.0945, 0.0695, 0.0946, 0.0946, 0.0918, 0.0698]],\n", + " device='mps:0') torch.Size([1, 8])\n", + "tensor([[0.0771, 0.0766, 0.0954, 0.0827, 0.0953, 0.0769, 0.0638, 0.0699]],\n", + " device='mps:0') torch.Size([1, 8])\n", + "tensor([[0.0775, 0.0634, 0.0832, 0.0703, 0.0957, 0.0706, 0.0928, 0.0958]],\n", + " device='mps:0') torch.Size([1, 8])\n", + "tensor([[0.0640, 0.0966, 0.0836, 0.0958, 0.0642, 0.0838, 0.0706, 0.0776]],\n", + " device='mps:0') torch.Size([1, 8])\n", + "tensor([[0.0962, 0.0711, 0.0967, 0.0936, 0.0840, 0.0782, 0.0779, 0.0841]],\n", + " device='mps:0') torch.Size([1, 8])\n", + "tensor([[0.0974, 0.0844, 0.0649, 0.0969, 0.0715, 0.0786, 0.0645, 0.0973]],\n", + " device='mps:0') torch.Size([1, 8])\n", + "tensor([[0.0650, 0.0651, 0.0977, 0.0941, 0.0786, 0.0788, 0.0715, 0.0784]],\n", + " device='mps:0') torch.Size([1, 8])\n", + "tensor([[0.0653, 0.0945, 0.0655, 0.0718, 0.0720, 0.0792, 0.0656, 0.0720]],\n", + " device='mps:0') torch.Size([1, 8])\n", + "tensor([[0.0856, 0.0721, 0.0655, 0.0955, 0.0985, 0.0985, 0.0985, 0.0799]],\n", + " device='mps:0') torch.Size([1, 8])\n", + "tensor([[0.0730, 0.0991, 0.0658, 0.0857, 0.0723, 0.0797, 0.0725, 0.0797]],\n", + " device='mps:0') torch.Size([1, 8])\n", + "tensor([[0.0732, 0.0661, 0.0867, 0.0804, 0.0802, 0.0959, 0.0996, 0.0994]],\n", + " device='mps:0') torch.Size([1, 8])\n", + "tensor([[0.0666, 0.0732, 0.0803, 0.0967, 0.0730, 0.0731, 0.0996, 0.1000]],\n", + " device='mps:0') torch.Size([1, 8])\n", + "tensor([[0.0811, 0.1001, 0.1002, 0.0971, 0.0999, 0.0966, 0.0735, 0.0669]],\n", + " device='mps:0') torch.Size([1, 8])\n", + "tensor([[0.0973, 0.0813, 0.0874, 0.0740, 0.0671, 0.0970, 0.0673, 0.0874]],\n", + " device='mps:0') torch.Size([1, 8])\n", + "tensor([[0.0674, 0.0978, 0.0878, 0.0673, 0.1011, 0.0742, 0.1007, 0.0672]],\n", + " device='mps:0') torch.Size([1, 8])\n", + "tensor([[0.0983, 0.0675, 0.0744, 0.1014, 0.0676, 0.1015, 0.1013, 0.1013]],\n", + " device='mps:0') torch.Size([1, 8])\n", + "tensor([[0.0883, 0.1018, 0.0821, 0.1019, 0.1017, 0.1019, 0.1019, 0.0824]],\n", + " device='mps:0') torch.Size([1, 8])\n", + "tensor([[0.1028, 0.0751, 0.0825, 0.0989, 0.0828, 0.1021, 0.0684, 0.0682]],\n", + " device='mps:0') torch.Size([1, 8])\n", + "tensor([[0.0892, 0.1024, 0.1026, 0.0995, 0.1027, 0.0753, 0.1025, 0.0755]],\n", + " device='mps:0') torch.Size([1, 8])\n", + "tensor([[0.0835, 0.1032, 0.0999, 0.0686, 0.0998, 0.1036, 0.0999, 0.1033]],\n", + " device='mps:0') torch.Size([1, 8])\n", + "tensor([[0.0759, 0.0836, 0.1000, 0.0901, 0.0691, 0.1036, 0.0761, 0.1002]],\n", + " device='mps:0') torch.Size([1, 8])\n", + "tensor([[0.0839, 0.1039, 0.0691, 0.0908, 0.0761, 0.1041, 0.0837, 0.0900]],\n", + " device='mps:0') torch.Size([1, 8])\n", + "tensor([[0.1011, 0.1008, 0.0696, 0.1040, 0.1042, 0.1007, 0.1015, 0.0769]],\n", + " device='mps:0') torch.Size([1, 8])\n", + "tensor([[0.0908, 0.1050, 0.1051, 0.1049, 0.1050, 0.0768, 0.1047, 0.0770]],\n", + " device='mps:0') torch.Size([1, 8])\n", + "tensor([[0.1018, 0.0701, 0.0701, 0.0846, 0.1051, 0.1052, 0.0701, 0.0699]],\n", + " device='mps:0') torch.Size([1, 8])\n" + ] + } + ], + "source": [ + "# set the parameters\n", + "device = 'mps' if th.backends.mps.is_available() else 'cuda' if th.cuda.is_available() else 'cpu'\n", + "n_steps = 128\n", + "n_envs = 8\n", + "# create the vectorized environments\n", + "envs = gym.vector.AsyncVectorEnv([FakeAtari for _ in range(n_envs)])\n", + "# wrap the environments to convert the observations to torch tensors\n", + "envs = Gymnasium2Torch(envs, device)\n", + "# create the intrinsic reward module\n", + "irs = ICM(envs, device)\n", + "# reset the environments and get the initial observations\n", + "obs, infos = envs.reset()\n", + "# create a dictionary to store the samples\n", + "samples = {'observations':[], \n", + " 'actions':[], \n", + " 'rewards':[],\n", + " 'terminateds':[],\n", + " 'truncateds':[],\n", + " 'next_observations':[]}\n", + "# sampling loop\n", + "for _ in range(n_steps):\n", + " # sample random actions\n", + " actions = th.stack([th.as_tensor(envs.action_space.sample()) for _ in range(n_envs)])\n", + " # environment step\n", + " next_obs, rewards, terminateds, truncateds, infos = envs.step(actions)\n", + " # watch the interactions and get necessary information for the intrinsic reward computation\n", + " irs.watch(observations=obs, \n", + " actions=actions, \n", + " rewards=rewards,\n", + " terminateds=terminateds,\n", + " truncateds=truncateds,\n", + " next_observations=next_obs)\n", + " # compute the intrinsic rewards at each step\n", + " intrinsic_rewards = irs.compute(samples={'observations':obs.unsqueeze(0), \n", + " 'actions':actions.unsqueeze(0), \n", + " 'rewards':rewards.unsqueeze(0),\n", + " 'terminateds':terminateds.unsqueeze(0),\n", + " 'truncateds':truncateds.unsqueeze(0),\n", + " 'next_observations':next_obs.unsqueeze(0)}, \n", + " sync=False)\n", + " print(intrinsic_rewards, intrinsic_rewards.shape)\n", + " # store the samples\n", + " samples['observations'].append(obs)\n", + " samples['actions'].append(actions)\n", + " samples['rewards'].append(rewards)\n", + " samples['terminateds'].append(terminateds)\n", + " samples['truncateds'].append(truncateds)\n", + " samples['next_observations'].append(next_obs)\n", + " obs = next_obs\n", + "# update the intrinsic reward module\n", + "samples = {k: th.stack(v) for k, v in samples.items()}\n", + "irs.update(samples=samples)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "rllte", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.17" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/examples/rlexplore/1 rlexplore_with_rllte.ipynb b/examples/rlexplore/1 rlexplore_with_rllte.ipynb new file mode 100644 index 00000000..a52875d2 --- /dev/null +++ b/examples/rlexplore/1 rlexplore_with_rllte.ipynb @@ -0,0 +1,179 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "**With RLLTE, you can use these intrinsic reward modules in a simple and elegant way.**" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "**Example with PPO**" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/Users/yuanmingqi/miniconda3/envs/rllte/lib/python3.8/site-packages/gym/envs/registration.py:627: UserWarning: \u001b[33mWARN: The environment creator metadata doesn't include `render_modes`, contains: ['render.modes', 'video.frames_per_second']\u001b[0m\n", + " logger.warn(\n", + "/Users/yuanmingqi/miniconda3/envs/rllte/lib/python3.8/site-packages/gym/envs/registration.py:627: UserWarning: \u001b[33mWARN: The environment creator metadata doesn't include `render_modes`, contains: ['render.modes', 'video.frames_per_second']\u001b[0m\n", + " logger.warn(\n", + "/Users/yuanmingqi/miniconda3/envs/rllte/lib/python3.8/site-packages/gym/envs/registration.py:627: UserWarning: \u001b[33mWARN: The environment creator metadata doesn't include `render_modes`, contains: ['render.modes', 'video.frames_per_second']\u001b[0m\n", + " logger.warn(\n", + "/Users/yuanmingqi/miniconda3/envs/rllte/lib/python3.8/site-packages/gym/envs/registration.py:627: UserWarning: \u001b[33mWARN: The environment creator metadata doesn't include `render_modes`, contains: ['render.modes', 'video.frames_per_second']\u001b[0m\n", + " logger.warn(\n", + "/Users/yuanmingqi/miniconda3/envs/rllte/lib/python3.8/site-packages/gym/envs/registration.py:627: UserWarning: \u001b[33mWARN: The environment creator metadata doesn't include `render_modes`, contains: ['render.modes', 'video.frames_per_second']\u001b[0m\n", + " logger.warn(\n", + "/Users/yuanmingqi/miniconda3/envs/rllte/lib/python3.8/site-packages/gym/envs/registration.py:627: UserWarning: \u001b[33mWARN: The environment creator metadata doesn't include `render_modes`, contains: ['render.modes', 'video.frames_per_second']\u001b[0m\n", + " logger.warn(\n", + "/Users/yuanmingqi/miniconda3/envs/rllte/lib/python3.8/site-packages/gym/envs/registration.py:627: UserWarning: \u001b[33mWARN: The environment creator metadata doesn't include `render_modes`, contains: ['render.modes', 'video.frames_per_second']\u001b[0m\n", + " logger.warn(\n", + "/Users/yuanmingqi/miniconda3/envs/rllte/lib/python3.8/site-packages/gym/envs/registration.py:627: UserWarning: \u001b[33mWARN: The environment creator metadata doesn't include `render_modes`, contains: ['render.modes', 'video.frames_per_second']\u001b[0m\n", + " logger.warn(\n", + "/Users/yuanmingqi/miniconda3/envs/rllte/lib/python3.8/site-packages/gym/envs/registration.py:627: UserWarning: \u001b[33mWARN: The environment creator metadata doesn't include `render_modes`, contains: ['render.modes', 'video.frames_per_second']\u001b[0m\n", + " logger.warn(\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "mps Box(0, 255, (3, 84, 84), uint8) Discrete(7)\n", + "[02/28/2024 09:15:03 PM] - [\u001b[1m\u001b[34mINFO.\u001b[0m] - Invoking RLLTE Engine...\n", + "[02/28/2024 09:15:03 PM] - [\u001b[1m\u001b[34mINFO.\u001b[0m] - ================================================================================\n", + "[02/28/2024 09:15:03 PM] - [\u001b[1m\u001b[34mINFO.\u001b[0m] - Tag : default\n", + "[02/28/2024 09:15:03 PM] - [\u001b[1m\u001b[34mINFO.\u001b[0m] - Device : MacOS MPS\n", + "[02/28/2024 09:15:03 PM] - [\u001b[1m\u001b[33mDEBUG\u001b[0m] - Agent : PPO\n", + "[02/28/2024 09:15:03 PM] - [\u001b[1m\u001b[33mDEBUG\u001b[0m] - Encoder : MnihCnnEncoder\n", + "[02/28/2024 09:15:03 PM] - [\u001b[1m\u001b[33mDEBUG\u001b[0m] - Policy : OnPolicySharedActorCritic\n", + "[02/28/2024 09:15:03 PM] - [\u001b[1m\u001b[33mDEBUG\u001b[0m] - Storage : VanillaRolloutStorage\n", + "[02/28/2024 09:15:03 PM] - [\u001b[1m\u001b[33mDEBUG\u001b[0m] - Distribution : Categorical\n", + "[02/28/2024 09:15:03 PM] - [\u001b[1m\u001b[33mDEBUG\u001b[0m] - Augmentation : None\n", + "[02/28/2024 09:15:03 PM] - [\u001b[1m\u001b[33mDEBUG\u001b[0m] - Intrinsic Reward : ICM\n", + "[02/28/2024 09:15:03 PM] - [\u001b[1m\u001b[33mDEBUG\u001b[0m] - ================================================================================\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/Users/yuanmingqi/miniconda3/envs/rllte/lib/python3.8/site-packages/gym/utils/passive_env_checker.py:233: DeprecationWarning: `np.bool8` is a deprecated alias for `np.bool_`. (Deprecated NumPy 1.24)\n", + " if not isinstance(terminated, (bool, np.bool8)):\n", + "/Users/yuanmingqi/miniconda3/envs/rllte/lib/python3.8/site-packages/gym/utils/passive_env_checker.py:233: DeprecationWarning: `np.bool8` is a deprecated alias for `np.bool_`. (Deprecated NumPy 1.24)\n", + " if not isinstance(terminated, (bool, np.bool8)):\n", + "/Users/yuanmingqi/miniconda3/envs/rllte/lib/python3.8/site-packages/gym/utils/passive_env_checker.py:233: DeprecationWarning: `np.bool8` is a deprecated alias for `np.bool_`. (Deprecated NumPy 1.24)\n", + " if not isinstance(terminated, (bool, np.bool8)):\n", + "/Users/yuanmingqi/miniconda3/envs/rllte/lib/python3.8/site-packages/gym/utils/passive_env_checker.py:233: DeprecationWarning: `np.bool8` is a deprecated alias for `np.bool_`. (Deprecated NumPy 1.24)\n", + " if not isinstance(terminated, (bool, np.bool8)):\n", + "/Users/yuanmingqi/miniconda3/envs/rllte/lib/python3.8/site-packages/gym/utils/passive_env_checker.py:233: DeprecationWarning: `np.bool8` is a deprecated alias for `np.bool_`. (Deprecated NumPy 1.24)\n", + " if not isinstance(terminated, (bool, np.bool8)):\n", + "/Users/yuanmingqi/miniconda3/envs/rllte/lib/python3.8/site-packages/gym/utils/passive_env_checker.py:233: DeprecationWarning: `np.bool8` is a deprecated alias for `np.bool_`. (Deprecated NumPy 1.24)\n", + " if not isinstance(terminated, (bool, np.bool8)):\n", + "/Users/yuanmingqi/miniconda3/envs/rllte/lib/python3.8/site-packages/gym/utils/passive_env_checker.py:233: DeprecationWarning: `np.bool8` is a deprecated alias for `np.bool_`. (Deprecated NumPy 1.24)\n", + " if not isinstance(terminated, (bool, np.bool8)):\n", + "/Users/yuanmingqi/miniconda3/envs/rllte/lib/python3.8/site-packages/gym/utils/passive_env_checker.py:233: DeprecationWarning: `np.bool8` is a deprecated alias for `np.bool_`. (Deprecated NumPy 1.24)\n", + " if not isinstance(terminated, (bool, np.bool8)):\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[02/28/2024 09:15:07 PM] - [\u001b[1m\u001b[31mTRAIN\u001b[0m] - S: 1024 | E: 8 | L: 56 | R: 2.230 | FPS: 246.207 | T: 0:00:04 \n", + "[02/28/2024 09:15:10 PM] - [\u001b[1m\u001b[31mTRAIN\u001b[0m] - S: 2048 | E: 16 | L: 103 | R: 3.890 | FPS: 294.277 | T: 0:00:06 \n", + "[02/28/2024 09:15:13 PM] - [\u001b[1m\u001b[31mTRAIN\u001b[0m] - S: 3072 | E: 24 | L: 115 | R: 4.305 | FPS: 318.981 | T: 0:00:09 \n", + "[02/28/2024 09:15:15 PM] - [\u001b[1m\u001b[31mTRAIN\u001b[0m] - S: 4096 | E: 32 | L: 134 | R: 4.456 | FPS: 332.621 | T: 0:00:12 \n", + "[02/28/2024 09:15:18 PM] - [\u001b[1m\u001b[31mTRAIN\u001b[0m] - S: 5120 | E: 40 | L: 134 | R: 4.456 | FPS: 340.885 | T: 0:00:15 \n", + "[02/28/2024 09:15:21 PM] - [\u001b[1m\u001b[31mTRAIN\u001b[0m] - S: 6144 | E: 48 | L: 134 | R: 4.456 | FPS: 346.692 | T: 0:00:17 \n", + "[02/28/2024 09:15:23 PM] - [\u001b[1m\u001b[31mTRAIN\u001b[0m] - S: 7168 | E: 56 | L: 163 | R: 4.607 | FPS: 350.687 | T: 0:00:20 \n", + "[02/28/2024 09:15:26 PM] - [\u001b[1m\u001b[31mTRAIN\u001b[0m] - S: 8192 | E: 64 | L: 393 | R: 5.556 | FPS: 354.132 | T: 0:00:23 \n", + "[02/28/2024 09:15:29 PM] - [\u001b[1m\u001b[31mTRAIN\u001b[0m] - S: 9216 | E: 72 | L: 389 | R: 5.585 | FPS: 356.831 | T: 0:00:25 \n", + "[02/28/2024 09:15:29 PM] - [\u001b[1m\u001b[34mINFO.\u001b[0m] - Training Accomplished!\n", + "[02/28/2024 09:15:29 PM] - [\u001b[1m\u001b[34mINFO.\u001b[0m] - Model saved at: /Users/yuanmingqi/code/rllte/misc/logs/default/2024-02-28-09-15-03/model\n" + ] + } + ], + "source": [ + "from rllte.xplore.reward import ICM\n", + "from rllte.env import make_mario_env\n", + "from rllte.agent import PPO\n", + "import torch as th\n", + "\n", + "# create the vectorized environments\n", + "device = 'mps' if th.backends.mps.is_available() else 'cuda' if th.cuda.is_available() else 'cpu'\n", + "envs = make_mario_env('SuperMarioBros-1-1-v3', device=device)\n", + "print(device, envs.observation_space, envs.action_space)\n", + "# create the intrinsic reward module\n", + "irs = ICM(envs, device=device)\n", + "# create the PPO agent\n", + "agent = PPO(envs, device=device)\n", + "# set the intrinsic reward module\n", + "agent.set(reward=irs)\n", + "# train the agent\n", + "agent.train(10000)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "**Example with DDPG**" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "from rllte.xplore.reward import PseudoCounts\n", + "from rllte.env import make_dmc_env\n", + "from rllte.agent import DDPG\n", + "import torch as th\n", + "\n", + "# create the vectorized environments\n", + "device = 'mps' if th.backends.mps.is_available() else 'cuda' if th.cuda.is_available() else 'cpu'\n", + "envs = make_dmc_env('hopper_hop', device=device)\n", + "print(device, envs.observation_space, envs.action_space)\n", + "# create the intrinsic reward module\n", + "irs = PseudoCounts(envs, device=device)\n", + "# create the PPO agent\n", + "agent = DDPG(envs, device=device)\n", + "# set the intrinsic reward module\n", + "agent.set(reward=irs)\n", + "# train the agent\n", + "agent.train(10000, log_interval=1000)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "rllte", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.17" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/examples/rlexplore/2 rlexplore_with_sb3.ipynb b/examples/rlexplore/2 rlexplore_with_sb3.ipynb new file mode 100644 index 00000000..2d3da3a6 --- /dev/null +++ b/examples/rlexplore/2 rlexplore_with_sb3.ipynb @@ -0,0 +1,109 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "from stable_baselines3.common.base_class import BaseAlgorithm\n", + "from stable_baselines3.common.on_policy_algorithm import OnPolicyAlgorithm\n", + "from stable_baselines3.common.off_policy_algorithm import OffPolicyAlgorithm\n", + "from stable_baselines3.common.callbacks import BaseCallback\n", + "from stable_baselines3.common.env_util import make_vec_env\n", + "from stable_baselines3 import PPO\n", + "\n", + "import gymnasium as gym\n", + "import torch as th\n", + "\n", + "# ===================== load the reward module ===================== #\n", + "from rllte.xplore.reward import RE3\n", + "# ===================== load the reward module ===================== #\n", + "\n", + "class RLeXploreCallback(BaseCallback):\n", + " \"\"\"\n", + " A custom callback for the RLeXplore toolkit. \n", + " \"\"\"\n", + " def __init__(self, irs, verbose=0):\n", + " super(RLeXploreCallback, self).__init__(verbose)\n", + " self.irs = irs\n", + " self.buffer = None\n", + "\n", + " def init_callback(self, model: BaseAlgorithm) -> None:\n", + " super().init_callback(model)\n", + " if isinstance(self.model, OnPolicyAlgorithm):\n", + " self.buffer = self.model.rollout_buffer\n", + " # TODO: support for off-policy algorithms will be added soon!!!\n", + "\n", + " def _on_step(self) -> bool:\n", + " \"\"\"\n", + " This method will be called by the model after each call to `env.step()`.\n", + "\n", + " :return: (bool) If the callback returns False, training is aborted early.\n", + " \"\"\"\n", + " observations = self.locals[\"obs_tensor\"]\n", + " device = observations.device\n", + " actions = th.as_tensor(self.locals[\"actions\"], device=device)\n", + " rewards = th.as_tensor(self.locals[\"rewards\"], device=device)\n", + " dones = th.as_tensor(self.locals[\"dones\"], device=device)\n", + " next_observations = th.as_tensor(self.locals[\"new_obs\"], device=device)\n", + "\n", + " # ===================== watch the interaction ===================== #\n", + " self.irs.watch(observations, actions, rewards, dones, dones, next_observations)\n", + " # ===================== watch the interaction ===================== #\n", + " return True\n", + "\n", + " def _on_rollout_end(self) -> None:\n", + " # ===================== compute the intrinsic rewards ===================== #\n", + " obs = th.as_tensor(self.buffer.observations, device=device)\n", + " actions = th.as_tensor(self.buffer.actions, device=device)\n", + " rewards = th.as_tensor(self.buffer.rewards, device=device)\n", + " dones = th.as_tensor(self.buffer.episode_starts, device=device)\n", + " print(obs.shape, actions.shape, rewards.shape, dones.shape, obs.shape)\n", + " intrinsic_rewards = irs.compute(samples=dict(observations=obs, \n", + " actions=actions, \n", + " rewards=rewards, \n", + " terminateds=dones,\n", + " truncateds=dones, \n", + " next_observations=obs\n", + " ))\n", + " self.buffer.advantages += intrinsic_rewards.cpu().numpy()\n", + " self.buffer.returns += intrinsic_rewards.cpu().numpy()\n", + " # ===================== compute the intrinsic rewards ===================== #\n", + "\n", + "# Parallel environments\n", + "device = 'cuda'\n", + "n_envs = 4\n", + "envs = make_vec_env(\"Pendulum-v1\", n_envs=n_envs)\n", + "\n", + "# ===================== build the reward ===================== #\n", + "irs = RE3(envs, device=device)\n", + "# ===================== build the reward ===================== #\n", + "\n", + "model = PPO(\"MlpPolicy\", envs, verbose=1, device=device)\n", + "model.learn(total_timesteps=25000, callback=RLeXploreCallback(irs))" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "rllte", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.17" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/examples/rlexplore/3 rlexplore_with_cleanrl.py b/examples/rlexplore/3 rlexplore_with_cleanrl.py new file mode 100644 index 00000000..36f264b0 --- /dev/null +++ b/examples/rlexplore/3 rlexplore_with_cleanrl.py @@ -0,0 +1,363 @@ +# docs and experiment results can be found at https://docs.cleanrl.dev/rl-algorithms/ppo/#ppo_ataripy +import os +import random +import time +from dataclasses import dataclass + +import gymnasium as gym +import numpy as np +import torch +import torch.nn as nn +import torch.optim as optim +import tyro +from torch.distributions.categorical import Categorical +from torch.utils.tensorboard import SummaryWriter + +from stable_baselines3.common.atari_wrappers import ( # isort:skip + ClipRewardEnv, + EpisodicLifeEnv, + FireResetEnv, + MaxAndSkipEnv, + NoopResetEnv, +) + +# ===================== load the reward module ===================== # +import sys +sys.path.append("../") +from rllte.xplore.reward import RE3 +# ===================== load the reward module ===================== # + +@dataclass +class Args: + exp_name: str = os.path.basename(__file__)[: -len(".py")] + """the name of this experiment""" + seed: int = 1 + """seed of the experiment""" + torch_deterministic: bool = True + """if toggled, `torch.backends.cudnn.deterministic=False`""" + cuda: bool = True + """if toggled, cuda will be enabled by default""" + track: bool = False + """if toggled, this experiment will be tracked with Weights and Biases""" + wandb_project_name: str = "cleanRL" + """the wandb's project name""" + wandb_entity: str = None + """the entity (team) of wandb's project""" + capture_video: bool = False + """whether to capture videos of the agent performances (check out `videos` folder)""" + + # Algorithm specific arguments + env_id: str = "BreakoutNoFrameskip-v4" + """the id of the environment""" + total_timesteps: int = 10000000 + """total timesteps of the experiments""" + learning_rate: float = 2.5e-4 + """the learning rate of the optimizer""" + num_envs: int = 8 + """the number of parallel game environments""" + num_steps: int = 128 + """the number of steps to run in each environment per policy rollout""" + anneal_lr: bool = True + """Toggle learning rate annealing for policy and value networks""" + gamma: float = 0.99 + """the discount factor gamma""" + gae_lambda: float = 0.95 + """the lambda for the general advantage estimation""" + num_minibatches: int = 4 + """the number of mini-batches""" + update_epochs: int = 4 + """the K epochs to update the policy""" + norm_adv: bool = True + """Toggles advantages normalization""" + clip_coef: float = 0.1 + """the surrogate clipping coefficient""" + clip_vloss: bool = True + """Toggles whether or not to use a clipped loss for the value function, as per the paper.""" + ent_coef: float = 0.01 + """coefficient of the entropy""" + vf_coef: float = 0.5 + """coefficient of the value function""" + max_grad_norm: float = 0.5 + """the maximum norm for the gradient clipping""" + target_kl: float = None + """the target KL divergence threshold""" + + # to be filled in runtime + batch_size: int = 0 + """the batch size (computed in runtime)""" + minibatch_size: int = 0 + """the mini-batch size (computed in runtime)""" + num_iterations: int = 0 + """the number of iterations (computed in runtime)""" + + +def make_env(env_id, idx, capture_video, run_name): + def thunk(): + if capture_video and idx == 0: + env = gym.make(env_id, render_mode="rgb_array") + env = gym.wrappers.RecordVideo(env, f"videos/{run_name}") + else: + env = gym.make(env_id) + env = gym.wrappers.RecordEpisodeStatistics(env) + if capture_video: + if idx == 0: + env = gym.wrappers.RecordVideo(env, f"videos/{run_name}") + env = NoopResetEnv(env, noop_max=30) + env = MaxAndSkipEnv(env, skip=4) + env = EpisodicLifeEnv(env) + if "FIRE" in env.unwrapped.get_action_meanings(): + env = FireResetEnv(env) + env = ClipRewardEnv(env) + env = gym.wrappers.ResizeObservation(env, (84, 84)) + env = gym.wrappers.GrayScaleObservation(env) + env = gym.wrappers.FrameStack(env, 4) + return env + + return thunk + + +def layer_init(layer, std=np.sqrt(2), bias_const=0.0): + torch.nn.init.orthogonal_(layer.weight, std) + torch.nn.init.constant_(layer.bias, bias_const) + return layer + + +class Agent(nn.Module): + def __init__(self, envs): + super().__init__() + self.network = nn.Sequential( + layer_init(nn.Conv2d(4, 32, 8, stride=4)), + nn.ReLU(), + layer_init(nn.Conv2d(32, 64, 4, stride=2)), + nn.ReLU(), + layer_init(nn.Conv2d(64, 64, 3, stride=1)), + nn.ReLU(), + nn.Flatten(), + layer_init(nn.Linear(64 * 7 * 7, 512)), + nn.ReLU(), + ) + self.actor = layer_init(nn.Linear(512, envs.single_action_space.n), std=0.01) + self.critic = layer_init(nn.Linear(512, 1), std=1) + + def get_value(self, x): + return self.critic(self.network(x / 255.0)) + + def get_action_and_value(self, x, action=None): + hidden = self.network(x / 255.0) + logits = self.actor(hidden) + probs = Categorical(logits=logits) + if action is None: + action = probs.sample() + return action, probs.log_prob(action), probs.entropy(), self.critic(hidden) + + +if __name__ == "__main__": + args = tyro.cli(Args) + args.batch_size = int(args.num_envs * args.num_steps) + args.minibatch_size = int(args.batch_size // args.num_minibatches) + args.num_iterations = args.total_timesteps // args.batch_size + run_name = f"{args.env_id}__{args.exp_name}__{args.seed}__{int(time.time())}" + if args.track: + import wandb + + wandb.init( + project=args.wandb_project_name, + entity=args.wandb_entity, + sync_tensorboard=True, + config=vars(args), + name=run_name, + monitor_gym=True, + save_code=True, + ) + writer = SummaryWriter(f"runs/{run_name}") + writer.add_text( + "hyperparameters", + "|param|value|\n|-|-|\n%s" % ("\n".join([f"|{key}|{value}|" for key, value in vars(args).items()])), + ) + + # TRY NOT TO MODIFY: seeding + random.seed(args.seed) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + torch.backends.cudnn.deterministic = args.torch_deterministic + + device = torch.device("cuda" if torch.cuda.is_available() and args.cuda else "cpu") + + # env setup + envs = gym.vector.SyncVectorEnv( + [make_env(args.env_id, i, args.capture_video, run_name) for i in range(args.num_envs)], + ) + assert isinstance(envs.single_action_space, gym.spaces.Discrete), "only discrete action space is supported" + + # ===================== build the reward ===================== # + irs = RE3(envs=envs, device=device) + # ===================== build the reward ===================== # + + agent = Agent(envs).to(device) + optimizer = optim.Adam(agent.parameters(), lr=args.learning_rate, eps=1e-5) + + # ALGO Logic: Storage setup + obs = torch.zeros((args.num_steps, args.num_envs) + envs.single_observation_space.shape).to(device) + actions = torch.zeros((args.num_steps, args.num_envs) + envs.single_action_space.shape).to(device) + logprobs = torch.zeros((args.num_steps, args.num_envs)).to(device) + rewards = torch.zeros((args.num_steps, args.num_envs)).to(device) + dones = torch.zeros((args.num_steps, args.num_envs)).to(device) + values = torch.zeros((args.num_steps, args.num_envs)).to(device) + + # TRY NOT TO MODIFY: start the game + global_step = 0 + start_time = time.time() + next_obs, _ = envs.reset(seed=args.seed) + next_obs = torch.Tensor(next_obs).to(device) + next_done = torch.zeros(args.num_envs).to(device) + + for iteration in range(1, args.num_iterations + 1): + # Annealing the rate if instructed to do so. + if args.anneal_lr: + frac = 1.0 - (iteration - 1.0) / args.num_iterations + lrnow = frac * args.learning_rate + optimizer.param_groups[0]["lr"] = lrnow + + for step in range(0, args.num_steps): + global_step += args.num_envs + obs[step] = next_obs + dones[step] = next_done + + # ALGO LOGIC: action logic + with torch.no_grad(): + action, logprob, _, value = agent.get_action_and_value(next_obs) + values[step] = value.flatten() + actions[step] = action + logprobs[step] = logprob + + # TRY NOT TO MODIFY: execute the game and log data. + next_obs, reward, terminations, truncations, infos = envs.step(action.cpu().numpy()) + + # ===================== watch the interaction ===================== # + irs.watch(observations=obs[step], + actions=actions[step], + rewards=rewards[step], + terminateds=dones[step], + truncateds=dones[step], + next_observations=next_obs + ) + # ===================== watch the interaction ===================== # + + next_done = np.logical_or(terminations, truncations) + rewards[step] = torch.tensor(reward).to(device).view(-1) + next_obs, next_done = torch.Tensor(next_obs).to(device), torch.Tensor(next_done).to(device) + + if "final_info" in infos: + for info in infos["final_info"]: + if info and "episode" in info: + print(f"global_step={global_step}, episodic_return={info['episode']['r']}") + writer.add_scalar("charts/episodic_return", info["episode"]["r"], global_step) + writer.add_scalar("charts/episodic_length", info["episode"]["l"], global_step) + + # ===================== compute the intrinsic rewards ===================== # + intrinsic_rewards = irs.compute(samples=dict(observations=obs, + actions=actions, + rewards=rewards, + terminateds=dones, + truncateds=dones, + next_observations=obs + )) + rewards += intrinsic_rewards + # ===================== compute the intrinsic rewards ===================== # + + # bootstrap value if not done + with torch.no_grad(): + next_value = agent.get_value(next_obs).reshape(1, -1) + advantages = torch.zeros_like(rewards).to(device) + lastgaelam = 0 + for t in reversed(range(args.num_steps)): + if t == args.num_steps - 1: + nextnonterminal = 1.0 - next_done + nextvalues = next_value + else: + nextnonterminal = 1.0 - dones[t + 1] + nextvalues = values[t + 1] + delta = rewards[t] + args.gamma * nextvalues * nextnonterminal - values[t] + advantages[t] = lastgaelam = delta + args.gamma * args.gae_lambda * nextnonterminal * lastgaelam + returns = advantages + values + + # flatten the batch + b_obs = obs.reshape((-1,) + envs.single_observation_space.shape) + b_logprobs = logprobs.reshape(-1) + b_actions = actions.reshape((-1,) + envs.single_action_space.shape) + b_advantages = advantages.reshape(-1) + b_returns = returns.reshape(-1) + b_values = values.reshape(-1) + + # Optimizing the policy and value network + b_inds = np.arange(args.batch_size) + clipfracs = [] + for epoch in range(args.update_epochs): + np.random.shuffle(b_inds) + for start in range(0, args.batch_size, args.minibatch_size): + end = start + args.minibatch_size + mb_inds = b_inds[start:end] + + _, newlogprob, entropy, newvalue = agent.get_action_and_value(b_obs[mb_inds], b_actions.long()[mb_inds]) + logratio = newlogprob - b_logprobs[mb_inds] + ratio = logratio.exp() + + with torch.no_grad(): + # calculate approx_kl http://joschu.net/blog/kl-approx.html + old_approx_kl = (-logratio).mean() + approx_kl = ((ratio - 1) - logratio).mean() + clipfracs += [((ratio - 1.0).abs() > args.clip_coef).float().mean().item()] + + mb_advantages = b_advantages[mb_inds] + if args.norm_adv: + mb_advantages = (mb_advantages - mb_advantages.mean()) / (mb_advantages.std() + 1e-8) + + # Policy loss + pg_loss1 = -mb_advantages * ratio + pg_loss2 = -mb_advantages * torch.clamp(ratio, 1 - args.clip_coef, 1 + args.clip_coef) + pg_loss = torch.max(pg_loss1, pg_loss2).mean() + + # Value loss + newvalue = newvalue.view(-1) + if args.clip_vloss: + v_loss_unclipped = (newvalue - b_returns[mb_inds]) ** 2 + v_clipped = b_values[mb_inds] + torch.clamp( + newvalue - b_values[mb_inds], + -args.clip_coef, + args.clip_coef, + ) + v_loss_clipped = (v_clipped - b_returns[mb_inds]) ** 2 + v_loss_max = torch.max(v_loss_unclipped, v_loss_clipped) + v_loss = 0.5 * v_loss_max.mean() + else: + v_loss = 0.5 * ((newvalue - b_returns[mb_inds]) ** 2).mean() + + entropy_loss = entropy.mean() + loss = pg_loss - args.ent_coef * entropy_loss + v_loss * args.vf_coef + + optimizer.zero_grad() + loss.backward() + nn.utils.clip_grad_norm_(agent.parameters(), args.max_grad_norm) + optimizer.step() + + if args.target_kl is not None and approx_kl > args.target_kl: + break + + y_pred, y_true = b_values.cpu().numpy(), b_returns.cpu().numpy() + var_y = np.var(y_true) + explained_var = np.nan if var_y == 0 else 1 - np.var(y_true - y_pred) / var_y + + # TRY NOT TO MODIFY: record rewards for plotting purposes + writer.add_scalar("charts/learning_rate", optimizer.param_groups[0]["lr"], global_step) + writer.add_scalar("losses/value_loss", v_loss.item(), global_step) + writer.add_scalar("losses/policy_loss", pg_loss.item(), global_step) + writer.add_scalar("losses/entropy", entropy_loss.item(), global_step) + writer.add_scalar("losses/old_approx_kl", old_approx_kl.item(), global_step) + writer.add_scalar("losses/approx_kl", approx_kl.item(), global_step) + writer.add_scalar("losses/clipfrac", np.mean(clipfracs), global_step) + writer.add_scalar("losses/explained_variance", explained_var, global_step) + print("SPS:", int(global_step / (time.time() - start_time))) + writer.add_scalar("charts/SPS", int(global_step / (time.time() - start_time)), global_step) + + envs.close() + writer.close() \ No newline at end of file