In [1]:
import tensorflow as tf
from tensorflow.keras.layers import Dense, Conv2D, Layer, Input, Flatten, concatenate, Lambda, ZeroPadding2D, MaxPool2D, AveragePooling2D
import tensorflow_datasets as tfds
import numpy as np
import math
from tensorflow.keras.utils import plot_model
from tensorflow.keras.applications import VGG16
from tensorflow.keras.layers import GlobalAveragePooling2D, BatchNormalization, Dropout
from tensorflow.keras.layers.experimental import preprocessing
import matplotlib.pyplot as plt
from tensorflow.keras.models import Sequential
from tensorflow.keras import layers

# Adding Memory to resnet50

In [2]:
def find_closest_perfect_square(n):
    i = n**0.5
    if i%1 !=0: i = int(i)+1
    else: i = int(i)
    return (i*i, i)

In [3]:
memory_channels = 1024
memory_shape = (7, 7, memory_channels)
total_m = memory_shape[0] * memory_shape[1] * memory_shape[2]
closest_perfect_square = find_closest_perfect_square(total_m)
memory_input = Input(shape = memory_shape)
Memory = Conv2D(memory_channels, (1,1), activation='linear')(memory_input)

In [4]:
total_m, closest_perfect_square

(50176, (50176, 224))

