In [None]:
SELECTED_GPUS = [7]

import os

os.environ['CUDA_VISIBLE_DEVICES'] = ','.join([str(gpu_number) for gpu_number in SELECTED_GPUS])

import tensorflow as tf 

tf.get_logger().setLevel('INFO')

assert len(tf.config.list_physical_devices('GPU')) > 0

GPUS = tf.config.experimental.list_physical_devices('GPU')
for gpu in GPUS:
    tf.config.experimental.set_memory_growth(gpu, True)

DISTRIBUTED_STRATEGY = tf.distribute.MirroredStrategy(
    cross_device_ops=tf.distribute.NcclAllReduce(),
    devices=['/gpu:%d' % index for index in range(len(SELECTED_GPUS))]
)

NUM_GPUS = DISTRIBUTED_STRATEGY.num_replicas_in_sync

print('Number of devices: {}'.format(NUM_GPUS))

import math
import numpy as np
import pickle
import sys
from skimage import transform
from vit_keras import vit
from vit_keras.layers import ClassToken, AddPositionEmbs, MultiHeadSelfAttention, TransformerBlock

BATCH_SIZE = 8 * NUM_GPUS
IMAGE_SIZE = 384
CACHE_DIR = 'fashion_mnist'
if not os.path.exists(CACHE_DIR):
    os.makedirs(CACHE_DIR)

In [None]:
def get_model():
    model = vit.vit_b16(
        image_size=IMAGE_SIZE,
        activation='sigmoid',
        pretrained=True,
        include_top=True,
        pretrained_top=False,
        classes=10
    )
    return model

In [None]:
def cache_split(images, labels, split):
    for i in range(images.shape[0]):
        if (i + 1) % 100 == 0:
            sys.stdout.write('\r%d' % (i + 1))
            sys.stdout.flush()
        with open(os.path.join(CACHE_DIR, '%s_%d.pkl' % (split, i)), 'wb') as cache_file:
            pickle.dump({
                'image': transform.resize(images[i], (IMAGE_SIZE, IMAGE_SIZE)),
                'label': labels[i],
            }, cache_file)
    print()  # newline

def cache_all():
    (train_images, train_labels), (test_images, test_labels) = tf.keras.datasets.fashion_mnist.load_data()

    train_labels = tf.keras.utils.to_categorical(train_labels)
    test_labels = tf.keras.utils.to_categorical(test_labels)

    val_index = int(len(train_images) * 0.8)
    val_images = train_images[val_index:]
    val_labels = train_labels[val_index:]
    train_images = train_images[:val_index]
    train_labels = train_labels[:val_index]

    cache_split(train_images, train_labels, 'train')
    cache_split(val_images, val_labels, 'val')
    cache_split(test_images, test_labels, 'test')

class FashionMNISTSequence(tf.keras.utils.Sequence):
    def __init__(self, split):
        self.split = split
        self.count = sum([1 if file_name.startswith(split) else 0 for file_name in os.listdir(CACHE_DIR)])
        self.random_permutation = np.random.permutation(self.count)

    def __len__(self):
        return math.ceil(self.count / BATCH_SIZE)

    def on_epoch_end(self):
        self.random_permutation = np.random.permutation(self.count)

    def __getitem__(self, index):
        images = []
        labels = []
        for i in self.random_permutation[index * BATCH_SIZE:(index + 1) * BATCH_SIZE]:
            with open(os.path.join(CACHE_DIR, '%s_%d.pkl' % (self.split, i)), 'rb') as cache_file:
                contents = pickle.load(cache_file)
                image = contents['image']
                expanded = np.expand_dims(image, axis=-1)
                repeated = np.repeat(expanded, 3, axis=-1)
                images.append(repeated)
                labels.append(contents['label'])
        return np.array(images), np.array(labels)

In [None]:
def train(max_epochs):
    with DISTRIBUTED_STRATEGY.scope():
        model = get_model()
        model.compile(
            optimizer=tf.keras.optimizers.Adam(learning_rate=1e-4),
            loss='categorical_crossentropy',
            metrics=['accuracy']
        )

    lr_reduce = tf.keras.callbacks.ReduceLROnPlateau(
        monitor='val_accuracy',
        factor=0.6,
        patience=2,
        verbose=1,
        mode='max',
        min_lr=1e-7
    )

    early_stop = tf.keras.callbacks.EarlyStopping(
        monitor='val_accuracy',
        patience=5,
        verbose=1,
        mode='max'
    )

    model_checkpoint_file = 'vit_fashion_mnist_v1.h5'

    checkpoint = tf.keras.callbacks.ModelCheckpoint(
        model_checkpoint_file,
        monitor='val_accuracy',
        verbose=1,
        save_weights_only=False,
        save_best_only=True,
        mode='max',
        save_freq='epoch'
    )

    history = model.fit(
        FashionMNISTSequence('train'),
        validation_data=FashionMNISTSequence('val'),
        epochs=max_epochs,
        shuffle=True,
        callbacks=[
            lr_reduce,
            early_stop,
            checkpoint
        ],
        verbose=1
    )

    test_accuracy = model.evaluate(FashionMNISTSequence('test'))[1]

    return model, test_accuracy

In [None]:
cache_all()
model, test_accuracy = train(100)