In [1]:
try:
    %tensorflow_version 2.x
except Exception:
    pass

In [2]:
import tensorflow as tf
from tensorflow.keras.layers import Conv2D, Dense, Input, BatchNormalization, Activation, Add, MaxPool2D, AveragePooling2D, Flatten, ZeroPadding2D
from tensorflow.keras.models import Model
from tensorflow.python.keras.utils.vis_utils import plot_model
import tensorflow_datasets as tfds

In [3]:
class IdentityBlock(Model):
    def __init__(self, filters=None, activation=None, kernel_size=None):
        super(IdentityBlock , self ).__init__(name  ='')
        self.conv1 = Conv2D(filters, kernel_size, padding='same')
        self.bn = BatchNormalization()
        self.act = Activation(activation)
        self.add = Add()
    def call(self, input):
        x = self.conv1(input)
        x = self.bn(x)
        x = self.act(x)

        x = self.conv1(x)
        x = self.bn(x)
        
        x = self.add([x, input])
        
        x = self.act(x)

        return x

In [4]:
class DownSamplingBlock(Model):
    def __init__(self, filters=None, activation=None, kernel_size=None, stride=None):
        super(DownSamplingBlock, self).__init__()
        self.conv1 = Conv2D(filters, kernel_size, strides=stride)
        self.bn = BatchNormalization()
        self.conv2 = Conv2D(filters, kernel_size, padding='same')
        self.conv3 = Conv2D(filters, 3, strides=stride)
        self.act = Activation(activation)
        self.add = Add()
        self.pad = ZeroPadding2D()
    def call(self, input):
        # print("-----")
        # print(input.shape)
        x_skip = input
        x = self.conv1(input)
        x = self.bn(x)
        x = self.act(x)

        x = self.conv2(x)
        x = self.bn(x)

        x_skip = self.conv3(x_skip)
        # print(x.shape, x_skip)
        x = self.add([x, x_skip])
        
        x = self.act(x)

        return x

In [5]:
class ResNet(Model):
    def __init__(self, units=None, kernel_size=None, filters=None, stride=None, activation=None, num_classes=None):
        super(ResNet, self).__init__()
        self.conv1 = Conv2D(filters, kernel_size, strides=stride)
        self.bn = BatchNormalization()
        self.act = Activation(activation)
        self.pool1 = MaxPool2D(pool_size=(3, 3), strides=stride)
        self.id1 = IdentityBlock(64, activation, 3)
        self.ds1 = DownSamplingBlock(128, activation, 3, stride)
        self.id2 = IdentityBlock(128, activation, 3)
        self.ds2 = DownSamplingBlock(256, activation, 3, stride)
        self.id3 = IdentityBlock(256, activation, 3)
        self.ds3 = DownSamplingBlock(512, activation, 3, stride)
        self.id4 = IdentityBlock(512, activation, 3)
        self.pool2 = AveragePooling2D()
        self.dense1 = Flatten()
        self.dense2 = Dense(units, activation=activation)
        self.dense3 = Dense(num_classes, activation='softmax')
    def call(self, input):
        # print(input.shape)
        x = self.conv1(input)
        
        # print(x.shape)
        x = self.bn(x)
        x = self.act(x)
        x = self.pool1(x)
        # print(x.shape)
        for _ in range(3):
            x = self.id1(x)
        
        # print(x.shape)
        x = self.ds1(x)
        
        # print(x.shape)
        for _ in range(3):
            x = self.id2(x)
        
        x = self.ds2(x)
        for _ in range(5):
            x = self.id3(x)
        
        x = self.ds3(x)
        for _ in range(2):
            x = self.id4(x)
        
        x = self.pool2(x)
        x = self.dense1(x)
        x = self.dense2(x)
        x = self.dense3(x)

        return x

In [6]:
resnet = ResNet(1000, 7, 64, (2,2), 'relu', 2)

In [7]:
resnet.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])

In [8]:
def preprocess(f):
    return tf.cast(f['image'] , 'float32') / 255.0  , f["label"]

dataset = tfds.load('horses_or_humans', split='train')
# print(dataset)
dataset = dataset.map(preprocess).batch(32)

In [9]:
history = resnet.fit(dataset , epochs=15)

Epoch 1/15
Epoch 2/15
Epoch 3/15
Epoch 4/15
Epoch 5/15
Epoch 6/15
Epoch 7/15
Epoch 8/15
Epoch 9/15
Epoch 10/15
Epoch 11/15
Epoch 12/15
Epoch 13/15
Epoch 14/15
Epoch 15/15


In [10]:
resnet.summary()

Model: "res_net"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
conv2d (Conv2D)              multiple                  9472      
_________________________________________________________________
batch_normalization (BatchNo multiple                  256       
_________________________________________________________________
activation (Activation)      multiple                  0         
_________________________________________________________________
max_pooling2d (MaxPooling2D) multiple                  0         
_________________________________________________________________
identity_block (IdentityBloc multiple                  37184     
_________________________________________________________________
down_sampling_block (DownSam multiple                  295808    
_________________________________________________________________
identity_block_1 (IdentityBl multiple                  1480