In [None]:
SELECTED_GPUS = [7]

import os

os.environ['CUDA_VISIBLE_DEVICES'] = ','.join([str(gpu_number) for gpu_number in SELECTED_GPUS])

import tensorflow as tf 

tf.get_logger().setLevel('INFO')

assert len(tf.config.list_physical_devices('GPU')) > 0

GPUS = tf.config.experimental.list_physical_devices('GPU')
for gpu in GPUS:
    tf.config.experimental.set_memory_growth(gpu, True)

DISTRIBUTED_STRATEGY = tf.distribute.MirroredStrategy(
    cross_device_ops=tf.distribute.NcclAllReduce(),
    devices=['/gpu:%d' % index for index in range(len(SELECTED_GPUS))]
)

NUM_GPUS = DISTRIBUTED_STRATEGY.num_replicas_in_sync

print('Number of devices: {}'.format(NUM_GPUS))

import math
import numpy as np
import pickle
import sys
from skimage import transform
from tensorflow.python.framework.convert_to_constants import  convert_variables_to_constants_v2_as_graph
from vit_keras import vit
from vit_keras.layers import ClassToken, AddPositionEmbs, MultiHeadSelfAttention, TransformerBlock

IMAGE_SIZE = 384
PATCH_SIZE = 16
HIDDEN_DIM = 768
MLP_DIM = 3072
CHANNELS_MLP_DIM = 3072
TOKENS_MLP_DIM = 384

In [None]:
def get_flops(model):
    """
    from https://github.com/tensorflow/tensorflow/issues/32809#issuecomment-768977280
    """
    concrete = tf.function(lambda inputs: model(inputs))
    concrete_func = concrete.get_concrete_function(
        [tf.TensorSpec([1, *inputs.shape[1:]]) for inputs in model.inputs])
    frozen_func, graph_def = convert_variables_to_constants_v2_as_graph(concrete_func)
    with tf.Graph().as_default() as graph:
        tf.graph_util.import_graph_def(graph_def, name='')
        run_meta = tf.compat.v1.RunMetadata()
        opts = tf.compat.v1.profiler.ProfileOptionBuilder.float_operation()
        flops = tf.compat.v1.profiler.profile(graph=graph, run_meta=run_meta, cmd="op", options=opts)
        return flops.total_float_ops

In [None]:
# from https://github.com/leondgarse/Keras_mlp/blob/main/res_mlp.py

def channel_affine(inputs, use_bias=True, weight_init_value=1, name=''):
    ww_init = tfkeras.initializers.Constant(weight_init_value) if weight_init_value != 1 else 'ones'
    nn = tf.keras.backend.expand_dims(inputs, 1)
    nn = tf.keras.layers.DepthwiseConv2D(1, depthwise_initializer=ww_init, use_bias=use_bias, name=name + 'affine')(nn)
    return tf.keras.backend.squeeze(nn, 1)

def mlp_block(inputs, mlp_dim, activation='gelu', name=''):
    affine_inputs = channel_affine(inputs, use_bias=True, name=name + '1_')
    nn = tf.keras.layers.Permute((2, 1), name=name + 'permute_1')(affine_inputs)
    nn = tf.keras.layers.Dense(nn.shape[-1], name=name + 'dense_1')(nn)
    nn = tf.keras.layers.Permute((2, 1), name=name + 'permute_2')(nn)
    nn = channel_affine(nn, use_bias=False, name=name + '1_gamma_')
    skip_conn = tf.keras.layers.Add(name=name + 'add_1')([nn, affine_inputs])

    affine_skip = channel_affine(skip_conn, use_bias=True, name=name + '2_')
    nn = tf.keras.layers.Dense(mlp_dim, name=name + 'dense_2_1')(affine_skip)
    nn = tf.keras.layers.Activation(activation, name=name + 'gelu')(nn)
    nn = tf.keras.layers.Dense(inputs.shape[-1], name=name + 'dense_2_2')(nn)
    nn = channel_affine(nn, use_bias=False, name=name + '2_gamma_')
    nn = tf.keras.layers.Add(name=name + 'add_2')([nn, affine_skip])
    return nn