In [5]:
def memory_reshape_block(memory, target_shape, mem_shape, closest_perfect_square = closest_perfect_square):
    p_sq, sq = closest_perfect_square
    _, n_rows, n_cols, n_c = target_shape #14, 14, 1
    m_rows, m_cols, m_c = mem_shape #7, 7, 6
    
    total_m = m_rows * m_cols * m_c #minimum total space for memory # 294
    total_t = n_rows * n_cols  #total space in target #196

    if n_rows > sq: 

        x = layers.Reshape((total_m, 1))(memory)
        x = layers.ZeroPadding1D(padding = (p_sq - total_m)//2)(x)
        x = layers.Reshape((sq, sq, 1))(x)
        x = layers.ZeroPadding2D(padding = (n_rows - sq)//2 )(x)
        return x
    
    if total_t < total_m: 
        buckets = math.ceil(total_m/(n_rows*n_rows)) 
        x = layers.Reshape((n_rows, n_cols, buckets))(memory)
        return x

In [6]:
def memory_update_block(input_tensor, memory_shape, Memory, filters, kernel_size, stage, block):
    
    filters1, filters2, filters3 = filters
    bn_axis = 3

    memory_name_base = 'Memory' + str(stage) + block + '_branch'
    memory_pad_base = 'mem_pad' + str(stage) + block + '_branch'
    concat_name_base = 'mem_concat' + str(stage) + block + '_branch'
    conv_name_base = 'mem_conv' + str(stage) + block + '_branch'
    pool_name_base = 'mem_pool' + str(stage) + block + '_branch'
    bn_name_base = 'mem_bn' + str(stage) + block + '_branch'

    _, n_rows, n_cols, n_c = input_tensor.shape #224, 224, 3
    m_rows, m_cols, m_c = memory_shape # 7, 7, 3

    memory_pad = memory_reshape_block(Memory, input_tensor.shape, memory_shape)

    concat = concatenate([input_tensor, memory_pad], name = concat_name_base) # 56 x 56 x 129

    x = Conv2D(filters = filters1, kernel_size = kernel_size, 
                           padding = 'same', kernel_initializer='he_normal', 
                           name = conv_name_base + '_a')(concat) 
    x = layers.BatchNormalization(axis=bn_axis, name=bn_name_base + '_2a')(x)
    x = layers.Activation('sigmoid')(x)

    x = Conv2D(filters = filters2, kernel_size = kernel_size,
                           padding = 'same', kernel_initializer='he_normal', 
                           name = conv_name_base + '_b')(x) 
    x = layers.BatchNormalization(axis=bn_axis, name=bn_name_base + '_2b')(x)
    x = layers.Activation('sigmoid')(x)

    x = Conv2D(filters = filters3, kernel_size = kernel_size, padding = 'same',
                           kernel_initializer='he_normal',
                           name = conv_name_base + '_c')(x) 
    x = layers.BatchNormalization(axis=bn_axis, name=bn_name_base + '_2c')(x)
    x = layers.Activation('sigmoid')(x)

    factor = (n_rows//m_rows, n_cols//m_cols)
    x = AveragePooling2D(pool_size=factor,strides=factor)(x)
    Memory = Conv2D(filters=m_c, kernel_size=(1, 1), strides=(1,1), padding='same', name = memory_name_base)(x) 
    return Memory

In [7]:
def identity_block(input_tensor, kernel_size, filters, stage, block, memory, memory_shape=memory_shape):
    """The identity block is the block that has no conv layer at shortcut.
    # Arguments
        input_tensor: input tensor
        kernel_size: default 3, the kernel size of
            middle conv layer at main path
        filters: list of integers, the filters of 3 conv layer at main path
        stage: integer, current stage label, used for generating layer names
        block: 'a','b'..., current block label, used for generating layer names
    # Returns
        Output tensor for the block.
    """
    filters1, filters2, filters3 = filters
    
    memory_name_base = 'memory' + str(stage) + block + '_branch'
    memory_pad_base = 'memory_pad' + str(stage) + block + '_branch'
    concat_name_base = 'concat' + str(stage) + block + '_branch'
    conv_name_base = 'res' + str(stage) + block + '_branch'
    bn_name_base = 'bn' + str(stage) + block + '_branch'
    
    bn_axis = 3
    '''   
    _, n_rows, n_cols, n_c = input_tensor.shape #224, 224, 3
    m_rows, m_cols, m_c = memory_shape # 7, 7, 3
    odd_even_balance = 0 if m_rows%2 == 0 else 1
    pad1 = (n_rows//2 - m_rows//2 - odd_even_balance, n_rows//2 - m_rows//2) if n_rows!=m_rows else 0 #(108, 109)
    pad2 = (n_rows//2 - m_rows//2, n_cols//2 - m_cols//2 - odd_even_balance) if n_rows!=m_rows else 0 #(108, 109)
    '''
    

    memory_pad = memory_reshape_block(Memory, input_tensor.shape, memory_shape)
                                      
    concat = layers.concatenate([input_tensor, memory_pad])

    x = layers.Conv2D(filters1, (1, 1),
                      kernel_initializer='he_normal',
                      name=conv_name_base + '2a')(concat)
    x = layers.BatchNormalization(axis=bn_axis, name=bn_name_base + '2a')(x)
    x = layers.Activation('relu')(x)

    x = layers.Conv2D(filters2, kernel_size,
                      padding='same',
                      kernel_initializer='he_normal',
                      name=conv_name_base + '2b')(x)
    x = layers.BatchNormalization(axis=bn_axis, name=bn_name_base + '2b')(x)
    x = layers.Activation('relu')(x)

    x = layers.Conv2D(filters3, (1, 1),
                      kernel_initializer='he_normal',
                      name=conv_name_base + '2c')(x)
    x = layers.BatchNormalization(axis=bn_axis, name=bn_name_base + '2c')(x)

    x = layers.add([x, input_tensor])
    x = layers.Activation('relu')(x)
    return x

In [8]:
def conv_block(input_tensor,kernel_size,
               filters,stage,block,memory,
               strides=(2, 2),
               memory_shape=memory_shape):

    """A block that has a conv layer at shortcut.
    # Arguments
        input_tensor: input tensor
        kernel_size: default 3, the kernel size of
            middle conv layer at main path
        filters: list of integers, the filters of 3 conv layer at main path
        stage: integer, current stage label, used for generating layer names
        block: 'a','b'..., current block label, used for generating layer names
        strides: Strides for the first conv layer in the block.
    # Returns
        Output tensor for the block.
    Note that from stage 3,
    the first conv layer at main path is with strides=(2, 2)
    And the shortcut should have strides=(2, 2) as well
    """
    filters1, filters2, filters3 = filters
    bn_axis = 3

    memory_name_base = 'memory' + str(stage) + block + '_branch'
    memory_pad_base = 'memory_pad' + str(stage) + block + '_branch'
    concat_name_base = 'concat' + str(stage) + block + '_branch'
    conv_name_base = 'res' + str(stage) + block + '_branch'
    bn_name_base = 'bn' + str(stage) + block + '_branch'

    memory_pad = memory_reshape_block(memory, input_tensor.shape, memory_shape)                                  
    concat = layers.concatenate([input_tensor, memory_pad], name = concat_name_base)

    x = layers.Conv2D(filters1, (1, 1), strides=strides,
                      kernel_initializer='he_normal',
                      name=conv_name_base + '2a')(concat)
    x = layers.BatchNormalization(axis=bn_axis, name=bn_name_base + '2a')(x)
    x = layers.Activation('relu')(x)

    x = layers.Conv2D(filters2, kernel_size, padding='same',
                      kernel_initializer='he_normal',
                      name=conv_name_base + '2b')(x)
    x = layers.BatchNormalization(axis=bn_axis, name=bn_name_base + '2b')(x)
    x = layers.Activation('relu')(x)

    x = layers.Conv2D(filters3, (1, 1),
                      kernel_initializer='he_normal',
                      name=conv_name_base + '2c')(x)
    x = layers.BatchNormalization(axis=bn_axis, name=bn_name_base + '2c')(x)

    shortcut = layers.Conv2D(filters3, (1, 1), strides=strides,
                             kernel_initializer='he_normal',
                             name=conv_name_base + '1')(input_tensor)
    shortcut = layers.BatchNormalization(
        axis=bn_axis, name=bn_name_base + '1')(shortcut)

    x = layers.add([x, shortcut])
    x = layers.Activation('relu')(x)
    return x

In [9]:
def ResNet50(include_top=True,
             input_shape=None,
             memory_input=None,
             classes=100,
             **kwargs):

    # Determine proper input shape
    input_shape = input_shape


    img_input = layers.Input(shape=input_shape, name = 'image_input')
    bn_axis = 3

    MEMORY = Conv2D(memory_channels, (1,1), activation='linear', name='mem_init')(memory_input)

    x = layers.ZeroPadding2D(padding=(3, 3), name='conv1_pad')(img_input)
    x = layers.Conv2D(64, (7, 7),
                      strides=(1, 1),
                      padding='valid',
                      kernel_initializer='he_normal',
                      name='conv1')(x)
    x = layers.BatchNormalization(axis=bn_axis, name='bn_conv1')(x)
    x = layers.Activation('relu')(x)
    x = layers.ZeroPadding2D(padding=(1, 1), name='pool1_pad')(x)
    x = layers.MaxPooling2D((3, 3), strides=(2, 2))(x)


    x = conv_block(x, 3, [64, 64, 256], stage=2, block='a', strides=(1, 1), memory = MEMORY)
    MEMORY = memory_update_block(x, memory_shape, MEMORY, [16, 16, 32], (3,3), stage=2, block='a')

    x = identity_block(x, 3, [64, 64, 256], stage=2, block='b', memory = MEMORY)
    x = identity_block(x, 3, [64, 64, 256], stage=2, block='c', memory = MEMORY)
    x = conv_block(x, 3, [128, 128, 512], stage=3, block='a', memory = MEMORY)
    MEMORY = memory_update_block(x, memory_shape, MEMORY, [16, 16, 32], (3,3), stage=3, block='a')

    x = identity_block(x, 3, [128, 128, 512], stage=3, block='b', memory = MEMORY)
    x = identity_block(x, 3, [128, 128, 512], stage=3, block='c', memory = MEMORY)
    x = identity_block(x, 3, [128, 128, 512], stage=3, block='d', memory = MEMORY)
    x = conv_block(x, 3, [256, 256, 1024], stage=4, block='a', memory = MEMORY)
    MEMORY = memory_update_block(x, memory_shape, MEMORY, [16, 16, 32], (3,3), stage=4, block='a')

    x = identity_block(x, 3, [256, 256, 1024], stage=4, block='b', memory = MEMORY)
    x = identity_block(x, 3, [256, 256, 1024], stage=4, block='c', memory = MEMORY)
    x = identity_block(x, 3, [256, 256, 1024], stage=4, block='d', memory = MEMORY)
    x = identity_block(x, 3, [256, 256, 1024], stage=4, block='e', memory = MEMORY)
    x = identity_block(x, 3, [256, 256, 1024], stage=4, block='f', memory = MEMORY)
    x = conv_block(x, 3, [512, 512, 2048], stage=5, block='a', memory = MEMORY)
    MEMORY = memory_update_block(x, memory_shape, MEMORY, [16, 16, 32], (3,3), stage=5, block='a')

    x = identity_block(x, 3, [512, 512, 2048], stage=5, block='b', memory = MEMORY)
    x = identity_block(x, 3, [512, 512, 2048], stage=5, block='c', memory = MEMORY)

    if include_top:
        x = layers.GlobalAveragePooling2D(name='avg_pool')(x)
        x = layers.Dense(classes, activation='softmax', name='fc1000')(x)
    else:
        if pooling == 'avg':
            x = layers.GlobalAveragePooling2D()(x)
        elif pooling == 'max':
            x = layers.GlobalMaxPooling2D()(x)
        else:
            warnings.warn('The output shape of `ResNet50(include_top=False)` '
                          'has been changed since Keras 2.2.0.')

    # Ensure that the model takes into account
    # any potential predecessors of `input_tensor`.
    inputs = [img_input, memory_input]
    # Create model.
    model = tf.keras.Model(inputs, x, name='resnet50')

    return model

In [10]:
model = ResNet50(include_top=True, input_shape=(224,224,3), classes=100, memory_input=memory_input)

In [11]:
model.summary()

Model: "resnet50"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
image_input (InputLayer)        [(None, 224, 224, 3) 0                                            
__________________________________________________________________________________________________
conv1_pad (ZeroPadding2D)       (None, 230, 230, 3)  0           image_input[0][0]                
__________________________________________________________________________________________________
conv1 (Conv2D)                  (None, 224, 224, 64) 9472        conv1_pad[0][0]                  
__________________________________________________________________________________________________
bn_conv1 (BatchNormalization)   (None, 224, 224, 64) 256         conv1[0][0]                      
___________________________________________________________________________________________

Number of parameters almost double

In [12]:
plot_model(model, show_shapes=True,expand_nested = True)

Output hidden; open in https://colab.research.google.com to view.

In [13]:
len(model.layers)

251

In [14]:
len(Resnet.layers)

NameError: ignored