In [None]:
SELECTED_GPUS = [7]

import os

os.environ['CUDA_VISIBLE_DEVICES'] = ','.join([str(gpu_number) for gpu_number in SELECTED_GPUS])

import tensorflow as tf 

tf.get_logger().setLevel('INFO')

assert len(tf.config.list_physical_devices('GPU')) > 0

GPUS = tf.config.experimental.list_physical_devices('GPU')
for gpu in GPUS:
    tf.config.experimental.set_memory_growth(gpu, True)

DISTRIBUTED_STRATEGY = tf.distribute.MirroredStrategy(
    cross_device_ops=tf.distribute.NcclAllReduce(),
    devices=['/gpu:%d' % index for index in range(len(SELECTED_GPUS))]
)

NUM_GPUS = DISTRIBUTED_STRATEGY.num_replicas_in_sync

print('Number of devices: {}'.format(NUM_GPUS))

import math
import numpy as np
import pickle
import random
import string
import sys
from skimage import transform
from tensorflow.python.framework.convert_to_constants import  convert_variables_to_constants_v2_as_graph
from vit_keras import vit
from vit_keras.layers import ClassToken, AddPositionEmbs, MultiHeadSelfAttention, TransformerBlock

VIDEO_PATCHES = (2, 3)

In [None]:
# from https://github.com/Benjamin-Etheredge/mlp-mixer-keras/blob/main/mlp_mixer_keras/mlp_mixer.py

class MlpBlock(tf.keras.layers.Layer):
    def __init__(self, dim, hidden_dim, activation=None, **kwargs):
        super(MlpBlock, self).__init__(**kwargs)

        if activation is None:
            activation = tf.keras.activations.gelu

        self.dim = dim
        self.hidden_dim = hidden_dim
        self.activation = activation
        self.dense1 = tf.keras.layers.Dense(hidden_dim)
        self.activation = tf.keras.layers.Activation(activation)
        self.dense2 = tf.keras.layers.Dense(dim)

    def call(self, inputs):
        x = inputs
        x = self.dense1(x)
        x = self.activation(x)
        x = self.dense2(x)
        return x

    def compute_output_shape(self, input_signature):
        return (input_signature[0], self.dim)

    def get_config(self):
        config = super(MlpBlock, self).get_config().copy()
        config.update({
            'dim': self.dim,
            'hidden_dim': self.hidden_dim,
            'activation': self.activation,
        })
        return config

class MixerBlock(tf.keras.layers.Layer):
    def __init__(
        self,
        num_patches,
        channel_dim,
        token_mixer_hidden_dim,
        channel_mixer_hidden_dim=None,
        activation=None,
        **kwargs
    ):
        super(MixerBlock, self).__init__(**kwargs)

        if activation is None:
            activation = tf.keras.activations.gelu

        if channel_mixer_hidden_dim is None:
            channel_mixer_hidden_dim = token_mixer_hidden_dim

        self.num_patches = num_patches
        self.channel_dim = channel_dim
        self.token_mixer_hidden_dim = token_mixer_hidden_dim
        self.channel_mixer_hidden_dim = channel_mixer_hidden_dim
        self.activation = activation
        
        self.norm1 = tf.keras.layers.LayerNormalization(axis=1)
        self.permute1 = tf.keras.layers.Permute((2, 1))
        self.token_mixer = MlpBlock(num_patches, token_mixer_hidden_dim, name='token_mixer')

        self.permute2 = tf.keras.layers.Permute((2, 1))
        self.norm2 = tf.keras.layers.LayerNormalization(axis=1)
        self.channel_mixer = MlpBlock(channel_dim, channel_mixer_hidden_dim, name='channel_mixer')

        self.skip_connection1 = tf.keras.layers.Add()
        self.skip_connection2 = tf.keras.layers.Add()

    def get_config(self):
        config = super(MixerBlock, self).get_config().copy()
        config.update({
            'num_patches': self.num_patches,
            'channel_dim': self.channel_dim,
            'token_mixer_hidden_dim': self.token_mixer_hidden_dim,
            'channel_mixer_hidden_dim': self.channel_mixer_hidden_dim,
            'activation': self.activation,
        })
        return config

    def call(self, inputs):
        x = inputs
        skip_x = x
        x = self.norm1(x)
        x = self.permute1(x)
        x = self.token_mixer(x)

        x = self.permute2(x)

        x = self.skip_connection1([x, skip_x])
        skip_x = x

        x = self.norm2(x)
        x = self.channel_mixer(x)

        x = self.skip_connection2([x, skip_x])

        return x

    def compute_output_shape(self, input_shape):
        return input_shape

