-
Notifications
You must be signed in to change notification settings - Fork 38
/
cql.rst
140 lines (102 loc) · 4.29 KB
/
cql.rst
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
.. _cql:
Conservative Q-Learning (CQL)
=============================
CQL is an extension of Q-learning that addresses the typical overestimation of values induced by the distributional shift between
the dataset and the learned policy in offline RL algorithms. A conservative Q-function is learned, such that the expected value of a
policy under this Q-function lower-bounds its true value
* CQL paper: https://arxiv.org/abs/2006.04779
Can I use it?
--------------
.. list-table::
:widths: 20 20 20
:header-rows: 1
* -
- Action
- Observation
* - Discrete
- ✔️
- ✔️
* - Continuous
- ✔️
- ✔️
So far, we have implemented CQN - CQL applied to DQN, which cannot be used on continuous action spaces. We will soon be
adding other CQL extensions of algorithms for offline RL.
Example
------------
.. code-block:: python
import gymnasium as gym
import h5py
from agilerl.components.replay_buffer import ReplayBuffer
from agilerl.algorithms.cqn import CQN
# Create environment and Experience Replay Buffer, and load dataset
env = gym.make('CartPole-v1')
try:
state_dim = env.observation_space.n # Discrete observation space
one_hot = True # Requires one-hot encoding
except Exception:
state_dim = env.observation_space.shape # Continuous observation space
one_hot = False # Does not require one-hot encoding
try:
action_dim = env.action_space.n # Discrete action space
except Exception:
action_dim = env.action_space.shape[0] # Continuous action space
channels_last = False # Swap image channels dimension from last to first [H, W, C] -> [C, H, W]
if channels_last:
state_dim = (state_dim[2], state_dim[0], state_dim[1])
field_names = ["state", "action", "reward", "next_state", "done"]
memory = ReplayBuffer(action_dim=action_dim, memory_size=10000, field_names=field_names)
dataset = h5py.File('data/cartpole/cartpole_random_v1.1.0.h5', 'r') # Load dataset
# Save transitions to replay buffer
dataset_length = dataset['rewards'].shape[0]
for i in range(dataset_length-1):
state = dataset['observations'][i]
next_state = dataset['observations'][i+1]
if channels_last:
state = np.moveaxis(state, [3], [1])
next_state = np.moveaxis(next_state, [3], [1])
action = dataset['actions'][i]
reward = dataset['rewards'][i]
done = bool(dataset['terminals'][i])
memory.save2memory(state, action, reward, next_state, done)
agent = CQN(state_dim=state_dim, action_dim=action_dim, one_hot=one_hot) # Create DQN agent
state = env.reset()[0] # Reset environment at start of episode
while True:
experiences = memory.sample(agent.batch_size) # Sample replay buffer
# Learn according to agent's RL algorithm
agent.learn(experiences)
To configure the network architecture, pass a dict to the CQN ``net_config`` field. For an MLP, this can be as simple as:
.. code-block:: python
NET_CONFIG = {
'arch': 'mlp', # Network architecture
'hidden_size': [32, 32] # Network hidden size
}
Or for a CNN:
.. code-block:: python
NET_CONFIG = {
'arch': 'cnn', # Network architecture
'hidden_size': [128], # Network hidden size
'channel_size': [32, 32], # CNN channel size
'kernel_size': [8, 4], # CNN kernel size
'stride_size': [4, 2], # CNN stride size
'normalize': True # Normalize image from range [0,255] to [0,1]
}
.. code-block:: python
agent = CQN(state_dim=state_dim, action_dim=action_dim, one_hot=one_hot, net_config=NET_CONFIG) # Create CQN agent
Saving and loading agents
-------------------------
To save an agent, use the ``saveCheckpoint`` method:
.. code-block:: python
from agilerl.algorithms.cqn import CQN
agent = CQN(state_dim=state_dim, action_dim=action_dim, one_hot=one_hot) # Create CQN agent
checkpoint_path = "path/to/checkpoint"
agent.saveCheckpoint(checkpoint_path)
To load a saved agent, use the ``load`` method:
.. code-block:: python
from agilerl.algorithms.cqn import CQN
checkpoint_path = "path/to/checkpoint"
agent = CQN.load(checkpoint_path)
Parameters
----------
.. autoclass:: agilerl.algorithms.cqn.CQN
:members:
:inherited-members: