In [0]:
import tensorflow as tf
from tensorflow.keras import layers
from tensorflow.keras.models import Model
from tensorflow.keras.optimizers import SGD
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.callbacks import LearningRateScheduler, History
from tensorflow.contrib.tpu.python.tpu import keras_support
import tensorflow.keras.backend as K
from keras.datasets import cifar10
from keras.utils import to_categorical
import pickle, os, time
import numpy as np
from tqdm import tqdm

Using TensorFlow backend.


In [0]:
class OctConv2D(layers.Layer):
    def __init__(self, filters, alpha, kernel_size=(3,3), strides=(1,1), 
                    padding="same", kernel_initializer='glorot_uniform',
                    kernel_regularizer=None, kernel_constraint=None,
                    **kwargs):

        assert alpha >= 0 and alpha <= 1
        assert filters > 0 and isinstance(filters, int)
        super().__init__(**kwargs)

        self.alpha = alpha
        self.filters = filters
   
        self.kernel_size = kernel_size
        self.strides = strides
        self.padding = padding
        self.kernel_initializer = kernel_initializer
        self.kernel_regularizer = kernel_regularizer
        self.kernel_constraint = kernel_constraint

        self.low_channels = int(self.filters * self.alpha)
        self.high_channels = self.filters - self.low_channels
        
    def build(self, input_shape):
        assert len(input_shape) == 2
        assert len(input_shape[0]) == 4 and len(input_shape[1]) == 4
  
        assert input_shape[0][1] // 2 >= self.kernel_size[0]
        assert input_shape[0][2] // 2 >= self.kernel_size[1]

        assert input_shape[0][1] // input_shape[1][1] == 2
        assert input_shape[0][2] // input_shape[1][2] == 2

        assert K.image_data_format() == "channels_last"

        high_in = int(input_shape[0][3])
        low_in = int(input_shape[1][3])

        self.high_to_high_kernel = self.add_weight(name="high_to_high_kernel", 
                                    shape=(*self.kernel_size, high_in, self.high_channels),
                                    initializer=self.kernel_initializer,
                                    regularizer=self.kernel_regularizer,
                                    constraint=self.kernel_constraint)

        self.high_to_low_kernel  = self.add_weight(name="high_to_low_kernel", 
                                    shape=(*self.kernel_size, high_in, self.low_channels),
                                    initializer=self.kernel_initializer,
                                    regularizer=self.kernel_regularizer,
                                    constraint=self.kernel_constraint)

        self.low_to_high_kernel  = self.add_weight(name="low_to_high_kernel", 
                                    shape=(*self.kernel_size, low_in, self.high_channels),
                                    initializer=self.kernel_initializer,
                                    regularizer=self.kernel_regularizer,
                                    constraint=self.kernel_constraint)

        self.low_to_low_kernel   = self.add_weight(name="low_to_low_kernel", 
                                    shape=(*self.kernel_size, low_in, self.low_channels),
                                    initializer=self.kernel_initializer,
                                    regularizer=self.kernel_regularizer,
                                    constraint=self.kernel_constraint)
        super().build(input_shape)

    def call(self, inputs):
  
        assert len(inputs) == 2
        high_input, low_input = inputs

        high_to_high = K.conv2d(high_input, self.high_to_high_kernel,
                                strides=self.strides, padding=self.padding,
                                data_format="channels_last")

        high_to_low  = K.pool2d(high_input, (2,2), strides=(2,2), pool_mode="avg")
        high_to_low  = K.conv2d(high_to_low, self.high_to_low_kernel,
                                strides=self.strides, padding=self.padding,
                                data_format="channels_last")
 
        low_to_high  = K.conv2d(low_input, self.low_to_high_kernel,
                                strides=self.strides, padding=self.padding,
                                data_format="channels_last")
        low_to_high = K.repeat_elements(low_to_high, 2, axis=1)
        low_to_high = K.repeat_elements(low_to_high, 2, axis=2)
    
        low_to_low   = K.conv2d(low_input, self.low_to_low_kernel,
                                strides=self.strides, padding=self.padding,
                                data_format="channels_last")
   
        high_add = high_to_high + low_to_high
        low_add = high_to_low + low_to_low
        return [high_add, low_add]

    def compute_output_shape(self, input_shapes):
        high_in_shape, low_in_shape = input_shapes
        high_out_shape = (*high_in_shape[:3], self.high_channels)
        low_out_shape = (*low_in_shape[:3], self.low_channels)
        return [high_out_shape, low_out_shape]

    def get_config(self):
        base_config = super().get_config()
        out_config = {
            **base_config,
            "filters": self.filters,
            "alpha": self.alpha,
            "filters": self.filters,
            "kernel_size": self.kernel_size,
            "strides": self.strides,
            "padding": self.padding,
            "kernel_initializer": self.kernel_initializer,
            "kernel_regularizer": self.kernel_regularizer,
            "kernel_constraint": self.kernel_constraint,            
        }
        return out_config

