-
Notifications
You must be signed in to change notification settings - Fork 0
/
dqn_world_model.py
162 lines (135 loc) · 5.57 KB
/
dqn_world_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
from environment import Environment
import tensorflow as tf
from tensorflow.keras import models, layers, optimizers
import numpy as np
from datetime import datetime
from joblib import Parallel, delayed
from matplotlib import pyplot as plt
from reward_plot import RewardPlot
import os
GAMMA = 0.7
LEARNING_RATE = 0.001
EPSILON_DECAY = 0.975
MIN_EPSILON = 0.03
SIMULATION_EPOCHS = 32
TRAINING_EPOCHS = 3
NUM_ACTIONS = 8
SHOW_AFTER_ITERATIONS = 15
BATCH_SIZE = 512
SIMULATION_STEPS = 150
WORLD_MODEL_PATH = 'world_models\\2020_01_30-21_24_24\\model_4.h5'
def get_model():
# return a compiled model to be used as the deep Q network
model = models.Sequential()
model.add(layers.Input(shape=4))
# model.add(layers.Dense(units=4, activation='sigmoid'))
model.add(layers.Dense(units=NUM_ACTIONS))
model.compile(loss='mse', optimizer=optimizers.RMSprop(learning_rate=LEARNING_RATE))
return model
def simulate(weights, epsilon):
# instantiate a new model and set the weights (enables multithreading)
model = get_model()
model.set_weights(weights)
# load the world model from a given checkpoint file
world_model = models.load_model(WORLD_MODEL_PATH)
# set up the environment
env = Environment(simulation_steps=SIMULATION_STEPS)
state = env.state()
done = False
observations = []
latent_world = np.zeros(2, dtype=np.float32)
while not done:
# get the latent vector from the world model
x = np.concatenate((state, latent_world))
latent_world = world_model(x.reshape(1, -1)).numpy()[0]
# get the model's predicted Q values
x = np.concatenate((state, latent_world))
prediction = model(x.reshape(1, -1)).numpy()[0]
# choose an action
if np.random.uniform() < epsilon:
# random action (exploration)
action = np.random.randint(NUM_ACTIONS)
else:
# best predicted action (exploitation)
action = np.argmax(prediction)
# save the observations
if len(observations) > 0:
observations[-1].extend([reward, x])
observations.append([x, prediction, action])
# simulate one step
state, reward, done = env.step(action)
observations = observations[:-1]
# preprocess the observation tuples
for i, current in enumerate(observations):
action = current[2]
current_q = current[1][action]
reward = current[3]
next_max_q = np.max(model(current[4].reshape((1, -1))).numpy())
current[1][action] = reward + GAMMA * next_max_q
return [obs[0] for obs in observations], [obs[1] for obs in observations], env.cumulative_reward
def plot_simulation(env, model, world_model):
# set up the environment
env.reset()
state = env.state()
# run the simulation
done = False
latent_world = np.zeros(2, dtype=np.float32)
while not done:
# get the latent vector from the world model
x = np.concatenate((state, latent_world))
latent_world = world_model(x.reshape(1, -1)).numpy()[0]
x = np.concatenate((state, latent_world))
action = np.argmax(model(x.reshape((1, -1)))[0])
state, _, done = env.step(action)
if __name__ == '__main__':
model = get_model()
plot_env = None
reward_plot = RewardPlot()
world_model = models.load_model(WORLD_MODEL_PATH)
# create folder for model checkpoints
timestamp = datetime.now().strftime('%Y_%m_%d-%H_%M_%S')
os.mkdir(os.path.join('models_combined', timestamp))
epsilon = 1
iteration = 0
while True:
tf.keras.backend.clear_session()
iteration += 1
print(f'Iteration {iteration}:')
# run simulations
print(f'Running {SIMULATION_EPOCHS} simulations...')
params = model.get_weights(), epsilon
delayed_call = (delayed(simulate)(*params) for _ in range(SIMULATION_EPOCHS))
results = Parallel(n_jobs=-1)(delayed_call)
# extract state-Q value pairs from results
x = np.array(sum([res[0] for res in results], []))
y = np.array(sum([res[1] for res in results], []))
# get the cumulative rewards from all simulations
cumulative_rewards = [res[2] for res in results]
mean_reward = np.mean(cumulative_rewards)
reward_plot.update(cumulative_rewards)
# decay the exploration rate
epsilon *= EPSILON_DECAY
epsilon = max(epsilon, MIN_EPSILON)
# train the model and save the current weights
print(f'Training for {TRAINING_EPOCHS} epochs...')
for epoch in range(TRAINING_EPOCHS):
# shuffle training data
permutation = np.random.permutation(len(x))
x = x[permutation]
y = y[permutation]
# train for one epoch
for i in range(len(x) // BATCH_SIZE + 1):
start = i * BATCH_SIZE
end = (i + 1) * BATCH_SIZE
model.train_on_batch(x[start:end], y[start:end])
# save the current model to disk
model.save(os.path.join('models_combined', timestamp, f'model_{iteration}_{int(mean_reward)}.h5'))
# evaluate the training state
loss = model.evaluate(x, y, verbose=0)
print(f'Current epsilon: {epsilon:.2f}, loss: {loss:.4f}, avg. cumulative reward: {mean_reward:.2f}')
# visualize the model's progress by plotting a simulation
if SHOW_AFTER_ITERATIONS > 0 and iteration % SHOW_AFTER_ITERATIONS == 0:
if plot_env is None:
plot_env = Environment(draw=True, simulation_steps=150)
plot_simulation(plot_env, model, world_model)
print()