In [None]:
SELECTED_GPUS = [4, 5, 6, 7]  # which GPUs to use

import os

os.environ['CUDA_VISIBLE_DEVICES'] = ','.join([str(gpu_number) for gpu_number in SELECTED_GPUS])

import tensorflow as tf 

tf.get_logger().setLevel('INFO')

assert len(tf.config.list_physical_devices('GPU')) > 0

GPUS = tf.config.experimental.list_physical_devices('GPU')
for gpu in GPUS:
    tf.config.experimental.set_memory_growth(gpu, True)

DISTRIBUTED_STRATEGY = tf.distribute.MirroredStrategy(
    cross_device_ops=tf.distribute.NcclAllReduce(),
    devices=['/gpu:%d' % index for index in range(len(SELECTED_GPUS))]
)

NUM_GPUS = DISTRIBUTED_STRATEGY.num_replicas_in_sync

print('Number of devices: {}'.format(NUM_GPUS))

import math
import matplotlib.pyplot as plt
import numpy as np
import pickle
import random
import scipy.io
import scipy.stats as st
import sys
import tensorflow_addons as tfa
from scipy import signal
from skimage.transform import resize

DISCO_PATH = 'disco'
WAVEFORMS_PATH = os.path.join(DISCO_PATH, 'auds')
IMAGES_PATH = os.path.join(DISCO_PATH, 'imgs')
TRAIN_DENSITY_MAPS_PATH = os.path.join(DISCO_PATH, 'train')
VAL_DENSITY_MAPS_PATH = os.path.join(DISCO_PATH, 'val')
TEST_DENSITY_MAPS_PATH = os.path.join(DISCO_PATH, 'test')
CACHE_DIR = os.path.join(DISCO_PATH, 'vit_cache')

from vit_keras import vit
from vit_keras.layers import ClassToken, AddPositionEmbs, MultiHeadSelfAttention, TransformerBlock

IMAGE_SIZE = 384
VIDEO_PATCHES = (2, 3)
VIDEO_SIZE = (VIDEO_PATCHES[0] * IMAGE_SIZE, VIDEO_PATCHES[1] * IMAGE_SIZE)

In [None]:
def get_model():
    backbone_model = vit.vit_b16(
        image_size=IMAGE_SIZE,
        pretrained=True,
        include_top=False,
        pretrained_top=False
    )
    y = backbone_model.get_layer(index=-1).output
    y = tf.keras.layers.Dense(1, name='regression_head')(y)
    model = tf.keras.models.Model(
        inputs=backbone_model.get_layer(index=0).input,
        outputs=y
    )
    return model

In [None]:
def get_dataset_split(split):
    examples = {}
    for file_name in os.listdir(WAVEFORMS_PATH):
        waveform_path = os.path.join(WAVEFORMS_PATH, file_name)
        if os.path.isfile(waveform_path) and file_name.endswith('.wav'):
            key = '.'.join(file_name.split('.')[:-1])
            if key not in examples:
                examples[key] = {}
            examples[key]['waveform_path'] = waveform_path
    for file_name in os.listdir(IMAGES_PATH):
        image_path = os.path.join(IMAGES_PATH, file_name)
        if os.path.isfile(image_path) and file_name.endswith('.jpg'):
            key = '.'.join(file_name.split('.')[:-1])
            if key not in examples:
                examples[key] = {}
            examples[key]['image_path'] = image_path
    for file_name in os.listdir(TRAIN_DENSITY_MAPS_PATH):
        density_map_path = os.path.join(TRAIN_DENSITY_MAPS_PATH, file_name)
        if os.path.isfile(density_map_path) and file_name.endswith('.mat'):
            key = '.'.join(file_name.split('.')[:-1])
            if key not in examples:
                examples[key] = {}
            examples[key]['density_map_path'] = density_map_path
            examples[key]['split'] = 'train'
    for file_name in os.listdir(VAL_DENSITY_MAPS_PATH):
        density_map_path = os.path.join(VAL_DENSITY_MAPS_PATH, file_name)
        if os.path.isfile(density_map_path) and file_name.endswith('.mat'):
            key = '.'.join(file_name.split('.')[:-1])
            if key not in examples:
                examples[key] = {}
            examples[key]['density_map_path'] = density_map_path
            examples[key]['split'] = 'val'
    for file_name in os.listdir(TEST_DENSITY_MAPS_PATH):
        density_map_path = os.path.join(TEST_DENSITY_MAPS_PATH, file_name)
        if os.path.isfile(density_map_path) and file_name.endswith('.mat'):
            key = '.'.join(file_name.split('.')[:-1])
            if key not in examples:
                examples[key] = {}
            examples[key]['density_map_path'] = density_map_path
            examples[key]['split'] = 'test'
    final_examples = []
    for key, info in examples.items():
        if 'split' in info and info['split'] == split:
            final_examples.append(info)
    return final_examples

