In [None]:
import tensorflow as tf

import ssl
ssl._create_default_https_context = ssl._create_unverified_context

def pipe(data, batch_size = 128, shuffle = False):
    dataset = tf.data.Dataset.from_tensor_slices(data)
    if shuffle:
        dataset = dataset.shuffle(buffer_size = batch_size * 10)
    dataset = dataset.batch(batch_size)
    #dataset = dataset.prefetch((batch_size * 2) + 1)
    dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE)
    return dataset

(tr_x, tr_y), (te_x, te_y) = tf.keras.datasets.cifar10.load_data()

tr_x = tr_x * 1/255
te_x = te_x * 1/255

batch_size = 128

tr_data = pipe((tr_x, tr_y), batch_size = batch_size, shuffle = True)
te_data = pipe((te_x, te_y), batch_size = batch_size, shuffle = False)

Downloading data from https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz


# **Only Vision-Transformer**

In [None]:
import vit

"""
> Manual Backbone Initialize (Example: vit_small)
x = tf.keras.layers.Input(shape = (32, 32, 3))
out = vit.VisionTransformer(n_class = 1000, include_top = True, patch_size = 16, distillation = False, emb_dim = 768, n_head = 8, n_feature = 2304, n_layer = 8, dropout_rate = 0.1, ori_input_shape = None, method = "bicubic")(x)
model = tf.keras.Model(x, out)
"""
model = vit.vit_small(input_shape = (32, 32, 3), classes = 10, distillation = False, include_top = True, weights = None)

In [None]:
loss = tf.keras.losses.sparse_categorical_crossentropy
opt = tf.keras.optimizers.Adam(1e-4)
metric = [tf.keras.metrics.sparse_categorical_accuracy]
model.compile(loss = loss, optimizer = opt, metrics = metric)

In [None]:
model.fit(tr_data, validation_data = te_data, epochs = 1)



<keras.callbacks.History at 0x7f5afe39cb50>

In [None]:
with open("model.json", mode = "w") as file:
    file.write(model.to_json())
model.save_weights("model.h5")

In [None]:
with open("model.json", mode = "r") as file:
    model = tf.keras.models.model_from_json(file.read(), {"VisionTransformer":vit.VisionTransformer})
model.load_weights("model.h5")

In [None]:
loss = tf.keras.losses.sparse_categorical_crossentropy
metric = [tf.keras.metrics.sparse_categorical_accuracy]
model.compile(loss = loss, metrics = metric)
model.evaluate(te_data)



[2.3025832176208496, 0.10000000149011612]

# **With Distillation Token (DeiT)**

In [None]:
model = vit.vit_small(input_shape = (32, 32, 3), classes = 10, distillation = True, include_top = True, weights = None)
logits, kd_logits = model.outputs[0], model.outputs[1]
tr_model = vit.train_model(model.input, logits, kd_logits)

In [None]:
import numpy as np
kd_sample = np.random.random((60000, 10))
tr_data = pipe({model.input.name:tr_x, "y_true":tr_y, "kd_true":kd_sample[:50000]}, batch_size = batch_size, shuffle = True)
te_data = pipe({model.input.name:te_x, "y_true":te_y, "kd_true":kd_sample[50000:]}, batch_size = batch_size, shuffle = False)

In [None]:
opt = tf.keras.optimizers.Adam(1e-4)
tr_model.compile(optimizer = opt)

In [None]:
tr_model.fit(tr_data, validation_data = te_data, epochs = 1)



<keras.callbacks.History at 0x7f5af8566cd0>