In [None]:
# from https://github.com/Benjamin-Etheredge/mlp-mixer-keras/blob/main/mlp_mixer_keras/mlp_mixer.py

class MlpBlock(tf.keras.layers.Layer):
    def __init__(self, dim, hidden_dim, activation=None, **kwargs):
        super(MlpBlock, self).__init__(**kwargs)

        if activation is None:
            activation = tf.keras.activations.gelu

        self.dim = dim
        self.dense1 = tf.keras.layers.Dense(hidden_dim)
        self.activation = tf.keras.layers.Activation(activation)
        self.dense2 = tf.keras.layers.Dense(dim)

    def call(self, inputs):
        x = inputs
        x = self.dense1(x)
        x = self.activation(x)
        x = self.dense2(x)
        return x

    def compute_output_shape(self, input_signature):
        return (input_signature[0], self.dim)

class MixerBlock(tf.keras.layers.Layer):
    def __init__(
        self,
        num_patches,
        channel_dim,
        token_mixer_hidden_dim,
        channel_mixer_hidden_dim=None,
        activation=None,
        **kwargs
    ):
        super(MixerBlock, self).__init__(**kwargs)

        if activation is None:
            activation = tf.keras.activations.gelu

        if channel_mixer_hidden_dim is None:
            channel_mixer_hidden_dim = token_mixer_hidden_dim

        self.norm1 = tf.keras.layers.LayerNormalization(axis=1)
        self.permute1 = tf.keras.layers.Permute((2, 1))
        self.token_mixer = MlpBlock(num_patches, token_mixer_hidden_dim, name='token_mixer')

        self.permute2 = tf.keras.layers.Permute((2, 1))
        self.norm2 = tf.keras.layers.LayerNormalization(axis=1)
        self.channel_mixer = MlpBlock(channel_dim, channel_mixer_hidden_dim, name='channel_mixer')

        self.skip_connection1 = tf.keras.layers.Add()
        self.skip_connection2 = tf.keras.layers.Add()

    def call(self, inputs):
        x = inputs
        skip_x = x
        x = self.norm1(x)
        x = self.permute1(x)
        x = self.token_mixer(x)

        x = self.permute2(x)

        x = self.skip_connection1([x, skip_x])
        skip_x = x

        x = self.norm2(x)
        x = self.channel_mixer(x)

        x = self.skip_connection2([x, skip_x])

        return x

    def compute_output_shape(self, input_shape):
        return input_shape

In [None]:
def get_branch_id(branch_number):
    if branch_number == 1:
        return 'transformer_block'
    else:
        return 'transformer_block_%d' % (branch_number - 1)