def visualize_data(image, density_map):
    fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(20, 10))
    ax1.imshow(image)
    ax2.imshow(density_map)
    ax1.axis('off')
    ax2.axis('off')
    plt.show()

def get_gaussian_kernel(kernel_size, sigma):
    """
    Returns a 2D Gaussian kernel.
    from:
    https://stackoverflow.com/questions/29731726/how-to-calculate-a-gaussian-kernel-matrix-efficiently-in-numpy
    """
    x = np.linspace(-sigma, sigma, kernel_size + 1)
    kern1d = np.diff(st.norm.cdf(x))
    kern2d = np.outer(kern1d, kern1d)
    return kern2d / kern2d.sum()

def extract_patches(image):
    patches = []
    for i in range(VIDEO_PATCHES[0]):
        for j in range(VIDEO_PATCHES[1]):
            if len(image.shape) == 3:
                patches.append(
                    image[i * IMAGE_SIZE:(i + 1) * IMAGE_SIZE, j * IMAGE_SIZE:(j + 1) * IMAGE_SIZE, :]
                )
            else:
                patches.append(image[i * IMAGE_SIZE:(i + 1) * IMAGE_SIZE, j * IMAGE_SIZE:(j + 1) * IMAGE_SIZE])
    return np.array(patches)

def precompute_batches():
    gaussian_kernel = get_gaussian_kernel(15, 4)
    split_lens = []
    resize_errors = []
    for split in ['train', 'val', 'test']:
        infos = get_dataset_split(split)
        infos_len = len(infos)
        split_lens.append(infos_len * VIDEO_PATCHES[0] * VIDEO_PATCHES[1])
        for index in range(infos_len):
            sys.stdout.write('\r%d' % (index + 1))
            sys.stdout.flush()

            info = infos[index]
            crowd_image = plt.imread(info['image_path'], format='jpeg')
            resized_crowd_image = resize(crowd_image, VIDEO_SIZE)
            crowd_image_patches = extract_patches(resized_crowd_image)

            head_annotation = scipy.io.loadmat(info['density_map_path'])['map']
            density_map = signal.convolve2d(head_annotation, gaussian_kernel)
            resize_factor = density_map.shape[0] / VIDEO_SIZE[0] * density_map.shape[1] / VIDEO_SIZE[1]
            resized_density_map = resize(density_map, VIDEO_SIZE) * resize_factor  # to preserve sum
            density_patches = extract_patches(resized_density_map)

            resize_errors.append(np.abs(np.sum(density_patches) - np.sum(resized_density_map)))

            for patch_index in range(VIDEO_PATCHES[0] * VIDEO_PATCHES[1]):
                    all_path = os.path.join(
                        CACHE_DIR,
                        '%s_%d.pkl' % (split, index * VIDEO_PATCHES[0] * VIDEO_PATCHES[1] + patch_index)
                    )
                    with open(all_path, 'wb') as all_file:
                        pickle.dump({
                            'image': crowd_image_patches[patch_index],
                            'density_map': density_patches[patch_index],
                        }, all_file)
        print()  # newline
    if resize_errors:
        print('Mean absolute resize error:', np.mean(resize_errors))
    return split_lens

def horizontal_flip(image):
    return np.flip(image, axis=1)

