In [None]:
SELECTED_GPUS = [4, 5, 6, 7]

import os

os.environ['CUDA_VISIBLE_DEVICES'] = ','.join([str(gpu_number) for gpu_number in SELECTED_GPUS])

import tensorflow as tf 

tf.get_logger().setLevel('INFO')

assert len(tf.config.list_physical_devices('GPU')) > 0

GPUS = tf.config.experimental.list_physical_devices('GPU')
for gpu in GPUS:
    tf.config.experimental.set_memory_growth(gpu, True)

DISTRIBUTED_STRATEGY = tf.distribute.MirroredStrategy(
    cross_device_ops=tf.distribute.NcclAllReduce(),
    devices=['/gpu:%d' % index for index in range(len(SELECTED_GPUS))]
)

NUM_GPUS = DISTRIBUTED_STRATEGY.num_replicas_in_sync

print('Number of devices: {}'.format(NUM_GPUS))

import math
import numpy as np
import pickle
import random
import sys
from skimage import transform
from tensorflow.python.framework.convert_to_constants import  convert_variables_to_constants_v2_as_graph
from vit_keras import vit
from vit_keras.layers import ClassToken, AddPositionEmbs, MultiHeadSelfAttention, TransformerBlock

IMAGE_SIZE = 384
HIDDEN_DIM = 768
PATCH_SIZE = 16
MLP_DIM = 3072  # ResMLP
CHANNELS_MLP_DIM = 3072  # MLP-Mixer
TOKENS_MLP_DIM = 384  # MLP-Mixer
VIDEO_PATCHES = (2, 3)  # how many sub-images there are in each image for crowd counting
VIDEO_SIZE = (VIDEO_PATCHES[0] * IMAGE_SIZE, VIDEO_PATCHES[1] * IMAGE_SIZE)

In [None]:
def get_params(model):
    string_list = []
    model.summary(print_fn=lambda x: string_list.append(x))
    for string in string_list:
        if string.startswith('Trainable params:'):
            return int(string.split()[-1].replace(',', ''))
    return None

def get_flops(model):
    """
    from https://github.com/tensorflow/tensorflow/issues/32809#issuecomment-768977280
    """
    concrete = tf.function(lambda inputs: model(inputs))
    concrete_func = concrete.get_concrete_function(
        [tf.TensorSpec([1, *inputs.shape[1:]]) for inputs in model.inputs])
    frozen_func, graph_def = convert_variables_to_constants_v2_as_graph(concrete_func)
    with tf.Graph().as_default() as graph:
        tf.graph_util.import_graph_def(graph_def, name='')
        run_meta = tf.compat.v1.RunMetadata()
        opts = tf.compat.v1.profiler.ProfileOptionBuilder.float_operation()
        flops = tf.compat.v1.profiler.profile(graph=graph, run_meta=run_meta, cmd="op", options=opts)
        return flops.total_float_ops

In [None]:
# from https://github.com/leondgarse/Keras_mlp/blob/main/res_mlp.py

def channel_affine(inputs, use_bias=True, weight_init_value=1, name=''):
    ww_init = tfkeras.initializers.Constant(weight_init_value) if weight_init_value != 1 else 'ones'
    nn = tf.keras.backend.expand_dims(inputs, 1)
    nn = tf.keras.layers.DepthwiseConv2D(1, depthwise_initializer=ww_init, use_bias=use_bias, name=name + 'affine')(nn)
    return tf.keras.backend.squeeze(nn, 1)

def mlp_block(inputs, mlp_dim, activation='gelu', name=''):
    affine_inputs = channel_affine(inputs, use_bias=True, name=name + '1_')
    nn = tf.keras.layers.Permute((2, 1), name=name + 'permute_1')(affine_inputs)
    nn = tf.keras.layers.Dense(nn.shape[-1], name=name + 'dense_1')(nn)
    nn = tf.keras.layers.Permute((2, 1), name=name + 'permute_2')(nn)
    nn = channel_affine(nn, use_bias=False, name=name + '1_gamma_')
    skip_conn = tf.keras.layers.Add(name=name + 'add_1')([nn, affine_inputs])

    affine_skip = channel_affine(skip_conn, use_bias=True, name=name + '2_')
    nn = tf.keras.layers.Dense(mlp_dim, name=name + 'dense_2_1')(affine_skip)
    nn = tf.keras.layers.Activation(activation, name=name + 'gelu')(nn)
    nn = tf.keras.layers.Dense(inputs.shape[-1], name=name + 'dense_2_2')(nn)
    nn = channel_affine(nn, use_bias=False, name=name + '2_gamma_')
    nn = tf.keras.layers.Add(name=name + 'add_2')([nn, affine_skip])
    return nn

