In [None]:
import os
import tensorflow as tf 
import math
import numpy as np
import pickle
import random
import sys
from skimage import transform
from tensorflow.python.framework.convert_to_constants import  convert_variables_to_constants_v2_as_graph
from vit_keras import vit
from vit_keras.layers import ClassToken, AddPositionEmbs, MultiHeadSelfAttention, TransformerBlock

IMAGE_SIZE = 384
HIDDEN_DIM = 768
PATCH_SIZE = 16
MLP_DIM = 3072  # ResMLP
CHANNELS_MLP_DIM = 3072  # MLP-Mixer
TOKENS_MLP_DIM = 384  # MLP-Mixer
VIDEO_PATCHES = (2, 3)  # how many sub-images there are in each image for crowd counting
VIDEO_SIZE = (VIDEO_PATCHES[0] * IMAGE_SIZE, VIDEO_PATCHES[1] * IMAGE_SIZE)

def get_params(model):
    string_list = []
    model.summary(print_fn=lambda x: string_list.append(x))
    for string in string_list:
        if string.startswith('Trainable params:'):
            return int(string.split()[-1].replace(',', ''))
    return None

def get_flops(model):
    """
    from https://github.com/tensorflow/tensorflow/issues/32809#issuecomment-768977280
    """
    concrete = tf.function(lambda inputs: model(inputs))
    concrete_func = concrete.get_concrete_function(
        [tf.TensorSpec([1, *inputs.shape[1:]]) for inputs in model.inputs])
    frozen_func, graph_def = convert_variables_to_constants_v2_as_graph(concrete_func)
    with tf.Graph().as_default() as graph:
        tf.graph_util.import_graph_def(graph_def, name='')
        run_meta = tf.compat.v1.RunMetadata()
        opts = tf.compat.v1.profiler.ProfileOptionBuilder.float_operation()
        flops = tf.compat.v1.profiler.profile(graph=graph, run_meta=run_meta, cmd="op", options=opts)
        return flops.total_float_ops

def get_branch_id(branch_number):
    if branch_number == 1:
        return 'transformer_block'
    else:
        return 'transformer_block_%d' % (branch_number - 1)

def get_model(branch_numbers, head_type, dataset):
    model_file_name = 'vit_cifar10_v1.h5'
    backbone_model = tf.keras.models.load_model(model_file_name, custom_objects={
        'ClassToken': ClassToken,
        'AddPositionEmbs': AddPositionEmbs,
        'MultiHeadSelfAttention': MultiHeadSelfAttention,
        'TransformerBlock': TransformerBlock,
    })

    outputs = []
    for i, branch_number in enumerate(branch_numbers):
        y, _ = backbone_model.get_layer(get_branch_id(branch_number)).output
        y, _ = TransformerBlock(
            num_heads=12,
            mlp_dim=3072,
            dropout=0.1,
            name='Transformer/encoderblock_x_%d' % i
        )(y)
        y = tf.keras.layers.LayerNormalization(
            epsilon=1e-6,
            name='Transformer/encoder_norm_x_%d' % i
        )(y)
        y = tf.keras.layers.Lambda(lambda v: v[:, 0], name='ExtractToken_x_%d' % i)(y)

        output_units = 10
        output_activation = 'softmax'

        # MLP head
        initializer = tf.keras.initializers.he_normal()
        regularizer = tf.keras.regularizers.l2()
        y = tf.keras.layers.Dense(
            units=256,
            activation='elu',
            kernel_initializer=initializer,
            kernel_regularizer=regularizer
        )(y)
        y = tf.keras.layers.Dropout(0.5)(y)
        y = tf.keras.layers.Dense(
            units=256,
            activation='elu',
            kernel_initializer=initializer,
            kernel_regularizer=regularizer
        )(y)
        y = tf.keras.layers.Dropout(0.5)(y)
        y = tf.keras.layers.Dense(
            units=output_units,
            activation=output_activation,
            kernel_initializer=initializer,
            kernel_regularizer=regularizer
        )(y)
        outputs.append(y)

    outputs.append(backbone_model.get_layer(index=-1).output)
    model = tf.keras.models.Model(
        inputs=backbone_model.get_layer(index=0).input,
        outputs=outputs
    )

    loss_type = 'categorical_crossentropy'
    metric_type = 'accuracy'
    
    model.compile(
        optimizer=tf.keras.optimizers.Adam(learning_rate=1e-4),
        loss=[loss_type] * (len(branch_numbers) + 1),
        loss_weights=[1] * len(branch_numbers) + [2],
        metrics=[metric_type]
    )

    return model