In [0]:
def create_normal_residual_block(inputs, ch, N):
    x = inputs
    for i in range(N):
        if i == 0:
            skip = layers.Conv2D(ch, 1)(x)
            skip = layers.BatchNormalization()(skip)
            skip = layers.Activation("relu")(skip)
        else:
            skip = x
        x = layers.Conv2D(ch, 3, padding="same")(x)
        x = layers.BatchNormalization()(x)
        x = layers.Activation("relu")(x)
        x = layers.Conv2D(ch, 3, padding="same")(x)
        x = layers.BatchNormalization()(x)
        x = layers.Activation("relu")(x)
        x = layers.Add()([x, skip])
    return x

def create_octconv_residual_block(inputs, ch, N, alpha):
    high, low = inputs
    for i in range(N):
        if i == 0:
            skip_high = layers.Conv2D(int(ch*(1-alpha)), 1)(high)
            skip_high = layers.BatchNormalization()(skip_high)
            skip_high = layers.Activation("relu")(skip_high)

            skip_low = layers.Conv2D(int(ch*alpha), 1)(low)
            skip_low = layers.BatchNormalization()(skip_low)
            skip_low = layers.Activation("relu")(skip_low)
        else:
            skip_high, skip_low = high, low

        high, low = OctConv2D(filters=ch, alpha=alpha)([high, low])
        high = layers.BatchNormalization()(high)
        high = layers.Activation("relu")(high)
        low = layers.BatchNormalization()(low)
        low = layers.Activation("relu")(low)

        high, low = OctConv2D(filters=ch, alpha=alpha)([high, low])
        high = layers.BatchNormalization()(high)
        high = layers.Activation("relu")(high)
        low = layers.BatchNormalization()(low)
        low = layers.Activation("relu")(low)

        high = layers.Add()([high, skip_high])
        low = layers.Add()([low, skip_low])
    return [high, low]

def create_octconv_last_residual_block(inputs, ch, alpha):
    high, low = inputs

    high, low = OctConv2D(filters=ch, alpha=alpha)([high, low])
    high = layers.BatchNormalization()(high)
    high = layers.Activation("relu")(high)
    low = layers.BatchNormalization()(low)
    low = layers.Activation("relu")(low)

    high_to_high = layers.Conv2D(ch, 3, padding="same")(high)
    low_to_high = layers.Conv2D(ch, 3, padding="same")(low)
    low_to_high = layers.Lambda(lambda x: 
                        K.repeat_elements(K.repeat_elements(x, 2, axis=1), 2, axis=2))(low_to_high)
    x = layers.Add()([high_to_high, low_to_high])
    x = layers.BatchNormalization()(x)
    x = layers.Activation("relu")(x)
    return x

