In [18]:
import numpy as np
from keras.layers import Input, Conv2D, Dense, MaxPooling2D , Flatten
from keras.models import Model, Sequential, load_model
from collections import deque

In [20]:
class DQN:
    def __init__(self, input_shape, output_shape, discount=0.99, update_target_every=10, memory_size=2000):
        self.input_shape=input_shape
        self.output_shape=output_shape
        self.discount=discount
        self.update_target_every=update_target_every
        self.policy_net=self.create_model()
        self.memory=deque(maxlen=memory_size)
        self.target_counter=0 
    
    def create_model(self):
        model=Sequential()
        model.add(Conv2D(input_shape=self.input_shape, filters=128, kernel_size=(7,7), strides=(1,1), padding="valid", 
                        activation="relu", use_bias=True,))
        model.add(MaxPooling2D(pool_size=(3,3), padding="valid"))
        model.add(Conv2D(filters=128, kernel_size=(7,7), strides=(2,2), padding="valid", 
                        activation="relu", use_bias=True,))
        model.add(MaxPooling2D(pool_size=(2,2), padding="valid"))
        model.add(Conv2D(filters=128, kernel_size=(7,7), strides=(2,2), padding="valid", 
                        activation="relu", use_bias=True,))
        model.add(Flatten())
        model.add(Dense(512, activation="relu"))
        model.add(Dense(self.output_shape, activation="softmax"))
        return model
       
        

In [25]:
dqn=DQN([200 , 128 ,3] , 6)

In [26]:
rete = dqn.policy_net

In [27]:
rete.summary()

_________________________________________________________________
Layer (type)                 Output Shape              Param #   
conv2d_8 (Conv2D)            (None, 194, 122, 128)     18944     
_________________________________________________________________
max_pooling2d_5 (MaxPooling2 (None, 64, 40, 128)       0         
_________________________________________________________________
conv2d_9 (Conv2D)            (None, 29, 17, 128)       802944    
_________________________________________________________________
max_pooling2d_6 (MaxPooling2 (None, 14, 8, 128)        0         
_________________________________________________________________
conv2d_10 (Conv2D)           (None, 4, 1, 128)         802944    
_________________________________________________________________
flatten_2 (Flatten)          (None, 512)               0         
_________________________________________________________________
dense_3 (Dense)              (None, 512)               262656    
__________