In [1]:
import tensorflow as tf
import numpy as np
from tensorflow import keras
from tensorflow.keras import layers as ly

In [2]:
def BasicBlock(x, planes, stride=1, downsample=False):
    ex = 1 # expansion factor
    x_skip = x
    if downsample:
        x_skip = ly.Conv2D(planes, (1, 1), strides=2, use_bias=False)(x_skip)
        x_skip = ly.BatchNormalization()(x_skip)
        
    x = ly.Conv2D(planes, (3,3), strides=2 if downsample else 1, padding = 'same', use_bias=False)(x)
    x = ly.BatchNormalization(axis=3)(x)
    x = ly.Activation('relu')(x)
    x = ly.Conv2D(planes, (3,3), strides=1, padding = 'same', use_bias=False)(x)
    x = ly.BatchNormalization(axis=3)(x)
    x = ly.Add()([x, x_skip])     
    x = ly.Activation('relu')(x)
    return x

In [3]:
def BottleneckBlock(x, planes, stride=1, downsample=False):
    ex = 4
    x_skip = x
    if downsample:
        stride = stride
        x_skip = ly.Conv2D(planes * ex, (1, 1), strides=stride, use_bias=False)(x_skip)
        x_skip = ly.BatchNormalization()(x_skip)
    else: 
        stride=1
        
    x = ly.Conv2D(planes, (1,1), strides=1, use_bias=False)(x)
    x = ly.BatchNormalization(axis=3)(x)
    x = ly.Activation('relu')(x)
    x = ly.Conv2D(planes, (3,3), strides=stride, padding = 'same', use_bias=False)(x)
    x = ly.BatchNormalization(axis=3)(x)
    x = ly.Activation('relu')(x)
    x = ly.Conv2D(planes * ex, (1,1), strides=1, use_bias=False)(x)
    x = ly.BatchNormalization(axis=3)(x)
    x = ly.Add()([x, x_skip])     
    x = ly.Activation('relu')(x)
    return x

In [32]:
def StemBlock(x, planes):
    x = ly.ZeroPadding2D(
      padding=((3, 3), (3, 3)), name='conv1_pad')(x)
    x = ly.Conv2D(planes, (7, 7), strides=2, name='conv1')(x)
    x = ly.BatchNormalization()(x)
    x = ly.Activation('relu')(x)
    x = ly.ZeroPadding2D(padding=((1, 1), (1, 1)), name='pool1_pad')(x)
    x = ly.MaxPooling2D(3, strides=2, name='pool1_pool')(x)
    return x

In [45]:
def ResNet(block, n_classes=1000, depths=[2, 2, 2, 2], include_top=True):
    dims = [64, 128, 256, 512]
    strides = [1, 2, 2, 2]
    inplanes = 64
    inputs = keras.Input(shape=(224, 224, 3))
    x = StemBlock(inputs, inplanes)

    for i in range(4):
        dim = dims[i]
        stride = strides[i]
        x = block(x, dim, stride=stride, downsample=True)
        for j in range(1, depths[i]):
            x = block(x, dim, stride=stride, downsample=False)
    
    x = ly.GlobalAveragePooling2D()(x)
    if include_top:
        x = ly.Dense(n_classes)(x)
    return keras.Model(inputs=inputs, outputs=x)

In [46]:
def resnet18(n_classes=1000, include_top=True):
    model = ResNet(BasicBlock, n_classes=n_classes, include_top=include_top)
    return model
def resnet34(n_classes=1000, include_top=True):
    model = ResNet(BasicBlock, n_classes=n_classes, 
                   depths=[3, 4, 6, 3], include_top=include_top)
    return model
def resnet50(n_classes=1000, include_top=True):
    model = ResNet(BottleneckBlock, n_classes=n_classes, 
                   depths=[3, 4, 6, 3], include_top=include_top)
    return model
def resnet101(n_classes=1000, include_top=True):
    model = ResNet(BottleneckBlock, n_classes=n_classes, 
                   depths=[3, 4, 23, 3], include_top=include_top)
    return model

In [47]:
r18 = resnet18()
r34 = resnet34()
r50 = resnet50()
r101 = resnet101()
models = [r18, r34, r50, r101]

In [36]:
%%time
r18(np.random.random((1, 224, 224, 3))).shape # should be [1, 1000]

CPU times: user 13.6 ms, sys: 8.84 ms, total: 22.4 ms
Wall time: 15 ms


TensorShape([1, 1000])

In [49]:
r18_headless = resnet18(include_top=False)
r18_headless(np.random.random((1, 224, 224, 3))).shape

TensorShape([1, 512])

In [37]:
def fmat(n):
    return "{:.2f}M".format(n / 1e6)

In [38]:
def params(model, f = True):
    count = int(np.sum([np.prod(p.shape) for p in model.trainable_variables]))
    return fmat(count) if f else count

In [39]:
for m in models:
    print(params(m))

11.69M
21.80M
25.56M
44.55M


In [43]:
r50_keras = keras.applications.ResNet50(weights=None, input_shape=(224, 224, 3))