In [1]:
from xshinnosuke.layers import Dense, Flatten, Conv2D, MaxPooling2D, AvgPooling2D, BatchNormalization, Activation, Add, Input
from xshinnosuke.models import Model
import cupy as np

In [2]:
def identity_block(X, filters, stage, block, s):
    conv_name_base = "res" + str(stage) + block + "_branch"
    bn_name_base = "bn" + str(stage) + block + "_branch"

    F1, F2 = filters

    X_shortcut = X

    X = Conv2D(out_channels=F1, kernel_size=(3, 3), stride=(s, s), padding=1,
               name=conv_name_base + "2a")(X)
    X = BatchNormalization()(X)
    X = Activation("relu")(X)
    X = Conv2D(out_channels=F2, kernel_size=(3, 3), stride=(1, 1), padding=1,
               name=conv_name_base + "2b")(X)
    X = BatchNormalization()(X)
    X = Add()([X, X_shortcut])
    X = Activation("relu")(X)
    
    return X


def convolutional_block(X, filters, stage, block, s=2):
    conv_name_base = "res" + str(stage) + block + "_branch"
    bn_name_base = "bn" + str(stage) + block + "_branch"

    F1, F2 = filters

    X_shortcut = X

    X = Conv2D(out_channels=F1, kernel_size=(3, 3), stride=(s, s), padding=1,
               name=conv_name_base + "2b")(X)
    X = BatchNormalization()(X)
    X = Activation("relu")(X)
    X = Conv2D(out_channels=F2, kernel_size=(3, 3), stride=(1, 1), padding=1,
               name=conv_name_base + "2c")(X)
    X = BatchNormalization()(X)


    X_shortcut = Conv2D(out_channels=F2, kernel_size=(1, 1), stride=(s, s),
                        name=conv_name_base + "1")(X_shortcut)
    X_shortcut = BatchNormalization()(X_shortcut)

    X = Add()([X, X_shortcut])
    X = Activation("relu")(X)

    return X


def ResNet18(input_shape=(3, 56, 56), classes=100):
    X_input = Input(input_shape)

    # stage1
    X = Conv2D(out_channels=64, kernel_size=(7, 7), stride=(2, 2), name="conv1", padding=3)(X_input)
    X = BatchNormalization(name="bn1")(X)
    X = Activation("relu")(X)
    X = MaxPooling2D(kernel_size=3, stride=2, padding=1)(X)

    # stage2
    X = identity_block(X, filters=[64, 64], stage=2, block="b", s=1)
    X = identity_block(X, filters=[64, 64], stage=2, block="c", s=1)

    # stage3
    X = convolutional_block(X, filters=[128, 128], stage=3, block="a", s=2)
    X = identity_block(X, filters=[128, 128], stage=3, block="b", s=1)


    # stage4
    X = convolutional_block(X, filters=[256, 256], stage=4, block="a", s=2)
    X = identity_block(X, filters=[256, 256], stage=4, block="b", s=1)

    # stage5
    X = convolutional_block(X, filters=[512, 512], stage=5, block="a", s=2)
    X = identity_block(X, filters=[512, 512], stage=5, block="b", s=1)

    X = AvgPooling2D(2)(X)

    X = Flatten()(X)
    X = Dense(classes, name="fc" + str(classes),)(X)

    model = Model(inputs=X_input, outputs=X)

    return model

In [3]:
# random generate data
x = np.random.rand(500, 3, 56, 56)
y = np.random.randint(0, 100, (500,))

In [4]:
net = ResNet18()
net.compile(optimizer='sgd', loss='cross_entropy')
print(net)

***************************************************************************
Layer(type)               Output Shape         Param      Connected to   
###########################################################################
input0 (Input)            (None, 3, 56, 56)    0          
              
---------------------------------------------------------------------------
conv1 (Conv2D)            (None, 64, 28, 28)   9408       input0         
---------------------------------------------------------------------------
bn1 (BatchNormalization)  (None, 64, 28, 28)   128        conv1          
---------------------------------------------------------------------------
activation0 (Activation)  (None, 64, 28, 28)   0          bn1            
---------------------------------------------------------------------------
maxpooling2d0 (MaxPooling2D) (None, 64, 14, 14)   0          activation0    
---------------------------------------------------------------------------
res2b_branch2a (Conv2

In [5]:
history = net.fit(x, y, batch_size=32, epochs=5)

[1;31m Epoch[1/5][0m
[1;31m Epoch[2/5][0m
[1;31m Epoch[3/5][0m
[1;31m Epoch[4/5][0m
[1;31m Epoch[5/5][0m
