In [None]:
SELECTED_GPUS = [4]

import os

os.environ['CUDA_VISIBLE_DEVICES'] = ','.join([str(gpu_number) for gpu_number in SELECTED_GPUS])

import tensorflow as tf 

tf.get_logger().setLevel('INFO')

assert len(tf.config.list_physical_devices('GPU')) > 0

GPUS = tf.config.experimental.list_physical_devices('GPU')
for gpu in GPUS:
    tf.config.experimental.set_memory_growth(gpu, True)

DISTRIBUTED_STRATEGY = tf.distribute.MirroredStrategy(
    cross_device_ops=tf.distribute.NcclAllReduce(),
    devices=['/gpu:%d' % index for index in range(len(SELECTED_GPUS))]
)

NUM_GPUS = DISTRIBUTED_STRATEGY.num_replicas_in_sync

print('Number of devices: {}'.format(NUM_GPUS))

import math
import numpy as np
import pickle
import random
import sys
import time
from skimage import transform
from vit_keras import vit
from vit_keras.layers import ClassToken, AddPositionEmbs, MultiHeadSelfAttention, TransformerBlock

PRECOMPUTE_DIR = 'precompute'

In [None]:
def get_branch_id(branch_number):
    if branch_number == 1:
        return 'transformer_block'
    else:
        return 'transformer_block_%d' % (branch_number - 1)

def get_model(dataset):
    backbone_model = tf.keras.models.load_model('vit_%s_v1.h5' % dataset, custom_objects={
        'ClassToken': ClassToken,
        'AddPositionEmbs': AddPositionEmbs,
        'MultiHeadSelfAttention': MultiHeadSelfAttention,
        'TransformerBlock': TransformerBlock,
    })

    # freeze
    for layer in backbone_model.layers:
        layer.trainable = False
    
    outputs = []
    for branch_number in range(1, 12):
        y, _ = backbone_model.get_layer(get_branch_id(branch_number)).output
        outputs.append(y)
    
    model = tf.keras.models.Model(
        inputs=backbone_model.get_layer(index=0).input,
        outputs=outputs
    )

    return model

In [None]:
def precompute(dataset, batch_size=32 * NUM_GPUS):
    with DISTRIBUTED_STRATEGY.scope():
        model = get_model(dataset)
    for split in ['train', 'val', 'test']:
        print(split)
        total_count = sum([1 if file_name.startswith(split) else 0 for file_name in os.listdir(dataset)])
        batch_count = math.ceil(total_count / batch_size)
        for batch_index in range(batch_count):
            sys.stdout.write('\r[%d/%d]' % (batch_index + 1, batch_count))
            sys.stdout.flush()
            images = []
            labels = []
            for sample_index in range(batch_index * batch_size, (batch_index + 1) * batch_size):
                image_path = os.path.join(dataset, '%s_%d.pkl' % (split, sample_index))
                if os.path.exists(image_path):  # last batch may contain less
                    with open(image_path, 'rb') as cache_file:
                        contents = pickle.load(cache_file)
                        images.append(contents['image'])
                        labels.append(contents['label'])
            outputs = model(np.array(images))
            for branch_number in range(1, 12):
                branch_outputs = outputs[branch_number - 1]
                for i, branch_output in enumerate(branch_outputs):
                    sample_index = batch_index * batch_size + i
                    sample_path = os.path.join(
                        PRECOMPUTE_DIR,
                        dataset,
                        '%s_branch%d_sample%d.pkl' % (split, branch_number, sample_index)
                    )
                    with open(sample_path, 'wb') as sample_file:
                        pickle.dump({
                            'features': branch_output,
                            'label': labels[i],
                        }, sample_file)
        print()  # newline

In [None]:
precompute('cifar10')
precompute('cifar100')