Skip to content

Commit 9977f7f

Browse files
ensemble distillation
1 parent ac8c9f3 commit 9977f7f

File tree

10 files changed

+323
-11
lines changed

10 files changed

+323
-11
lines changed

Agent/DQNEnsembleAgent.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ def combineModels(models, combiner):
1515
) for x in models ]
1616

1717
res = layers.Lambda(combiner)( layers.Concatenate(axis=1)(predictions) )
18+
res = MaskedSoftmax()( res, actionsMask )
1819
return keras.Model(inputs=[inputs, actionsMask], outputs=res)
1920

2021
@tf.function
@@ -45,11 +46,10 @@ def processBatch(self, states, actionsMask):
4546
actions[rndIndexes] = np.random.random_sample(actions.shape)[rndIndexes]
4647

4748
if not (self._noise is None):
48-
# softmax
49-
e_x = np.exp(actions - actions.max(axis=-1, keepdims=True))
50-
normed = e_x / e_x.sum(axis=-1, keepdims=True)
51-
# add noise
5249
actions = normed + (np.random.random_sample(actions.shape) * self._noise)
5350

5451
actions[np.where(~(1 == np.array(actionsMask)))] = -math.inf
55-
return actions.argmax(axis=-1)
52+
return actions.argmax(axis=-1)
53+
54+
def predict(self, states, actionsMask):
55+
return self._model.predict([states, actionsMask])

README.md

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,35 @@
4141

4242
Опять же, прямо ощутимого улучшения нет, но ансамбль немного стабильнее открывает 20-25% карты.
4343