In [None]:
# from https://github.com/Benjamin-Etheredge/mlp-mixer-keras/blob/main/mlp_mixer_keras/mlp_mixer.py

class MlpBlock(tf.keras.layers.Layer):
    def __init__(self, dim, hidden_dim, activation=None, **kwargs):
        super(MlpBlock, self).__init__(**kwargs)

        if activation is None:
            activation = tf.keras.activations.gelu

        self.dim = dim
        self.hidden_dim = hidden_dim
        self.activation = activation
        self.dense1 = tf.keras.layers.Dense(hidden_dim)
        self.activation = tf.keras.layers.Activation(activation)
        self.dense2 = tf.keras.layers.Dense(dim)

    def call(self, inputs):
        x = inputs
        x = self.dense1(x)
        x = self.activation(x)
        x = self.dense2(x)
        return x

    def compute_output_shape(self, input_signature):
        return (input_signature[0], self.dim)

    def get_config(self):
        config = super(MlpBlock, self).get_config().copy()
        config.update({
            'dim': self.dim,
            'hidden_dim': self.hidden_dim,
            'activation': self.activation,
        })
        return config

class MixerBlock(tf.keras.layers.Layer):
    def __init__(
        self,
        num_patches,
        channel_dim,
        token_mixer_hidden_dim,
        channel_mixer_hidden_dim=None,
        activation=None,
        **kwargs
    ):
        super(MixerBlock, self).__init__(**kwargs)

        if activation is None:
            activation = tf.keras.activations.gelu

        if channel_mixer_hidden_dim is None:
            channel_mixer_hidden_dim = token_mixer_hidden_dim

        self.num_patches = num_patches
        self.channel_dim = channel_dim
        self.token_mixer_hidden_dim = token_mixer_hidden_dim
        self.channel_mixer_hidden_dim = channel_mixer_hidden_dim
        self.activation = activation
        
        self.norm1 = tf.keras.layers.LayerNormalization(axis=1)
        self.permute1 = tf.keras.layers.Permute((2, 1))
        self.token_mixer = MlpBlock(num_patches, token_mixer_hidden_dim, name='token_mixer')

        self.permute2 = tf.keras.layers.Permute((2, 1))
        self.norm2 = tf.keras.layers.LayerNormalization(axis=1)
        self.channel_mixer = MlpBlock(channel_dim, channel_mixer_hidden_dim, name='channel_mixer')

        self.skip_connection1 = tf.keras.layers.Add()
        self.skip_connection2 = tf.keras.layers.Add()

    def get_config(self):
        config = super(MixerBlock, self).get_config().copy()
        config.update({
            'num_patches': self.num_patches,
            'channel_dim': self.channel_dim,
            'token_mixer_hidden_dim': self.token_mixer_hidden_dim,
            'channel_mixer_hidden_dim': self.channel_mixer_hidden_dim,
            'activation': self.activation,
        })
        return config

    def call(self, inputs):
        x = inputs
        skip_x = x
        x = self.norm1(x)
        x = self.permute1(x)
        x = self.token_mixer(x)

        x = self.permute2(x)

        x = self.skip_connection1([x, skip_x])
        skip_x = x

        x = self.norm2(x)
        x = self.channel_mixer(x)

        x = self.skip_connection2([x, skip_x])

        return x

    def compute_output_shape(self, input_shape):
        return input_shape

In [None]:
def get_branch_id(branch_number):
    if branch_number == 1:
        return 'transformer_block'
    else:
        return 'transformer_block_%d' % (branch_number - 1)

