TRYING TO IMPLEMENT THE RESNET 34 ARCHITECTURE

In [1]:
import tensorflow as tf

In [2]:
class ResidualUnit(tf.keras.Layer):
    def __init__(self, filters, strides = 1, activation ="relu", **kwargs):
        super().__init__(**kwargs)
        self.activation = tf.keras.activations.get(activation)
        self.main_layers = [
            tf.keras.layers.Conv2D(filters,3,strides = strides,padding='same',use_bias = False),
            tf.keras.layers.BatchNormalization(),
            self.activation,
            tf.keras.layers.Conv2D(filters,3,strides = strides,padding='same',use_bias = False),
            tf.keras.layers.BatchNormalization()]
        self.skip_layers = []
        if strides > 1:
            self.skip_layers = [
                tf.keras.layers.Conv2D(filters,1,strides = strides,padding = 'same',use_bias = False),
                tf.keras.layers.BatchNormalization()
            ]
    def call(self,inputs):
        z ,skip_Z= inputs,inputs
        for layer in self.main_layers:
            z = layer(z)
        for skip_layer in self.skip_layers:
            skip_Z = skip_layer(skip_Z)
        return self.activation(z + skip_Z)   
        

CREATING THE RESNET-34 MODEL

In [7]:
model = tf.keras.models.Sequential([
    tf.keras.layers.Conv2D(64,7,padding='same',strides=2),
    tf.keras.layers.BatchNormalization(),
    tf.keras.layers.Activation('relu'),
    tf.keras.layers.MaxPool2D(pool_size=3,padding='same',strides=2)
])

ADDING THE RESIDUAL UNITS

In [8]:
prev_filter = 64
for filter in [64] * 3 + [128] * 4 + [256] * 6 + [512] * 3:
    strides = 1 if filter == prev_filter else 2
    model.add(ResidualUnit(filter,strides=strides))
    prev_filter = filter

ADDING THE FLATTENING LAYER AND A FULY CONNECTED LAYER

In [9]:
model.add(tf.keras.layers.Flatten())
model.add(tf.keras.layers.Dense(10,activation='softmax'))

In [11]:
model = tf.keras.applications.resnet50.ResNet50(weights=None)
model_2 = tf.keras.applications.ResNet50(weights=None)
print(model.summary(), model_2.summary())

None None


In [None]:
tf.image.resizetf.()
tf.keras.applications.resnet50.preprocess_input()