In [None]:
def get_branch_id(branch_number):
    if branch_number == 1:
        return 'transformer_block'
    else:
        return 'transformer_block_%d' % (branch_number - 1)

def get_model(branch_type, branch_number, version):
    backbone_model = tf.keras.models.load_model('vit_cc_backbone_v2.h5', custom_objects={
        'ClassToken': ClassToken,
        'AddPositionEmbs': AddPositionEmbs,
        'MultiHeadSelfAttention': MultiHeadSelfAttention,
        'TransformerBlock': TransformerBlock,
    })
    y, _ = backbone_model.get_layer(get_branch_id(branch_number)).output
    backend_model = tf.keras.models.Model(
        inputs=backbone_model.get_layer(index=0).input,
        outputs=y
    )
    backend_model._name='backend_model'
    frontend_model = tf.keras.models.load_model(
        'vit_disco_cw_%d_%s_head_precomputed_%s.h5' % (branch_number, branch_type, version),
        custom_objects={
            'ClassToken': ClassToken,
            'AddPositionEmbs': AddPositionEmbs,
            'MultiHeadSelfAttention': MultiHeadSelfAttention,
            'TransformerBlock': TransformerBlock,
            'MlpBlock': MlpBlock,
            'MixerBlock': MixerBlock,
        }
    )
    frontend_model._name = 'frontend_model'
    model = tf.keras.Sequential([
        backend_model,
        frontend_model
    ])
    return model

In [None]:
DISCO_PATH = 'disco'
CACHE_DIR = os.path.join(DISCO_PATH, 'vit_cache')

def horizontal_flip(image):
    return np.flip(image, axis=1)

class CCSequence(tf.keras.utils.Sequence):
    def __init__(self, split, batch_size):
        self.split = split
        self.split_len = sum([
            1 if file_name.startswith(self.split) else 0 for file_name in os.listdir(CACHE_DIR)
        ])
        self.batch_size = batch_size
        self.random_permutation = np.random.permutation(self.split_len)

    def __len__(self):
        return math.ceil(self.split_len / self.batch_size)

    def on_epoch_end(self):
        self.random_permutation = np.random.permutation(self.split_len)

    def __getitem__(self, index):
        spectrograms = []
        images = []
        density_maps = []
        if self.split == 'test':
            index_generator = range(
                index * self.batch_size,
                min((index + 1) * self.batch_size, self.split_len)
            )
        else:
            index_generator = self.random_permutation[index * self.batch_size:(index + 1) * self.batch_size]
        for random_index in index_generator:
            all_path = os.path.join(
                CACHE_DIR,
                '%s_%d.pkl' % (self.split, random_index)
            )
            with open(all_path, 'rb') as all_file:
                data = pickle.load(all_file)
                if self.split == 'train' and random.random() < 0.5:  # flip augmentation
                    images.append(horizontal_flip(data['image']))
                else:
                    images.append(data['image'])
                density_maps.append(np.sum(data['density_map']))

        return np.array(images), np.array(density_maps)

In [None]:
def get_mae(branch_type, branch_number, version):
    tf.keras.backend.clear_session()
    test_sequence = CCSequence('test', 32)
    model = get_model(branch_type, branch_number, version)
    gt = None
    out = None
    for i, (images, density_maps) in enumerate(test_sequence):
        sys.stdout.write('\r%d' % (i + 1))
        sys.stdout.flush()
        if gt is not None:
            gt = np.concatenate((gt, density_maps))
        else:
            gt = density_maps
        if out is not None:
            out = np.concatenate((out, model(images).numpy().flatten()))
        else:
            out = model(images).numpy().flatten()
    print()  # newline
    mae = []
    img_patches = VIDEO_PATCHES[0] * VIDEO_PATCHES[1]
    for i in range(0, gt.shape[0], img_patches):
        gt_subset = gt[i:i + img_patches]
        out_subset = out[i:i + img_patches]
        mae.append(np.abs(np.sum(gt_subset) - np.sum(out_subset)))
    return np.mean(np.array(mae))

In [None]:
def get_maes():
    for branch_type in [
        'vit',
        'mlp',
        'cnn_ignore',
        'cnn_add',
        'cnn_project',
        'resmlp',
        'mlp_mixer',
    ]:
        maes = []
        for branch_number in range(1, 12):
            mae = get_mae(branch_type, branch_number, 'v1')
            print(mae)
            maes.append(mae)
        print('###', branch_type, maes)

In [None]:
get_maes()