-------------

# Imports

In [9]:

from PIL import Image  # To transform the image in the Processor
import numpy as np
import gym

# Convolutional Backbone Network
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, Activation, Flatten, Convolution2D, Permute
from tensorflow.keras.optimizers import Adam

# Keras-RL
from rl.agents.dqn import DQNAgent
from rl.policy import LinearAnnealedPolicy, EpsGreedyQPolicy
from rl.memory import SequentialMemory
from rl.core import Processor
from rl.callbacks import FileLogger, ModelIntervalCheckpoint
from tensorflow.keras.models import load_model


In [10]:
env = gym.make("poison:poison-v0")
nb_actions = env.action_space.n


-----
# Image Processing

In [11]:
nb_actions

4

In [12]:
IMG_SHAPE = (84, 84)
WINDOW_LENGTH = 4


In [13]:
class ImageProcessor(Processor):
    def process_observation(self, observation):
        # First convert the numpy array to a PIL Image
        img = Image.fromarray(observation)
        # Then resize the image
        img = img.resize(IMG_SHAPE)
        # And convert it to grayscale  (The L stands for luminance)
        img = img.convert("L")
        # Convert the image back to a numpy array and finally return the image
        img = np.array(img)
        return img.astype('uint8')  # saves storage in experience memory
    
    def process_state_batch(self, batch):

        # We divide the observations by 255 to compress it into the intervall [0, 1].
        # This supports the training of the network
        # We perform this operation here to save memory.
        processed_batch = batch.astype('float32') / 255.
        return processed_batch

    def process_reward(self, reward):
        return np.clip(reward, -1., 1.)


## Model Creation

**NOTE: Depending on your custom environment, this model will vary greatly, try reading papers that are solving similar problems to your own!**

In [14]:
input_shape = (WINDOW_LENGTH, IMG_SHAPE[0], IMG_SHAPE[1])
input_shape

(4, 84, 84)

In [15]:
# MODEL LOADING

#model = load_model("test_dqn_snake_weights_100000.h5f")

# MODEL FROM SCRATCH
model = Sequential()
model.add(Permute((2, 3, 1), input_shape=input_shape))

model.add(Convolution2D(32, (8, 8), strides=(4, 4),kernel_initializer='he_normal'))
model.add(Activation('relu'))
model.add(Convolution2D(64, (4, 4), strides=(2, 2), kernel_initializer='he_normal'))
model.add(Activation('relu'))
model.add(Convolution2D(64, (3, 3), strides=(1, 1), kernel_initializer='he_normal'))
model.add(Activation('relu'))
model.add(Flatten())
model.add(Dense(512))
model.add(Activation('relu'))
model.add(Dense(nb_actions))
model.add(Activation('linear'))
print(model.summary())

#print(model.summary())

OSError: Unable to open file (file signature not found)

----
## Creating the Agent

In [None]:
memory = SequentialMemory(limit=1000000, window_length=WINDOW_LENGTH)


In [None]:
processor = ImageProcessor()


In [None]:
policy = LinearAnnealedPolicy(EpsGreedyQPolicy(), attr='eps', value_max=1., value_min=.1, value_test=.05,
                              nb_steps=1000000)


In [None]:
dqn = DQNAgent(model=model, nb_actions=nb_actions, policy=policy, memory=memory,
               processor=processor, nb_steps_warmup=50000, gamma=.99, target_model_update=10000,
              train_interval=4, delta_clip=1)

In [None]:
dqn.compile(Adam(learning_rate=.00025), metrics=['mae'])


In [None]:
weights_filename = 'test_dqn_poison_weights.h5f'
checkpoint_weights_filename = 'test_dqn_' + "poison" + '_weights_{steps}.h5f'
checkpoint_callback = ModelIntervalCheckpoint(checkpoint_weights_filename, interval=100000)


In [None]:
dqn.fit(env, nb_steps=1500000, callbacks=[checkpoint_callback], log_interval=100000, visualize=False)

# After training is done, we save the final weights one more time.
dqn.save_weights(weights_filename, overwrite=True)



In [None]:
# Load the weights
model.load_weights("test_dqn_poison_weights.h5f")


memory = SequentialMemory(limit=1000000, window_length=WINDOW_LENGTH)
policy = LinearAnnealedPolicy(EpsGreedyQPolicy(), attr='eps', value_max=1, value_min=.1, value_test=.05,
                              nb_steps=100000)

processor = ImageProcessor()

# Initialize the DQNAgent with the new model and updated policy and compile it
dqn = DQNAgent(model=model, nb_actions=nb_actions, policy=policy, memory=memory,
               processor=processor, nb_steps_warmup=50000, gamma=.99, target_model_update=10000)
dqn.compile(Adam(learning_rate=.00025), metrics=['mae'])


In [None]:
env.sleep = 0.2

In [None]:
dqn.test(env, nb_episodes=10, visualize=True)