In [0]:
import numpy as np
import tensorflow as tf
import tensorflow.keras.layers as layers

In [3]:
(train_x, train_y), (test_x, test_y) = tf.keras.datasets.cifar10.load_data()

Downloading data from https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz


In [0]:
def ResNetUnit(x, filter_size, kernel_size, bn, act, strides=1):
  if bn:
    x = layers.BatchNormalization()(x)
  if act:
    x = layers.Activation('relu')(x)
  x = layers.Conv2D(filter_size, kernel_size, strides, 'same', kernel_initializer='he_normal')(x)
  return x

def ResNet(input_shape, block_num, num_classes):
  # 32, 32, 16
  # stage 0 32, 32, 64
  # stage 1 16, 16, 128
  # stage 2 8, 8, 256
  filter_in = 16
  filter_out = 64
  inputs = tf.keras.layers.Input(input_shape)
  x = layers.Conv2D(filter_in, 3, 1, 'same')(inputs)
  for stage in range(3):
    for block in range(block_num):
      strides = 1
      if block == 0 and stage != 0:
        strides = 2   #down sample
      y = ResNetUnit(x, filter_in, 1, True, True, strides)
      y = ResNetUnit(y, filter_in, 3, True, True)
      y = ResNetUnit(y, filter_out, 1, True, True)
      if block == 0:
        x = ResNetUnit(x, filter_out, 1, False, False, strides)    #match size of x to y
      x = layers.add([x, y])
      filter_in = filter_out
    filter_out = filter_in * 2
  x = layers.BatchNormalization()(x)
  x = layers.Activation('relu')(x)
  x = layers.AveragePooling2D(pool_size=8)(x)
  x = layers.Flatten()(x)
  outputs = layers.Dense(num_classes, activation='softmax', kernel_initializer='he_normal')(x)
  return tf.keras.models.Model(inputs=inputs, outputs=outputs)

In [0]:
epochs = 100
input_shape = train_x.shape[1:]
block_num = (110-2)//9   #9 layers in block loop (bn, act, conv(1)), (bn, act, conv(3)), (bn, act, conv(1))
num_classes = 10
train_step = len(train_x)//32
test_step = len(test_x)//32

In [0]:
model = ResNet(input_shape,block_num, num_classes)

In [0]:
model.compile(
    loss='categorical_crossentropy',
    optimizer=tf.keras.optimizers.Adam(0.01),
    metrics=['accuracy']
)

In [0]:
train_y_one = tf.keras.utils.to_categorical(train_y,num_classes)
test_y_one = tf.keras.utils.to_categorical(test_y, num_classes)

train_datagen = tf.keras.preprocessing.image.ImageDataGenerator(
    height_shift_range=0.1,
    width_shift_range=0.1,
    horizontal_flip=True,
    rescale=1./255.
)

test_datagen = tf.keras.preprocessing.image.ImageDataGenerator(rescale=1./255.)

train_generator = train_datagen.flow(train_x, train_y_one, shuffle=True)
test_generator = test_datagen.flow(test_x, test_y_one)

In [0]:
model.fit_generator(
    train_generator, train_step, epochs,
    validation_data=test_generator, validation_steps=test_step
)

Epoch 1/100
Epoch 2/100
Epoch 3/100
Epoch 4/100
Epoch 5/100
Epoch 6/100
Epoch 7/100
Epoch 8/100
Epoch 9/100
Epoch 10/100
Epoch 11/100
Epoch 12/100
Epoch 13/100
Epoch 14/100
Epoch 15/100
Epoch 16/100
Epoch 17/100
Epoch 18/100
Epoch 19/100
Epoch 20/100
Epoch 21/100
Epoch 22/100
Epoch 23/100
Epoch 24/100
Epoch 25/100
Epoch 26/100
Epoch 27/100
Epoch 28/100
Epoch 29/100
Epoch 30/100
Epoch 31/100
Epoch 32/100
Epoch 33/100
Epoch 34/100
Epoch 35/100

In [0]:
def lr_schedule(epoch):
  lr = 1e-3
  if epoch > 180:
      lr *= 0.5e-3
  elif epoch > 160:
      lr *= 1e-3
  elif epoch > 120:
      lr *= 1e-2
  elif epoch > 80:
      lr *= 1e-1
  print('Learning rate: ', lr)
  return lr

In [8]:
model.compile(
    loss='categorical_crossentropy',
    optimizer=tf.keras.optimizers.Adam(lr_schedule(0)),
    metrics=['accuracy']
)

Learning rate:  0.001


In [10]:
history = model.fit_generator(
    train_generator, train_step, epochs,
    validation_data=test_generator, validation_steps=test_step
)

Epoch 1/100
Instructions for updating:
Use tf.where in 2.0, which has the same broadcast rule as np.where
Epoch 2/100
Epoch 3/100
Epoch 4/100
Epoch 5/100
Epoch 6/100
Epoch 7/100
Epoch 8/100
Epoch 9/100
Epoch 10/100
Epoch 11/100
Epoch 12/100
Epoch 13/100
Epoch 14/100
Epoch 15/100
Epoch 16/100
Epoch 17/100
Epoch 18/100
Epoch 19/100
Epoch 20/100
Epoch 21/100
Epoch 22/100
Epoch 23/100
  18/1562 [..............................] - ETA: 19:16 - loss: 0.2176 - acc: 0.9288

KeyboardInterrupt: ignored