def cache_split(cache_dir, images, labels, split):
    for i in range(images.shape[0]):
        if (i + 1) % 100 == 0:
            sys.stdout.write('\r%d' % (i + 1))
            sys.stdout.flush()
        with open(os.path.join(cache_dir, '%s_%d.pkl' % (split, i)), 'wb') as cache_file:
            pickle.dump({
                'image': transform.resize(images[i], (IMAGE_SIZE, IMAGE_SIZE)),
                'label': labels[i],
            }, cache_file)
    print()  # newline

def cache_all(dataset):
    (train_images, train_labels), (test_images, test_labels) = tf.keras.datasets.cifar10.load_data()

    train_labels = tf.keras.utils.to_categorical(train_labels)
    test_labels = tf.keras.utils.to_categorical(test_labels)

    val_index = int(len(train_images) * 0.8)
    val_images = train_images[val_index:]
    val_labels = train_labels[val_index:]
    train_images = train_images[:val_index]
    train_labels = train_labels[:val_index]

    cache_split(dataset, train_images, train_labels, 'train')
    cache_split(dataset, val_images, val_labels, 'val')
    cache_split(dataset, test_images, test_labels, 'test')

class CIFARSequence(tf.keras.utils.Sequence):
    def __init__(self, split, batch_size, dataset):
        self.split = split
        self.batch_size = batch_size
        self.cache_dir = dataset
        self.count = sum([1 if file_name.startswith(split) else 0 for file_name in os.listdir(self.cache_dir)])
        self.random_permutation = np.random.permutation(self.count)

    def __len__(self):
        return math.ceil(self.count / self.batch_size)

    def on_epoch_end(self):
        self.random_permutation = np.random.permutation(self.count)

    def __getitem__(self, index):
        images = []
        labels = []
        for i in self.random_permutation[index * self.batch_size:(index + 1) * self.batch_size]:
            with open(os.path.join(self.cache_dir, '%s_%d.pkl' % (self.split, i)), 'rb') as cache_file:
                contents = pickle.load(cache_file)
                images.append(contents['image'])
                labels.append(contents['label'])
        return np.array(images), np.array(labels)

def train(max_epochs, branch_numbers, head_type, dataset, version, temporary):
    model = get_model(branch_numbers, head_type, dataset)

    lr_reduce = tf.keras.callbacks.ReduceLROnPlateau(
        monitor='val_loss',
        factor=0.6,
        patience=2,
        verbose=1,
        mode='min',
        min_lr=1e-7
    )

    early_stop = tf.keras.callbacks.EarlyStopping(
        monitor='val_loss',
        patience=5,
        verbose=1,
        mode='min'
    )

    save_model_checkpoint_file = 'bmvc_rebuttal_ee_v%d_%s_%s_%s.h5' % (
        version,
        head_type,
        dataset,
        '-'.join([str(branch_number) for branch_number in branch_numbers])
    )

    checkpoint = tf.keras.callbacks.ModelCheckpoint(
        save_model_checkpoint_file,
        monitor='val_loss',
        verbose=1,
        save_weights_only=False,
        save_best_only=True,
        mode='min',
        save_freq='epoch'
    )

    callbacks = [lr_reduce, early_stop]
    if not temporary:
        callbacks.append(checkpoint)

    batch_size = 4
    train_sequence = CIFARSequence('train', batch_size, dataset)
    val_sequence = CIFARSequence('val', batch_size, dataset)
    test_sequence = CIFARSequence('test', batch_size, dataset)
    

    history = model.fit(
        train_sequence,
        validation_data=val_sequence,
        epochs=max_epochs,
        shuffle=True,
        callbacks=callbacks,
        verbose=1
    )

    test_accuracy = model.evaluate(test_sequence)[1]

    model_params = get_params(model) / 10 ** 6
    model_flops = get_flops(model) / 10 ** 9

    return model, test_accuracy, model_params, model_flops

cache_all('cifar10')
cache_all('cifar100')
model, test_accuracy, model_params, model_flops = train(
    max_epochs=100,
    branch_numbers=[3, 6, 9],
    head_type='resmlp',
    dataset='disco',
    version=5,
    temporary=False
)
print(test_accuracy)