In [None]:
SELECTED_GPUS = [4, 5]

import os

os.environ['CUDA_VISIBLE_DEVICES'] = ','.join([str(gpu_number) for gpu_number in SELECTED_GPUS])

import tensorflow as tf 

tf.get_logger().setLevel('INFO')

assert len(tf.config.list_physical_devices('GPU')) > 0

GPUS = tf.config.experimental.list_physical_devices('GPU')
for gpu in GPUS:
    tf.config.experimental.set_memory_growth(gpu, True)

DISTRIBUTED_STRATEGY = tf.distribute.MirroredStrategy(
    cross_device_ops=tf.distribute.NcclAllReduce(),
    devices=['/gpu:%d' % index for index in range(len(SELECTED_GPUS))]
)

NUM_GPUS = DISTRIBUTED_STRATEGY.num_replicas_in_sync

print('Number of devices: {}'.format(NUM_GPUS))

import json
import math
import numpy as np
import pickle
import sys
from skimage import transform
from tensorflow.python.framework.convert_to_constants import  convert_variables_to_constants_v2_as_graph
from vit_keras import vit
from vit_keras.layers import ClassToken, AddPositionEmbs, MultiHeadSelfAttention, TransformerBlock

IMAGE_SIZE = 384
PATCH_SIZE = 16
NUM_PATCHES = (384 // PATCH_SIZE) ** 2 + 1
HIDDEN_DIM = 768
VIDEO_PATCHES = (2, 3)
VIDEO_SIZE = (VIDEO_PATCHES[0] * IMAGE_SIZE, VIDEO_PATCHES[1] * IMAGE_SIZE)
MLP_DIM = 3072  # ResMLP
CHANNELS_MLP_DIM = 3072  # MLP-Mixer
TOKENS_MLP_DIM = 384  # MLP-Mixer
PRECOMPUTE_DIR = 'precompute'
PRECOMPUTE_FASHION_MNIST_DIR = os.path.join(PRECOMPUTE_DIR, 'fashion_mnist')

In [None]:
def get_params(model):
    string_list = []
    model.summary(print_fn=lambda x: string_list.append(x))
    for string in string_list:
        if string.startswith('Trainable params:'):
            return int(string.split()[-1].replace(',', ''))
    return None

def get_flops(model):
    """
    from https://github.com/tensorflow/tensorflow/issues/32809#issuecomment-768977280
    """
    concrete = tf.function(lambda inputs: model(inputs))
    concrete_func = concrete.get_concrete_function(
        [tf.TensorSpec([1, *inputs.shape[1:]]) for inputs in model.inputs])
    frozen_func, graph_def = convert_variables_to_constants_v2_as_graph(concrete_func)
    with tf.Graph().as_default() as graph:
        tf.graph_util.import_graph_def(graph_def, name='')
        run_meta = tf.compat.v1.RunMetadata()
        opts = tf.compat.v1.profiler.ProfileOptionBuilder.float_operation()
        flops = tf.compat.v1.profiler.profile(graph=graph, run_meta=run_meta, cmd="op", options=opts)
        return flops.total_float_ops

In [None]:
# from https://github.com/leondgarse/Keras_mlp/blob/main/res_mlp.py

def channel_affine(inputs, use_bias=True, weight_init_value=1, name=''):
    ww_init = tfkeras.initializers.Constant(weight_init_value) if weight_init_value != 1 else 'ones'
    nn = tf.keras.backend.expand_dims(inputs, 1)
    nn = tf.keras.layers.DepthwiseConv2D(1, depthwise_initializer=ww_init, use_bias=use_bias, name=name + 'affine')(nn)
    return tf.keras.backend.squeeze(nn, 1)

def mlp_block(inputs, mlp_dim, activation='gelu', name=''):
    affine_inputs = channel_affine(inputs, use_bias=True, name=name + '1_')
    nn = tf.keras.layers.Permute((2, 1), name=name + 'permute_1')(affine_inputs)
    nn = tf.keras.layers.Dense(nn.shape[-1], name=name + 'dense_1')(nn)
    nn = tf.keras.layers.Permute((2, 1), name=name + 'permute_2')(nn)
    nn = channel_affine(nn, use_bias=False, name=name + '1_gamma_')
    skip_conn = tf.keras.layers.Add(name=name + 'add_1')([nn, affine_inputs])

    affine_skip = channel_affine(skip_conn, use_bias=True, name=name + '2_')
    nn = tf.keras.layers.Dense(mlp_dim, name=name + 'dense_2_1')(affine_skip)
    nn = tf.keras.layers.Activation(activation, name=name + 'gelu')(nn)
    nn = tf.keras.layers.Dense(inputs.shape[-1], name=name + 'dense_2_2')(nn)
    nn = channel_affine(nn, use_bias=False, name=name + '2_gamma_')
    nn = tf.keras.layers.Add(name=name + 'add_2')([nn, affine_skip])
    return nn

In [None]:
# from https://github.com/Benjamin-Etheredge/mlp-mixer-keras/blob/main/mlp_mixer_keras/mlp_mixer.py

class MlpBlock(tf.keras.layers.Layer):
    def __init__(self, dim, hidden_dim, activation=None, **kwargs):
        super(MlpBlock, self).__init__(**kwargs)

        if activation is None:
            activation = tf.keras.activations.gelu

        self.dim = dim
        self.hidden_dim = hidden_dim
        self.activation = activation
        self.dense1 = tf.keras.layers.Dense(hidden_dim)
        self.activation = tf.keras.layers.Activation(activation)
        self.dense2 = tf.keras.layers.Dense(dim)

    def call(self, inputs):
        x = inputs
        x = self.dense1(x)
        x = self.activation(x)
        x = self.dense2(x)
        return x

    def compute_output_shape(self, input_signature):
        return (input_signature[0], self.dim)

    def get_config(self):
        config = super(MlpBlock, self).get_config().copy()
        config.update({
            'dim': self.dim,
            'hidden_dim': self.hidden_dim,
            'activation': self.activation,
        })
        return config

class MixerBlock(tf.keras.layers.Layer):
    def __init__(
        self,
        num_patches,
        channel_dim,
        token_mixer_hidden_dim,
        channel_mixer_hidden_dim=None,
        activation=None,
        **kwargs
    ):
        super(MixerBlock, self).__init__(**kwargs)

        if activation is None:
            activation = tf.keras.activations.gelu

        if channel_mixer_hidden_dim is None:
            channel_mixer_hidden_dim = token_mixer_hidden_dim

        self.num_patches = num_patches
        self.channel_dim = channel_dim
        self.token_mixer_hidden_dim = token_mixer_hidden_dim
        self.channel_mixer_hidden_dim = channel_mixer_hidden_dim
        self.activation = activation
        
        self.norm1 = tf.keras.layers.LayerNormalization(axis=1)
        self.permute1 = tf.keras.layers.Permute((2, 1))
        self.token_mixer = MlpBlock(num_patches, token_mixer_hidden_dim, name='token_mixer')

        self.permute2 = tf.keras.layers.Permute((2, 1))
        self.norm2 = tf.keras.layers.LayerNormalization(axis=1)
        self.channel_mixer = MlpBlock(channel_dim, channel_mixer_hidden_dim, name='channel_mixer')

        self.skip_connection1 = tf.keras.layers.Add()
        self.skip_connection2 = tf.keras.layers.Add()

    def get_config(self):
        config = super(MixerBlock, self).get_config().copy()
        config.update({
            'num_patches': self.num_patches,
            'channel_dim': self.channel_dim,
            'token_mixer_hidden_dim': self.token_mixer_hidden_dim,
            'channel_mixer_hidden_dim': self.channel_mixer_hidden_dim,
            'activation': self.activation,
        })
        return config

    def call(self, inputs):
        x = inputs
        skip_x = x
        x = self.norm1(x)
        x = self.permute1(x)
        x = self.token_mixer(x)

        x = self.permute2(x)

        x = self.skip_connection1([x, skip_x])
        skip_x = x

        x = self.norm2(x)
        x = self.channel_mixer(x)

        x = self.skip_connection2([x, skip_x])

        return x

    def compute_output_shape(self, input_shape):
        return input_shape

In [None]:
def get_model(branch_number, head_type):
    model_input = tf.keras.Input(shape=(NUM_PATCHES, HIDDEN_DIM))
    y = model_input
    if head_type == 'resmlp':
        y = mlp_block(y, mlp_dim=MLP_DIM, name='mlp_mixer')
        y = tf.keras.layers.GlobalAveragePooling1D()(y)
    elif head_type == 'mlp':
        y = tf.keras.layers.LayerNormalization(
            epsilon=1e-6,
            name='Transformer/encoder_norm_x'
        )(y)
        y = tf.keras.layers.Lambda(lambda v: v[:, 0], name='ExtractToken_x')(y)
    elif head_type == 'vit':
        y, _ = TransformerBlock(
            num_heads=12,
            mlp_dim=3072,
            dropout=0.1,
            name='Transformer/encoderblock_x'
        )(y)
        y = tf.keras.layers.LayerNormalization(
            epsilon=1e-6,
            name='Transformer/encoder_norm_x'
        )(y)
        y = tf.keras.layers.Lambda(lambda v: v[:, 0], name='ExtractToken_x')(y)
    elif head_type == 'cnn_ignore':
        channels = HIDDEN_DIM
        width = height = IMAGE_SIZE // PATCH_SIZE
        y = tf.keras.layers.Lambda(lambda v: v[:, 1:], name='RemoveToken')(y)
        y = tf.keras.layers.Reshape((width, height, channels), name='cnn_reshape')(y)
        y = tf.keras.layers.Conv2D(
            filters=16,
            kernel_size=(3, 3),
            activation='elu',
            padding='same'
        )(y)
        y = tf.keras.layers.MaxPool2D(pool_size=(2, 2))(y)
        y = tf.keras.layers.Flatten()(y)
    elif head_type == 'cnn_add':    
        channels = HIDDEN_DIM
        width = height = IMAGE_SIZE // PATCH_SIZE

        y1 = tf.keras.layers.Lambda(lambda v: v[:, 1:], name='RemoveToken_x')(y)
        y1 = tf.keras.layers.Reshape((width, height, channels), name='cnn_reshape')(y1)

        y2 = tf.keras.layers.Lambda(lambda v: v[:, 0], name='ExtractToken_x')(y)
        y2 = tf.keras.layers.RepeatVector(width * height)(y2)
        y2 = tf.keras.layers.Reshape((width, height, channels), name='cls_reshape')(y2)

        y = tf.keras.layers.Add()([y1, y2])

        y = tf.keras.layers.Conv2D(
            filters=16,
            kernel_size=(3, 3),
            activation='elu',
            padding='same'
        )(y)
        y = tf.keras.layers.MaxPool2D(pool_size=(2, 2))(y)
        y = tf.keras.layers.Flatten()(y)
    elif head_type == 'cnn_project':
        channels = HIDDEN_DIM
        width = height = IMAGE_SIZE // PATCH_SIZE

        y1 = tf.keras.layers.Lambda(lambda v: v[:, 1:], name='RemoveToken_x')(y)
        y1 = tf.keras.layers.Reshape((width, height, channels), name='cnn_reshape')(y1)

        y2 = tf.keras.layers.Lambda(lambda v: v[:, 0], name='ExtractToken_x')(y)
        y2 = tf.keras.layers.RepeatVector(width * height)(y2)
        y2 = tf.keras.layers.Reshape((width, height, channels), name='cls_reshape')(y2)

        y = tf.keras.layers.Concatenate()([y1, y2])

        y = tf.keras.layers.Conv2D(
            filters=16,
            kernel_size=(3, 3),
            activation='elu',
            padding='same'
        )(y)
        y = tf.keras.layers.MaxPool2D(pool_size=(2, 2))(y)
        y = tf.keras.layers.Flatten()(y)
    elif head_type == 'mlp_mixer':
        num_patches = (IMAGE_SIZE // PATCH_SIZE) ** 2 + 1
        y = MixerBlock(
            num_patches=num_patches,
            channel_dim=HIDDEN_DIM,
            token_mixer_hidden_dim=TOKENS_MLP_DIM,
            channel_mixer_hidden_dim=CHANNELS_MLP_DIM
        )(y)
        y = tf.keras.layers.GlobalAveragePooling1D()(y)

    # MLP head
    initializer = tf.keras.initializers.he_normal()
    regularizer = tf.keras.regularizers.l2()
    y = tf.keras.layers.Dense(
        units=256,
        activation='elu',
        kernel_initializer=initializer,
        kernel_regularizer=regularizer
    )(y)
    y = tf.keras.layers.Dropout(0.5)(y)
    y = tf.keras.layers.Dense(
        units=256,
        activation='elu',
        kernel_initializer=initializer,
        kernel_regularizer=regularizer
    )(y)
    y = tf.keras.layers.Dropout(0.5)(y)
    y = tf.keras.layers.Dense(
        units=10,
        activation='softmax',
        kernel_initializer=initializer,
        kernel_regularizer=regularizer
    )(y)

    model = tf.keras.models.Model(
        inputs=model_input,
        outputs=y
    )

    model.compile(
        optimizer=tf.keras.optimizers.Adam(learning_rate=1e-4),
        loss='categorical_crossentropy',
        metrics=['accuracy']
    )

    return model

In [None]:
class FashionMNISTSequence(tf.keras.utils.Sequence):
    def __init__(self, split, branch_number, batch_size):
        self.split = split
        self.branch_number = branch_number
        self.batch_size = batch_size * NUM_GPUS
        self.dir = PRECOMPUTE_FASHION_MNIST_DIR
        self.count = sum([
            1 if file_name.startswith('%s_branch%d_sample' % (
                self.split,
                self.branch_number
            )) else 0 for file_name in os.listdir(self.dir)
        ])
        self.random_permutation = np.random.permutation(self.count)

    def __len__(self):
        return math.ceil(self.count / self.batch_size)

    def on_epoch_end(self):
        self.random_permutation = np.random.permutation(self.count)

    def __getitem__(self, index):
        features = []
        labels = []
        for i in self.random_permutation[index * self.batch_size:(index + 1) * self.batch_size]:
            cache_file_path = os.path.join(
                self.dir,
                '%s_branch%d_sample%d.pkl' % (self.split, self.branch_number, i)
            )
            with open(cache_file_path, 'rb') as cache_file:
                contents = pickle.load(cache_file)
                features.append(contents['features'])
                labels.append(contents['label'])
        return np.array(features), np.array(labels)

In [None]:
def train(max_epochs, branch_number, head_type, batch_size=64):
    tf.keras.backend.clear_session()

    with DISTRIBUTED_STRATEGY.scope():
        model = get_model(branch_number, head_type)
        branch_params = get_params(model) / 10 ** 6
        total_flops = get_flops(model) / 10 ** 9

    lr_reduce = tf.keras.callbacks.ReduceLROnPlateau(
        monitor='val_accuracy',
        factor=0.6,
        patience=2,
        verbose=1,
        mode='max',
        min_lr=1e-7
    )

    early_stop = tf.keras.callbacks.EarlyStopping(
        monitor='val_accuracy',
        patience=5,
        verbose=1,
        mode='max'
    )

    save_model_checkpoint_file = 'vit_shtb_cw_%d_%s_head_precomputed_v1.h5' % (branch_number, head_type)

    checkpoint = tf.keras.callbacks.ModelCheckpoint(
        save_model_checkpoint_file,
        monitor='val_accuracy',
        verbose=1,
        save_weights_only=False,
        save_best_only=True,
        mode='max',
        save_freq='epoch'
    )

    history = model.fit(
        FashionMNISTSequence('train', branch_number, batch_size),
        validation_data=FashionMNISTSequence('val', branch_number, batch_size),
        epochs=max_epochs,
        shuffle=True,
        callbacks=[
            lr_reduce,
            early_stop,
            checkpoint
        ],
        verbose=1
    )

    test_accuracy = model.evaluate(FashionMNISTSequence('test', branch_number, batch_size))

    return model, test_accuracy, branch_params, total_flops

In [None]:
def save_results(results, results_path):
    with open(results_path, 'w') as results_file:
        results_file.write(json.dumps(results))

def print_results(results_path):
    with open(results_path, 'r') as results_file:
        print(json.loads(results_file.read()))

def get_results_path(head_type):
    return 'shtb_%s.json' % head_type

def run_experiment(head_type):
    results = []
    for i in reversed(range(1, 12)):
        model, test_accuracy, branch_params, total_flops = train(100, i, head_type)        
        results.append({
            'exit': i,
            'test_accuracy': test_accuracy,
            'branch_params': branch_params,
            'total_flops': total_flops,
        })
        results_path = get_results_path(head_type)
        save_results(results, results_path)
        print_results(results_path)

In [None]:
run_experiment('vit')

In [None]:
run_experiment('resmlp')

In [None]:
run_experiment('mlp_mixer')

In [None]:
run_experiment('mlp')

In [None]:
run_experiment('cnn_ignore')

In [None]:
run_experiment('cnn_add')

In [None]:
run_experiment('cnn_project')