In [2]:
import tensorflow.keras.backend as K
import tensorflow as tf
from tensorflow.keras.layers import Conv2D, BatchNormalization, Dense, GlobalAveragePooling2D, MaxPool2D
from tensorflow.keras.models import Model
from tensorflow.keras.optimizers import SGD, Adam, RMSprop

In [7]:
IMG_HEIGHT, IMG_WIDTH = 300, 300

In [12]:
class ResNetBlock(Model):
    def __init__(self, filters, kernel_size = (3, 3), strides=1):
        super(ResNetBlock, self).__init__()
        self.filters = filters
        self.kernel_size = kernel_size
        self.strides = strides

        self.conv_0 = Conv2D(
            input_shape = (IMG_HEIGHT, IMG_WIDTH), kernel_size=self.kernel_size, filters=self.filters, padding='SAME', activation='relu',
            strides = self.strides
        )

        self.conv_1 = Conv2D(
            input_shape = (IMG_HEIGHT, IMG_WIDTH), kernel_size=self.kernel_size, filters=self.filters, padding='SAME', activation='relu',
            strides = self.strides
        )

    def call(self, inputs):
        x = self.conv_0(inputs)
        x = self.conv_1(x)
        x = BatchNormalization()(x)
        return x

In [16]:
class ResNet(Model):
    def __init__(self, n_blocks, n_classes):
        super(ResNet, self).__init__()
        self.conv_0 = Conv2D(kernel_size=(3, 3), filters=64, strides=2, padding='SAME', activation='relu')
        self.maxpool = MaxPool2D(pool_size=(3, 3), strides=2)
        self.n_blocks = n_blocks
        self.n_classes = n_classes
        for i in range(1, n_blocks + 1):
            vars(self)[f"block_{i}"] = ResNetBlock(
                64, 3, 1
            )

        self.avgpool = GlobalAveragePooling2D()
        self.classifier = Dense(units=n_classes, activation='softmax')

    def call(self, inputs):
        x = self.conv_0(inputs)
        x = self.maxpool(x)
        for i in range(1, self.n_blocks + 1):
            x = vars(self)[f"block_{i}"](x)
        
        x = self.avgpool(x)
        x = self.classifier(x)
        return x


In [17]:
model = ResNet(5, 10)

In [19]:
model.compile(loss='sparse_categorical_crossentropy', optimizer=SGD(), metrics=['accuracy'])