In [None]:
#Adapted from: https://github.com/keras-team/keras-applications/blob/master/keras_applications/resnet50.py
def ResNet50Custom():
    
    def identity_block(input_tensor, kernel_size, filters, ax):
    
        filters1, filters2, filters3 = filters
        x = Conv2D(filters1, (1, 1), kernel_initializer='he_normal',)(input_tensor)
        x = BatchNormalization(axis=ax)(x)
        x = Activation('relu')(x)
        x = Conv2D(filters2, kernel_size, padding='same',kernel_initializer='he_normal')(x)
        x = BatchNormalization(axis=ax)(x)
        x = Activation('relu')(x)
        x = Conv2D(filters3, (1, 1), kernel_initializer='he_normal')(x)
        x = BatchNormalization(axis=ax)(x)
        x = add([x, input_tensor])
        return Activation('relu')(x)      

    def conv_block(input_tensor=None, kernel_size=None, filters=None, strides=(2, 2), ax=1):
        
        filters1, filters2, filters3 = filters
        x = Conv2D(filters1, (1, 1), strides=strides, kernel_initializer='he_normal')(input_tensor)
        x = BatchNormalization(axis=ax,)(x)
        x = Activation('relu')(x)
        x = Conv2D(filters2, kernel_size, padding='same', kernel_initializer='he_normal')(x)
        x = BatchNormalization(axis=ax)(x)
        x = Activation('relu')(x)
        x = Conv2D(filters3, (1, 1),kernel_initializer='he_normal')(x)
        x = BatchNormalization(axis=ax)(x)
        shortcut = Conv2D(filters3, (1, 1), strides=strides, kernel_initializer='he_normal')(input_tensor)
        shortcut = BatchNormalization(axis=ax)(shortcut)
        x = add([x, shortcut])
        return Activation('relu')(x)

    axis = 1
    
    img_input = Input(shape=(229, 229, 1))
    x = ZeroPadding2D(padding=(3, 3),input_shape=(299, 299, 1))(img_input)
    x = Conv2D(64, (7, 7), strides=(2, 2), padding='valid', kernel_initializer='he_normal')(x)
    x = BatchNormalization(axis=axis)(x)
    x = Activation('relu')(x)
    x = ZeroPadding2D(padding=(1, 1))(x)
    x = MaxPooling2D((3, 3), strides=(2, 2))(x)

    x = conv_block(x, 3, [64, 64, 256], strides=(1, 1),ax=axis)
    x = identity_block(x, 3, [64, 64, 256], ax=axis)
    x = identity_block(x, 3, [64, 64, 256], ax=axis)

    x = conv_block(x, 3, [128, 128, 512], ax=axis)
    x = identity_block(x, 3, [128, 128, 512], ax=axis)
    x = identity_block(x, 3, [128, 128, 512], ax=axis)
    x = identity_block(x, 3, [128, 128, 512], ax=axis)

    x = conv_block(x, 3, [256, 256, 1024], ax=axis)
    x = identity_block(x, 3, [256, 256, 1024], ax=axis)
    x = identity_block(x, 3, [256, 256, 1024], ax=axis)
    x = identity_block(x, 3, [256, 256, 1024], ax=axis)
    x = identity_block(x, 3, [256, 256, 1024], ax=axis)
    x = identity_block(x, 3, [256, 256, 1024], ax=axis)

    x = conv_block(x, 3, [512, 512, 2048], ax=axis)
    x = identity_block(x, 3, [512, 512, 2048], ax=axis)
    x = identity_block(x, 3, [512, 512, 2048], ax=axis)
    
    x = Flatten()(x)
                      
    x = Dense(units=2048,activation="relu")(x)
    x = Dense(units=2048,activation="relu")(x)
    out = Dense(num_classes, activation="softmax")(x)
    
    opt = SGD(lr=0.01, momentum=0.9)
    model = Model(inputs=img_input, outputs=out)
    model.compile(optimizer=opt, loss='categorical_crossentropy', metrics=["accuracy"])
    model.summary()
    return model