In [0]:
def create_normal_wide_resnet(N=4, k=10):
    input = layers.Input((32,32,3))
    x = layers.Conv2D(16, 3, padding="same")(input)
    x = layers.BatchNormalization()(x)
    x = layers.Activation("relu")(x)
    x = create_normal_residual_block(x, 16*k, N)
    x = layers.AveragePooling2D(2)(x)
    x = create_normal_residual_block(x, 32*k, N)
    x = layers.AveragePooling2D(2)(x)
    x = create_normal_residual_block(x, 64*k, N)
    x = layers.GlobalAveragePooling2D()(x)
    x = layers.Dense(10, activation="softmax")(x)

    model = Model(input, x)
    return model

In [0]:
def create_octconv_wide_resnet(alpha, N=4, k=10):

    input = layers.Input((32,32,3))
    low = layers.AveragePooling2D(2)(input)

    high, low = OctConv2D(filters=16, alpha=alpha)([input, low])
    high = layers.BatchNormalization()(high)
    high = layers.Activation("relu")(high)
    low = layers.BatchNormalization()(low)
    low = layers.Activation("relu")(low)

    high, low = create_octconv_residual_block([high, low], 16*k, N, alpha)
 
    high = layers.AveragePooling2D(2)(high)
    low = layers.AveragePooling2D(2)(low)
    high, low = create_octconv_residual_block([high, low], 32*k, N, alpha)

    high = layers.AveragePooling2D(2)(high)
    low = layers.AveragePooling2D(2)(low)
    high, low = create_octconv_residual_block([high, low], 64*k, N-1, alpha)

    x = create_octconv_last_residual_block([high, low], 64*k, alpha)

    x = layers.GlobalAveragePooling2D()(x)
    x = layers.Dense(10, activation="softmax")(x)

    model = Model(input, x)
    return model

In [11]:
def lr_scheduler(epoch):
    x = 0.1
    if epoch >= 100: x /= 5.0
    if epoch >= 150: x /= 5.0
    return x

def train(alpha):
    (X_train, y_train), (X_test, y_test) = cifar10.load_data()
    train_gen = ImageDataGenerator(rescale=1.0/255, horizontal_flip=True, 
                                    width_shift_range=4.0/32.0, height_shift_range=4.0/32.0)
    test_gen = ImageDataGenerator(rescale=1.0/255)
    y_train = to_categorical(y_train)
    y_test = to_categorical(y_test)

    tf.logging.set_verbosity(tf.logging.FATAL)

    if alpha <= 0:
        model = create_normal_wide_resnet()
    else:
        model = create_octconv_wide_resnet(alpha)
    model.compile(SGD(0.1, momentum=0.9), "categorical_crossentropy", ["acc"])
    model.summary()

    batch_size = 128
    scheduler = LearningRateScheduler(lr_scheduler)
    hist = History()

    start_time = time.time()
    model.fit_generator(train_gen.flow(X_train, y_train, batch_size, shuffle=True),
                        steps_per_epoch=X_train.shape[0]//batch_size,
                        validation_data=test_gen.flow(X_test, y_test, batch_size, shuffle=False),
                        validation_steps=X_test.shape[0]//batch_size,
                        callbacks=[scheduler, hist], max_queue_size=5, epochs=20)
    elapsed = time.time() - start_time
    print(elapsed)

    history = hist.history
    history["elapsed"] = elapsed

train(0.25)

Model: "model_3"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_4 (InputLayer)            [(None, 32, 32, 3)]  0                                            
__________________________________________________________________________________________________
average_pooling2d_12 (AveragePo (None, 16, 16, 3)    0           input_4[0][0]                    
__________________________________________________________________________________________________
oct_conv2d_48 (OctConv2D)       [(None, 32, 32, 12), 864         input_4[0][0]                    
                                                                 average_pooling2d_12[0][0]       
__________________________________________________________________________________________________
batch_normalization_138 (BatchN (None, 32, 32, 12)   48          oct_conv2d_48[0][0]        