44-
Следующим шагом будет дистилляция ансамбля в единую сеть, а так же использование полноценной сети для комбинации предсказаний подсетей. Есть большая вероятность того, что это позволит уловить более глубокие корреляции т. к. обучаемая сеть будет уже иметь представление о соотношение Q-values (сами значения индивидуальны для каждой сети).
44+
# Дистилляция ансамбля
45+
46+
Новая сеть обучалась с дополнительным лоссом, который определяет сходство распределения Q-values обучаемой сети с предсказанием ансамбля.
47+
48+
![](img/20210106-distilled.jpg)
49+
50+
Обучаемая с учителем нейронная сеть практически сразу же достигает более высоких результатов, чем обучаемая без учителя. Более того, в полноценных тестах она показывает себя немного лучше ансамбля:
51+
52+
![](img/20210106-high.jpg)
53+
54+
![](img/20210106-low.jpg)
55+
56+
Некоторые наблюдения:
57+
58+
- Новая сеть какое-то время (10-30 эпох) способна показывать хорошие результаты, если "копировать" только распределение и не контролировать сами значения. Это вполне ожидаемо, но всё же интересно.
59+
- Сеть лишь копирует распределение, поэтому не способна улучшить результаты. Вполне возможно, что необходимо более длительное обучение, чтоб сеть полностью адаптировала Q-values к распределению диктуемому ансамблем, а затем она смогла бы продолжить обучение. Целесообразно ли это? Не лучше ли тогда обучить сеть полностью с нуля?
60+
- Ансамбль усредняет поведение агентов, выделяя общее поведение. Новая сеть копирует усреднённое поведение, тоже сглаживая нюансы поведения, стратегию. Таким образом, новая сеть теряет особенности, которые позволяли агентам демонстрировать более хорошие результаты в особых ситуациях. Как тогда эффективно объединять "знания" агентов? Полезные материалы по данной теме:
61+
- [Distill and transfer learning for robust multitask RL (YouTube)](https://www.youtube.com/watch?v=scf7Przmh7c)
62+
- [Teacher-Student Framework: A Reinforcement Learning Approach](https://www.researchgate.net/publication/280255927_Teacher-Student_Framework_A_Reinforcement_Learning_Approach)
63+
- [Progressive Reinforcement Learning with Distillation for Multi-Skilled Motion Control](https://arxiv.org/abs/1802.04765)
64+
65+
# Идеи и эксперименты
66+
67+
- [ ] Заменить разделение памяти/эпизодов на основные и после попадания в цикл.
68+
- [ ] Реализовать дистилляцию нескольких политик, используя доп. награды или иные методы.
69+
- [ ] Сравнить обученного без учителя агента с обученным с учителем. (500 эпох)
70+
- [ ] Обучить агента, который не получает информацию о своих перемещениях (только с данным об окружении).
71+
- [ ] Реализовать полноценного агента с памятью.
72+
- [ ] Использовать A2C или иные методы, фундаментально отличающиеся от DQN.
4573

4674
# Области применения
4775

Utils/__init__.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,13 @@
22
import numpy as np
33
import math
44

5+
def softmax(x, mask=None):
6+
e_x = np.exp(x - x.max(axis=-1, keepdims=True))
7+
if not (mask is None):
8+
e_x *= mask
9+
e_sum = e_x.sum(axis=-1, keepdims=True)
10+
return np.divide(e_x, e_sum, out=np.zeros_like(e_x), where=(e_sum != 0))
11+
512
def emulateBatch(testEnvs, agent, maxSteps):
613
replays = [[] for _ in testEnvs]
714
steps = 0

distillation.py

Lines changed: 275 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,275 @@
1+
# -*- coding: utf-8 -*-
2+
import sys
3+
import os
4+
import tensorflow as tf
5+
from Agent.MaskedSoftmax import MaskedSoftmax
6+
7+
if 'COLAB_GPU' in os.environ:
8+
# fix resolve modules
9+
from os.path import dirname
10+
sys.path.append(dirname(dirname(dirname(__file__))))
11+
else: # local GPU
12+
gpus = tf.config.experimental.list_physical_devices('GPU')
13+
tf.config.experimental.set_virtual_device_configuration(
14+
gpus[0], [tf.config.experimental.VirtualDeviceConfiguration(memory_limit=3 * 1024)]
15+
)
16+
17+
from tensorflow.keras.optimizers import Adam
18+
from tensorflow.keras.losses import Huber
19+
import tensorflow.keras as keras
20+
21+
from model import createModel
22+
from Core.MazeRLWrapper import MazeRLWrapper
23+
from Utils.ExperienceBuffers.CebPrioritized import CebPrioritized
24+
from Agent.DQNAgent import DQNAgent
25+
from Agent.DQNEnsembleAgent import DQNEnsembleAgent
26+
import time
27+
import Utils
28+
from Utils.ExperienceBuffers.CebLinear import CebLinear
29+
import glob
30+
import numpy as np
31+
32+
#######################################
33+
def train(model, trainableModel, memory, params):
34+
modelClone = tf.keras.models.clone_model(model)
35+
modelClone.set_weights(model.get_weights()) # use clone model for stability
36+
37+
BOOTSTRAPPED_STEPS = params['steps']
38+
GAMMA = params['gamma']
39+
ALPHA = params.get('alpha', 1.0)
40+
rows = np.arange(params['batchSize'])
41+
lossSum = 0
42+
for _ in range(params['episodes']):
43+
allStates, actions, rewards, actionsMask, teacherPredictions, nextStateScoreMultiplier = memory.sampleSequenceBatch(
44+
batch_size=params['batchSize'],
45+
maxSamplesFromEpisode=params.get('maxSamplesFromEpisode', 16),
46+
sequenceLen=BOOTSTRAPPED_STEPS + 1
47+
)
48+
49+
states = allStates[:, :-1]
50+
rewards = rewards[:, :-1]
51+
actions = actions[:, 0]
52+
53+
futureScores = modelClone.predict(allStates[:, -1]).max(axis=-1) * nextStateScoreMultiplier[:, -1]
54+
totalRewards = (rewards * (GAMMA ** np.arange(BOOTSTRAPPED_STEPS))).sum(axis=-1)
55+
targets = modelClone.predict(states[:, 0])
56+
57+
targets[rows, actions] += ALPHA * (
58+
totalRewards + futureScores * (GAMMA ** BOOTSTRAPPED_STEPS) - targets[rows, actions]
59+
)
60+
61+
lossSum += trainableModel.fit(
62+
[states[:, 0], teacherPredictions[:, 0], actionsMask[:, 0], targets],
63+
epochs=1, verbose=0
64+
).history['loss'][0]
65+
###
66+
67+
return lossSum / params['episodes']
68+
69+
def complexLoss(valueLoss, teacherPower, distributions, actionsMasks, y_true, y_pred, y_pred_softmax):
70+
# mask out invalid actions
71+
lossValues = valueLoss(y_true * actionsMasks, y_pred * actionsMasks)
72+
73+
lossDistribution = keras.losses.kl_divergence(distributions * actionsMasks, y_pred_softmax * actionsMasks)
74+
return lossValues + (lossDistribution * teacherPower)
75+
76+
def wrapStudentModel(student):
77+
inputA = keras.layers.Input(shape=student.layers[0].input_shape[0][1:])
78+
inputDistributions = keras.layers.Input(shape=(4, ))
79+
inputMasks = keras.layers.Input(shape=(4, ))
80+
inputTargets = keras.layers.Input(shape=(4, ))
81+
teacherPower = tf.Variable(1.0, tf.float32)
82+
83+
res = student(inputA)
84+
resSoftmax = MaskedSoftmax()(res, inputMasks)
85+
86+
model = keras.Model(inputs=[inputA, inputDistributions, inputMasks, inputTargets], outputs=[res, resSoftmax])
87+
model.add_loss(complexLoss(
88+
Huber(delta=1),
89+
teacherPower,
90+
inputDistributions, inputMasks, inputTargets,
91+
res, resSoftmax
92+
))
93+
model.compile(optimizer=Adam(lr=1e-3), loss=None )
94+
return model, teacherPower
95+
96+
def learn_environment(teacher, model, params):
97+
NAME = params['name']
98+
BATCH_SIZE = params['batch size']
99+
GAMMA = params['gamma']
100+
BOOTSTRAPPED_STEPS = params['bootstrapped steps']
101+
LOOP_LIMIT = params['maze']['loop limit']
102+
metrics = {}
103+
104+
environments = [
105+
MazeRLWrapper(params['maze']) for _ in range(params['test episodes'])
106+
]
107+
108+
memory = CebPrioritized(maxSize=5000, sampleWeight='abs')
109+
doomMemory = CebLinear(
110+
maxSize=params.get('max steps after loop', 16) * 10000,
111+
sampleWeight='abs'
112+
)
113+
trainableModel, teacherPower = wrapStudentModel(model)
114+
######################################################
115+
def withTeacherPredictions(replay):
116+
prevStates, actions, rewards, actionsMasks = zip(*replay)
117+
teacherPredictions = teacher.predict(np.array(prevStates), np.array(actionsMasks))
118+
return list(zip(prevStates, actions, rewards, actionsMasks, teacherPredictions))
119+
120+
def testModel(EXPLORE_RATE):
121+
for e in environments: e.reset()
122+
replays = Utils.emulateBatch(
123+
environments,
124+
DQNAgent(model, exploreRate=EXPLORE_RATE, noise=params.get('agent noise', 0)),
125+
maxSteps=params.get('max test steps')
126+
)
127+
for replay, _ in replays:
128+
if params.get('clip replay', False):
129+
replay = Utils.clipReplay(replay, loopLimit=LOOP_LIMIT)
130+
if BOOTSTRAPPED_STEPS < len(replay):
131+
memory.addEpisode(withTeacherPredictions(replay), terminated=True)
132+
133+
scores = [x.score for x in environments]
134+
################
135+
# collect bad experience
136+
envs = [e for e in environments if e.hitTheLoop]
137+
if envs:
138+
for e in envs: e.Continue()
139+
replays = Utils.emulateBatch(
140+
envs,
141+
DQNAgent(
142+
model,
143+
exploreRate=params.get('explore rate after loop', 1),
144+
noise=params.get('agent noise after loop', 0)
145+
),
146+
maxSteps=params.get('max steps after loop', 16)
147+
)
148+
for replay, _ in replays:
149+
if BOOTSTRAPPED_STEPS < len(replay):
150+
doomMemory.addEpisode(withTeacherPredictions(replay), terminated=True)
151+
################
152+
return scores
153+
######################################################
154+
# collect some experience
155+
for _ in range(2):
156+
testModel(EXPLORE_RATE=0)
157+
#######################
158+
bestModelScore = -float('inf')
159+
for epoch in range(params['epochs']):
160+
T = time.time()
161+
162+
EXPLORE_RATE = params['explore rate'](epoch)
163+
alpha = params.get('alpha', lambda _: 1)(epoch)
164+
teacherP = max((0, params.get('teacher power', lambda _: 1)(epoch) ))
165+
teacherPower.assign(teacherP)
166+
print(
167+
'[%s] %d/%d epoch. Explore rate: %.3f. Alpha: %.5f. Teacher power: %.3f' % (
168+
NAME, epoch, params['epochs'], EXPLORE_RATE, alpha, teacherP
169+
)
170+
)
171+
##################
172+
# Training
173+
trainLoss = train(
174+
model, trainableModel, memory,
175+
{
176+
'gamma': GAMMA,
177+
'batchSize': BATCH_SIZE,
178+
'steps': BOOTSTRAPPED_STEPS,
179+
'episodes': params['train episodes'](epoch),
180+
'alpha': alpha
181+
}
182+
)
183+
print('Avg. train loss: %.4f' % trainLoss)
184+
185+
if BATCH_SIZE < len(doomMemory):
186+
trainLoss = train(
187+
model, trainableModel, doomMemory,
188+
{
189+
'gamma': GAMMA,
190+
'batchSize': BATCH_SIZE,
191+
'steps': BOOTSTRAPPED_STEPS,
192+
'episodes': params['train doom episodes'](epoch),
193+
'alpha': params.get('doom alpha', lambda _: alpha)(epoch)
194+
}
195+
)
196+
print('Avg. train doom loss: %.4f' % trainLoss)
197+
##################
198+
# test
199+
print('Testing...')
200+
scores = testModel(EXPLORE_RATE)
201+
Utils.trackScores(scores, metrics)
202+
##################
203+
204+
scoreSum = sum(scores)
205+
print('Scores sum: %.5f' % scoreSum)
206+
if (bestModelScore < scoreSum) and (params['warm up epochs'] < epoch):
207+
print('save best model (%.2f => %.2f)' % (bestModelScore, scoreSum))
208+
bestModelScore = scoreSum
209+
model.save_weights('weights/%s.h5' % NAME)
210+
##################
211+
os.makedirs('charts', exist_ok=True)
212+
Utils.plotData2file(metrics, 'charts/%s.jpg' % NAME)
213+
print('Epoch %d finished in %.1f sec.' % (epoch, time.time() - T))
214+
print('------------------')
215+
216+
#######################################
217+
MAZE_FOV = 3
218+
MAZE_MINIMAP_SIZE = 8
219+
MAZE_LOOPLIMIT = 32
220+
#######################################
221+
222+
if __name__ == "__main__":
223+
DEFAULT_MAZE_PARAMS = {
224+
'size': 40,
225+
'FOV': MAZE_FOV,
226+
'minimapSize': MAZE_MINIMAP_SIZE,
227+
'loop limit': MAZE_LOOPLIMIT,
228+
}
229+
230+
MODEL_INPUT_SHAPE = MazeRLWrapper(DEFAULT_MAZE_PARAMS).input_size
231+
232+
models = []
233+
for x in glob.iglob('weights/agent-*.h5'):
234+
filename = os.path.abspath(x)
235+
model = createModel(shape=MODEL_INPUT_SHAPE)
236+
model.load_weights(filename)
237+
models.append(model)
238+
239+
teacher = DQNEnsembleAgent(models)
240+
#######################
241+
DEFAULT_LEARNING_PARAMS = {
242+
'maze': DEFAULT_MAZE_PARAMS,
243+
244+
'batch size': 256,
245+
'gamma': 0.95,
246+
'bootstrapped steps': 3,
247+
248+
'epochs': 100,
249+
'warm up epochs': 0,
250+
'test episodes': 128,
251+
'train episodes': lambda _: 128,
252+
'train doom episodes': lambda _: 32,
253+
254+
'alpha': lambda _: 1,
255+
'explore rate': lambda _: 0,
256+
257+
'agent noise': 0.01,
258+
'clip replay': True,
259+
260+
'explore rate after loop': 0.2,
261+
'agent noise after loop': 0.1,
262+
263+
'max test steps': 1000
264+
}
265+
#######################
266+
# just transfer distributions from teacher
267+
learn_environment(
268+
teacher,
269+
createModel(shape=MODEL_INPUT_SHAPE),
270+
{
271+
**DEFAULT_LEARNING_PARAMS,
272+
'name': 'distilled',
273+
'teacher power': lambda epoch: 1,
274+
}
275+
)

img/20210106-distilled.jpg

405 KB
Loading

img/20210106-high.jpg

93.5 KB
Loading

img/20210106-low.jpg

116 KB
Loading

test.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -79,12 +79,13 @@ def testAgent(environments, agent, name, metrics, N=20):
7979
'Worst scores (top 90%)': {},
8080
'Best scores (top 10%)': {}
8181
}
82-
models = []
82+
agents = []
8383
for i, x in enumerate(glob.iglob('weights/*.h5')):
8484
filename = os.path.abspath(x)
8585
model = createModel(shape=MODEL_INPUT_SHAPE)
8686
model.load_weights(filename)
87-
models.append(model)
87+
if os.path.basename(filename).startswith('agent-'):
88+
agents.append(model)
8889

8990
testAgent(
9091
environments,
@@ -95,7 +96,7 @@ def testAgent(environments, agent, name, metrics, N=20):
9596

9697
testAgent(
9798
environments,
98-
DQNEnsembleAgent(models),
99+
DQNEnsembleAgent(agents),
99100
name='ensemble',
100101
metrics=metrics
101102
)

view_maze.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -104,9 +104,10 @@ def _createNewAgent(self):
104104
filename = os.path.abspath(x)
105105
model = createModel(shape=self._maze.input_size)
106106
model.load_weights(filename)
107-
models.append(model)
108-
agent = DQNAgent(model)
109107
name = os.path.basename(filename)
108+
if name.startswith('agent-'):
109+
models.append(model)
110+
agent = DQNAgent(model)
110111

111112
self._agents.append(RLAgent(
112113
name[:-3], agent, None, None

weights/distilled.h5

1.92 MB
Binary file not shown.

0 commit comments

Comments
 (0)