In [4]:
from keras.models import Sequential
from keras.layers import Dense, Dropout, Conv2D, MaxPooling2D, Activation, Flatten
from keras.callbacks import TensorBoard
from keras.optimizers import Adam
from collections import deque
import time

In [5]:
REPLAY_MEMORY_SIZE = 50_000
MODEL_NAME = '256x2'

In [8]:
# override log file creation per .fit call to just 1 log file for all .fit calls
class ModifiedTensorBoard(TensorBoard):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        self.step = 1
        self.writer = tf.summary.FileWriter(self.log_dir)
        
    def set_model(self, model):
        pass
    
    # Overrided, saves logs with our step number
    # (otherwise every .fit() will start writing from 0th step)
    def on_epoch_end(self, epoch, logs=None):
        self.update_stats(**logs)

    # Overrided
    # We train for one batch only, no need to save anything at epoch end
    def on_batch_end(self, batch, logs=None):
        pass

    # Overrided, so won't close writer
    def on_train_end(self, _):
        pass

    # Custom method for saving own metrics
    # Creates writer, writes custom metrics and closes writer
    def update_stats(self, **stats):
        self._write_logs(stats, self.step)

In [10]:
class DQNAgent:
    def __init__(self):
        
        # main model: gets trained every step
        self.model = self.create_model()
    
        # target model: what we .predict against every step
        self.target_model = self.create_model()
        self.target_model.set_weights(self.model.get_weights())
    
        self.replay_memory = deque(maxlen=REPLAY_MEMORY_SIZE)
    
        self.tensorboard = ModifiedTensorBoard(log_dir=f'logs/{MODEL_NAME}-{int(time.time())}')
    
        self.target_update_counter = 0
    
    def create_model(self):
        model = Sequential(
            Conv2D(256, (3, 3), input_shape=env.OBSERVATION_SPACE_VALUES),
            Activation('relu'),
            MaxPooling2D(2, 2),
            Dropout(0.2),
            
            Conv2D(256, (3, 3)),
            Activation('relu'),
            MaxPooling2D(2, 2),
            Dropout(0.2),
            
            Flatten(),
            Dense(64),
            Dense(env.ACTION_SPACE_SIZE, activation='linear')
        )
        
        model.compile(loss='mse', optimizer=Adam(lr=0.001), metrics=['accuracy'])
        
        return model
    
    def update_replay_memory(self, transition):
        self.replay_memory.append(transition)
        
    def get_qs(self, state, step):
        return self.model.predict(np.array(state).reshape(-1, *state.shape) / 255)[0]