class CCSequence(tf.keras.utils.Sequence):
    def __init__(self, split, batch_size):
        self.split = split
        self.split_len = sum([
            1 if file_name.startswith(self.split) else 0 for file_name in os.listdir(CACHE_DIR)
        ])
        self.batch_size = batch_size
        self.random_permutation = np.random.permutation(self.split_len)

    def __len__(self):
        return math.ceil(self.split_len / self.batch_size)

    def on_epoch_end(self):
        self.random_permutation = np.random.permutation(self.split_len)

    def __getitem__(self, index):
        spectrograms = []
        images = []
        density_maps = []
        if self.split == 'test':
            index_generator = range(
                index * self.batch_size,
                min((index + 1) * self.batch_size, self.split_len - 1)
            )
        else:
            index_generator = self.random_permutation[index * self.batch_size:(index + 1) * self.batch_size]
        for random_index in index_generator:
            all_path = os.path.join(
                CACHE_DIR,
                '%s_%d.pkl' % (self.split, random_index)
            )
            with open(all_path, 'rb') as all_file:
                data = pickle.load(all_file)
                if self.split == 'train' and random.random() < 0.5:  # flip augmentation
                    images.append(horizontal_flip(data['image']))
                else:
                    images.append(data['image'])
                density_maps.append(np.sum(data['density_map']))

        return np.array(images), np.array(density_maps)

In [None]:
def train_backbone(epochs):
    tf.keras.backend.clear_session()

    batch_size=4 * NUM_GPUS
    train_sequence = CCSequence('train', batch_size)
    val_sequence = CCSequence('val', batch_size)
    test_sequence = CCSequence('test', batch_size)

    with DISTRIBUTED_STRATEGY.scope():
        model = get_model()
        model.compile(
            optimizer=tf.keras.optimizers.Adam(learning_rate=1e-4),
            loss='mean_absolute_error',
            metrics=['mean_absolute_error']
        )

    lr_reduce = tf.keras.callbacks.ReduceLROnPlateau(
        monitor='val_mean_absolute_error',
        factor=0.6,
        patience=2,
        verbose=1,
        mode='min',
        min_lr=1e-7
    )

    model_checkpoint_file = 'vit_cc_backbone_v2.h5'

    checkpoint = tf.keras.callbacks.ModelCheckpoint(
        model_checkpoint_file,
        monitor='val_mean_absolute_error',
        verbose=1,
        save_weights_only=False,
        save_best_only=True,
        mode='min',
        save_freq='epoch'
    )

    history = model.fit(
        train_sequence,
        validation_data=val_sequence,
        epochs=epochs,
        shuffle=True,
        callbacks=[
            lr_reduce,
            checkpoint
        ],
        verbose=1
    )

    model.evaluate(test_sequence)
    return model

In [None]:
precompute_batches()
model = train_backbone(100)

In [None]:
test_sequence = CCSequence('test', 4 * NUM_GPUS)
model = tf.keras.models.load_model('vit_cc_backbone_v2.h5', custom_objects={
    'ClassToken': ClassToken,
    'AddPositionEmbs': AddPositionEmbs,
    'MultiHeadSelfAttention': MultiHeadSelfAttention,
    'TransformerBlock': TransformerBlock,
})

In [None]:
gt = None
out = None
for i, (images, density_maps) in enumerate(test_sequence):
    sys.stdout.write('\r%d' % (i + 1))
    sys.stdout.flush()
    if gt is not None:
        gt = np.concatenate((gt, density_maps))
    else:
        gt = density_maps
    if out is not None:
        out = np.concatenate((out, model(images).numpy().flatten()))
    else:
        out = model(images).numpy().flatten()
print()  # newline

In [None]:
mae = []
img_patches = VIDEO_PATCHES[0] * VIDEO_PATCHES[1]
for i in range(0, gt.shape[0], img_patches):
    gt_subset = gt[i:i + img_patches]
    out_subset = out[i:i + img_patches]
    mae.append(np.abs(np.sum(gt_subset) - np.sum(out_subset)))
print(np.mean(np.array(mae)))