def get_model(branch_numbers, head_type, dataset):
    if dataset == 'cifar10':
        model_file_name = 'vit_cifar10_v1.h5'
    elif dataset == 'cifar100':
        model_file_name = 'vit_cifar100_v1.h5'
    elif dataset == 'disco':
        model_file_name = 'vit_cc_backbone_v2.h5'
    else:
        model_file_name = None
    
    backbone_model = tf.keras.models.load_model(model_file_name, custom_objects={
        'ClassToken': ClassToken,
        'AddPositionEmbs': AddPositionEmbs,
        'MultiHeadSelfAttention': MultiHeadSelfAttention,
        'TransformerBlock': TransformerBlock,
    })

    outputs = []
    for i, branch_number in enumerate(branch_numbers):
        y, _ = backbone_model.get_layer(get_branch_id(branch_number)).output
        if head_type == 'resmlp':
            y = mlp_block(y, mlp_dim=MLP_DIM, name='mlp_mixer_%d' % i)
            y = tf.keras.layers.GlobalAveragePooling1D()(y)
        elif head_type == 'mlp':
            y = tf.keras.layers.LayerNormalization(
                epsilon=1e-6,
                name='Transformer/encoder_norm_x_%d' % i
            )(y)
            y = tf.keras.layers.Lambda(lambda v: v[:, 0], name='ExtractToken_x_%d' % i)(y)
        elif head_type == 'vit':
            y, _ = TransformerBlock(
                num_heads=12,
                mlp_dim=3072,
                dropout=0.1,
                name='Transformer/encoderblock_x_%d' % i
            )(y)
            y = tf.keras.layers.LayerNormalization(
                epsilon=1e-6,
                name='Transformer/encoder_norm_x_%d' % i
            )(y)
            y = tf.keras.layers.Lambda(lambda v: v[:, 0], name='ExtractToken_x_%d' % i)(y)
        elif head_type == 'cnn_ignore':
            channels = HIDDEN_DIM
            width = height = IMAGE_SIZE // PATCH_SIZE
            y = tf.keras.layers.Lambda(lambda v: v[:, 1:], name='RemoveToken_%d' % i)(y)
            y = tf.keras.layers.Reshape((width, height, channels), name='cnn_reshape_%d' % i)(y)
            y = tf.keras.layers.Conv2D(
                filters=16,
                kernel_size=(3, 3),
                activation='elu',
                padding='same'
            )(y)
            y = tf.keras.layers.MaxPool2D(pool_size=(2, 2))(y)
            y = tf.keras.layers.Flatten()(y)
        elif head_type == 'cnn_add':    
            channels = HIDDEN_DIM
            width = height = IMAGE_SIZE // PATCH_SIZE

            y1 = tf.keras.layers.Lambda(lambda v: v[:, 1:], name='RemoveToken_x_%d' % i)(y)
            y1 = tf.keras.layers.Reshape((width, height, channels), name='cnn_reshape_%d' % i)(y1)

            y2 = tf.keras.layers.Lambda(lambda v: v[:, 0], name='ExtractToken_x_%d' % i)(y)
            y2 = tf.keras.layers.RepeatVector(width * height)(y2)
            y2 = tf.keras.layers.Reshape((width, height, channels), name='cls_reshape_%d' % i)(y2)

            y = tf.keras.layers.Add()([y1, y2])

            y = tf.keras.layers.Conv2D(
                filters=16,
                kernel_size=(3, 3),
                activation='elu',
                padding='same'
            )(y)
            y = tf.keras.layers.MaxPool2D(pool_size=(2, 2))(y)
            y = tf.keras.layers.Flatten()(y)
        elif head_type == 'cnn_project':
            channels = HIDDEN_DIM
            width = height = IMAGE_SIZE // PATCH_SIZE

            y1 = tf.keras.layers.Lambda(lambda v: v[:, 1:], name='RemoveToken_x_%d' % i)(y)
            y1 = tf.keras.layers.Reshape((width, height, channels), name='cnn_reshape_%d' % i)(y1)

            y2 = tf.keras.layers.Lambda(lambda v: v[:, 0], name='ExtractToken_x_%d' % i)(y)
            y2 = tf.keras.layers.RepeatVector(width * height)(y2)
            y2 = tf.keras.layers.Reshape((width, height, channels), name='cls_reshape_%d' % i)(y2)

            y = tf.keras.layers.Concatenate()([y1, y2])

            y = tf.keras.layers.Conv2D(
                filters=16,
                kernel_size=(3, 3),
                activation='elu',
                padding='same'
            )(y)
            y = tf.keras.layers.MaxPool2D(pool_size=(2, 2))(y)
            y = tf.keras.layers.Flatten()(y)
        elif head_type == 'mlp_mixer':
            num_patches = (IMAGE_SIZE // PATCH_SIZE) ** 2 + 1
            y = MixerBlock(
                num_patches=num_patches,
                channel_dim=HIDDEN_DIM,
                token_mixer_hidden_dim=TOKENS_MLP_DIM,
                channel_mixer_hidden_dim=CHANNELS_MLP_DIM
            )(y)
            y = tf.keras.layers.GlobalAveragePooling1D()(y)

        if dataset == 'cifar10':
            output_units = 10
            output_activation = 'softmax'
        elif dataset == 'cifar100':
            output_units = 100
            output_activation = 'softmax'
        elif dataset == 'disco':
            output_units = 1
            output_activation = None
        else:
            output_units = None
            output_activation = None

        # MLP head
        initializer = tf.keras.initializers.he_normal()
        regularizer = tf.keras.regularizers.l2()
        y = tf.keras.layers.Dense(
            units=256,
            activation='elu',
            kernel_initializer=initializer,
            kernel_regularizer=regularizer
        )(y)
        y = tf.keras.layers.Dropout(0.5)(y)
        y = tf.keras.layers.Dense(
            units=256,
            activation='elu',
            kernel_initializer=initializer,
            kernel_regularizer=regularizer
        )(y)
        y = tf.keras.layers.Dropout(0.5)(y)
        y = tf.keras.layers.Dense(
            units=output_units,
            activation=output_activation,
            kernel_initializer=initializer,
            kernel_regularizer=regularizer
        )(y)
        outputs.append(y)

    outputs.append(backbone_model.get_layer(index=-1).output)
    model = tf.keras.models.Model(
        inputs=backbone_model.get_layer(index=0).input,
        outputs=outputs
    )

    if dataset == 'cifar10' or dataset == 'cifar100':
        loss_type = 'categorical_crossentropy'
        metric_type = 'accuracy'
    elif dataset == 'disco':
        loss_type = 'mean_absolute_error'
        metric_type = 'mean_absolute_error'
    
    model.compile(
        optimizer=tf.keras.optimizers.Adam(learning_rate=1e-4),
        loss=[loss_type] * (len(branch_numbers) + 1),
        loss_weights=[1] * len(branch_numbers) + [2],
        metrics=[metric_type]
    )

    return model

In [None]:
def cache_split(cache_dir, images, labels, split):
    for i in range(images.shape[0]):
        if (i + 1) % 100 == 0:
            sys.stdout.write('\r%d' % (i + 1))
            sys.stdout.flush()
        with open(os.path.join(cache_dir, '%s_%d.pkl' % (split, i)), 'wb') as cache_file:
            pickle.dump({
                'image': transform.resize(images[i], (IMAGE_SIZE, IMAGE_SIZE)),
                'label': labels[i],
            }, cache_file)
    print()  # newline

def cache_all(dataset):
    if dataset == 'cifar10':
        (train_images, train_labels), (test_images, test_labels) = tf.keras.datasets.cifar10.load_data()
    elif dataset == 'cifar100':
        (train_images, train_labels), (test_images, test_labels) = tf.keras.datasets.cifar100.load_data()
    else:
        raise Exception('Unknown dataset: %s' % dataset)

    train_labels = tf.keras.utils.to_categorical(train_labels)
    test_labels = tf.keras.utils.to_categorical(test_labels)

    val_index = int(len(train_images) * 0.8)
    val_images = train_images[val_index:]
    val_labels = train_labels[val_index:]
    train_images = train_images[:val_index]
    train_labels = train_labels[:val_index]

    cache_split(dataset, train_images, train_labels, 'train')
    cache_split(dataset, val_images, val_labels, 'val')
    cache_split(dataset, test_images, test_labels, 'test')

class CIFARSequence(tf.keras.utils.Sequence):
    def __init__(self, split, batch_size, dataset):
        self.split = split
        self.batch_size = batch_size * NUM_GPUS
        self.cache_dir = dataset
        self.count = sum([1 if file_name.startswith(split) else 0 for file_name in os.listdir(self.cache_dir)])
        self.random_permutation = np.random.permutation(self.count)

    def __len__(self):
        return math.ceil(self.count / self.batch_size)

    def on_epoch_end(self):
        self.random_permutation = np.random.permutation(self.count)

    def __getitem__(self, index):
        images = []
        labels = []
        for i in self.random_permutation[index * self.batch_size:(index + 1) * self.batch_size]:
            with open(os.path.join(self.cache_dir, '%s_%d.pkl' % (self.split, i)), 'rb') as cache_file:
                contents = pickle.load(cache_file)
                images.append(contents['image'])
                labels.append(contents['label'])
        return np.array(images), np.array(labels)

In [None]:
def horizontal_flip(image):
    return np.flip(image, axis=1)

class DISCOSequence(tf.keras.utils.Sequence):
    def __init__(self, split, batch_size):
        self.split = split
        self.cache_dir = os.path.join('disco', 'vit_cache')
        self.split_len = sum([
            1 if file_name.startswith(self.split) else 0 for file_name in os.listdir(self.cache_dir)
        ])
        self.batch_size = batch_size * NUM_GPUS
        self.random_permutation = np.random.permutation(self.split_len)

    def __len__(self):
        return math.ceil(self.split_len / self.batch_size)

    def on_epoch_end(self):
        self.random_permutation = np.random.permutation(self.split_len)

    def __getitem__(self, index):
        spectrograms = []
        images = []
        density_maps = []
        if self.split == 'test':
            index_generator = range(
                index * self.batch_size,
                min((index + 1) * self.batch_size, self.split_len - 1)
            )
        else:
            index_generator = self.random_permutation[index * self.batch_size:(index + 1) * self.batch_size]
        for random_index in index_generator:
            all_path = os.path.join(
                self.cache_dir,
                '%s_%d.pkl' % (self.split, random_index)
            )
            with open(all_path, 'rb') as all_file:
                data = pickle.load(all_file)
                if self.split == 'train' and random.random() < 0.5:  # flip augmentation
                    images.append(horizontal_flip(data['image']))
                else:
                    images.append(data['image'])
                density_maps.append(np.sum(data['density_map']))

        return np.array(images), np.array(density_maps)

In [None]:
def test_cc(model, test_sequence, total_branches):
    gt = None
    outs = []
    for i, (images, density_maps) in enumerate(test_sequence):
        sys.stdout.write('\r%d' % (i + 1))
        sys.stdout.flush()
        if gt is not None:
            gt = np.concatenate((gt, density_maps))
        else:
            gt = density_maps
        output = model(images)
        for j in range(total_branches):
            if i == 0:
                outs.append(output[j].numpy().flatten())
            else:
                outs[j] = np.concatenate((outs[j], output[j].numpy().flatten()))
    print()  # newline
    maes = []
    img_patches = VIDEO_PATCHES[0] * VIDEO_PATCHES[1]
    for i in range(0, gt.shape[0], img_patches):
        gt_subset = gt[i:i + img_patches]
        for j in range(total_branches):
            if i == 0:
                maes.append([np.abs(np.sum(gt_subset) - np.sum(outs[j][i:i + img_patches]))])
            else:
                maes[j].append(np.abs(np.sum(gt_subset) - np.sum(outs[j][i:i + img_patches])))
    return [np.mean(np.array(item)) for item in maes]

In [None]:
def train(max_epochs, branch_numbers, head_type, dataset, version, temporary):
    tf.keras.backend.clear_session()

    with DISTRIBUTED_STRATEGY.scope():
        model = get_model(branch_numbers, head_type, dataset)

    lr_reduce = tf.keras.callbacks.ReduceLROnPlateau(
        monitor='val_loss',
        factor=0.6,
        patience=2,
        verbose=1,
        mode='min',
        min_lr=1e-7
    )

    early_stop = tf.keras.callbacks.EarlyStopping(
        monitor='val_loss',
        patience=5,
        verbose=1,
        mode='min'
    )

    save_model_checkpoint_file = 'bmvc_rebuttal_ee_v%d_%s_%s_%s.h5' % (
        version,
        head_type,
        dataset,
        '-'.join([str(branch_number) for branch_number in branch_numbers])
    )

    checkpoint = tf.keras.callbacks.ModelCheckpoint(
        save_model_checkpoint_file,
        monitor='val_loss',
        verbose=1,
        save_weights_only=False,
        save_best_only=True,
        mode='min',
        save_freq='epoch'
    )

    callbacks = [lr_reduce, early_stop]
    if not temporary:
        callbacks.append(checkpoint)

    batch_size = 4
    if dataset == 'cifar10' or dataset == 'cifar100':
        train_sequence = CIFARSequence('train', batch_size, dataset)
        val_sequence = CIFARSequence('val', batch_size, dataset)
        test_sequence = CIFARSequence('test', batch_size, dataset)
    elif dataset == 'disco':
        train_sequence = DISCOSequence('train', batch_size)
        val_sequence = DISCOSequence('val', batch_size)
        test_sequence = DISCOSequence('test', 2 * batch_size)
    else:
        raise Exception('Unknown dataset: %s' % dataset)

    history = model.fit(
        train_sequence,
        validation_data=val_sequence,
        epochs=max_epochs,
        shuffle=True,
        callbacks=callbacks,
        verbose=1
    )

    if dataset == 'cifar10' or dataset == 'cifar100':
        test_accuracy = model.evaluate(test_sequence)[1]
    elif dataset == 'disco':
        test_accuracy = test_cc(model, test_sequence, len(branch_numbers) + 1)

    model_params = get_params(model) / 10 ** 6
    model_flops = get_flops(model) / 10 ** 9

    return model, test_accuracy, model_params, model_flops

In [None]:
cache_all('cifar10')
cache_all('cifar100')
model, test_accuracy, model_params, model_flops = train(
    max_epochs=100,
    branch_numbers=[3, 6, 9],
    head_type='resmlp',
    dataset='disco',
    version=5,
    temporary=False
)
print(test_accuracy)