def get_model(dataset_name, branch_type, branch_number):
    if dataset_name == 'disco':
        model_file_name = 'vit_cc_backbone_v2.h5'
        output_units = 1
        output_activation = None
    elif dataset_name == 'fashion_mnist':
        model_file_name = 'vit_fashion_mnist_v1.h5'
        output_units = 10
        output_activation = 'softmax'
    elif dataset_name == 'cifar10':
        model_file_name = 'vit_cifar10_v1.h5'
        output_units = 10
        output_activation = 'softmax'
    else:
        model_file_name = 'vit_cifar100_v1.h5'
        output_units = 100
        output_activation = 'softmax'

    backbone_model = tf.keras.models.load_model(model_file_name, custom_objects={
        'ClassToken': ClassToken,
        'AddPositionEmbs': AddPositionEmbs,
        'MultiHeadSelfAttention': MultiHeadSelfAttention,
        'TransformerBlock': TransformerBlock,
    })
    
    # freeze
    for layer in backbone_model.layers:
        layer.trainable = False
    
    if branch_type == 'mlp':
        y, _ = backbone_model.get_layer(get_branch_id(branch_number)).output
        y = tf.keras.layers.LayerNormalization(
            epsilon=1e-6, name="Transformer/encoder_norm"
        )(y)
        y = tf.keras.layers.Lambda(lambda v: v[:, 0], name="ExtractToken")(y)

    elif branch_type == 'vit':
        y, _ = backbone_model.get_layer(get_branch_id(branch_number)).output
        y, _ = TransformerBlock(
            num_heads=12,
            mlp_dim=3072,
            dropout=0.1,
            name=f"Transformer/encoderblock_x",
        )(y)
        y = tf.keras.layers.LayerNormalization(
            epsilon=1e-6, name="Transformer/encoder_norm"
        )(y)
        y = tf.keras.layers.Lambda(lambda v: v[:, 0], name="ExtractToken")(y)

    elif branch_type.startswith('cnn_'):
        y0, _ = backbone_model.get_layer(get_branch_id(branch_number)).output
        channels = HIDDEN_DIM
        width = height = IMAGE_SIZE // PATCH_SIZE
        y1 = tf.keras.layers.Lambda(lambda v: v[:, 1:], name='RemoveToken')(y0)
        y1 = tf.keras.layers.Reshape((width, height, channels), name='cnn_reshape')(y1)
        y2 = tf.keras.layers.Lambda(lambda v: v[:, 0], name='ExtractToken')(y0)
        y2 = tf.keras.layers.RepeatVector(width * height)(y2)
        y2 = tf.keras.layers.Reshape((width, height, channels), name='cls_reshape')(y2)
        if branch_type == 'cnn_ignore':
            y = y1
        elif branch_type == 'cnn_add':
            y = tf.keras.layers.Add()([y1, y2])
        elif branch_type == 'cnn_project':
            y = tf.keras.layers.Concatenate()([y1, y2])
        y = tf.keras.layers.Conv2D(
            filters=16,
            kernel_size=(3, 3),
            activation='elu',
            padding='same'
        )(y)
        y = tf.keras.layers.MaxPool2D(pool_size=(2, 2))(y)
        y = tf.keras.layers.Flatten()(y)

    elif branch_type == 'resmlp':
        y, _ = backbone_model.get_layer(get_branch_id(branch_number)).output
        y = mlp_block(y, mlp_dim=MLP_DIM, name='mlp_mixer')
        y = tf.keras.layers.GlobalAveragePooling1D()(y)

    elif branch_type == 'mlp_mixer':
        y, _ = backbone_model.get_layer(get_branch_id(branch_number)).output
        num_patches = (IMAGE_SIZE // PATCH_SIZE) ** 2 + 1
        y = MixerBlock(
            num_patches=num_patches,
            channel_dim=HIDDEN_DIM,
            token_mixer_hidden_dim=TOKENS_MLP_DIM,
            channel_mixer_hidden_dim=CHANNELS_MLP_DIM
        )(y)
        y = tf.keras.layers.GlobalAveragePooling1D()(y)

    else:
        raise Exception('Unknown branch type: %s' % branch_type)
    
    # MLP head
    initializer = tf.keras.initializers.he_normal()
    regularizer = tf.keras.regularizers.l2()
    y = tf.keras.layers.Dense(
        units=256,
        activation='elu',
        kernel_initializer=initializer,
        kernel_regularizer=regularizer
    )(y)
    y = tf.keras.layers.Dropout(0.5)(y)
    y = tf.keras.layers.Dense(
        units=256,
        activation='elu',
        kernel_initializer=initializer,
        kernel_regularizer=regularizer
    )(y)
    y = tf.keras.layers.Dropout(0.5)(y)
    y = tf.keras.layers.Dense(
        units=output_units,
        activation=output_activation,
        kernel_initializer=initializer,
        kernel_regularizer=regularizer
    )(y)

    model = tf.keras.models.Model(
        inputs=backbone_model.get_layer(index=0).input,
        outputs=y
    )

    return model

In [None]:
branch_types = [
    'mlp',
    'vit',
    'cnn_ignore',
    'cnn_add',
    'cnn_project',
    'resmlp',
    'mlp_mixer',
]

dataset_names = [
    'cifar10',
    'cifar100',
    'disco',
    'fashion_mnist',
]

for dataset_name in dataset_names:
    for branch_type in branch_types:
        flops = []
        for branch_number in range(1, 12):
            tf.keras.backend.clear_session()
            flops.append(get_flops(get_model(dataset_name, branch_type, branch_number)) / 10 ** 9)
        print('###', dataset_name, branch_type)
        print(flops)