In [5]:
import tensorflow as tf
from tensorflow.keras.layers import Dense, Dropout, LayerNormalization, Add, Activation, Input
from tensorflow.keras.layers.experimental.preprocessing import Rescaling
from tensorflow.keras.models import Model
from tensorflow.keras import backend as K
from tensorflow.keras.datasets import cifar10
from tensorflow.keras.utils import to_categorical

In [6]:
def MultiHead_SelfAttention(inputs, embed_dim, num_heads):
    projection_dim = embed_dim // num_heads
    batch_size = K.int_shape(inputs)[0]

    query = Dense(embed_dim)(inputs)
    key   = Dense(embed_dim)(inputs)
    value = Dense(embed_dim)(inputs)

    query = K.reshape(query, (batch_size, -1, num_heads, projection_dim))
    key   = K.reshape(key,   (batch_size, -1, num_heads, projection_dim))
    value = K.reshape(value, (batch_size, -1, num_heads, projection_dim))

    query = K.permute_dimensions(query, (0, 2, 1, 3))
    key   = K.permute_dimensions(key,   (0, 2, 1, 3))
    value = K.permute_dimensions(value, (0, 2, 1, 3))

    score = tf.matmul(query, key, transpose_b=True)
    score = score/K.sqrt(K.cast(projection_dim, 'float32'))
    weights = Activation('softmax')(score)

    attention = tf.matmul(weights, value)
    attention = K.permute_dimensions(attention, (0, 2, 1, 3))
    attention = K.reshape(attention, (batch_size, -1, embed_dim))
    output = Dense(embed_dim)(attention)
    return output

In [7]:
def TransformerBlock(inputs, embed_dim, num_heads, ff_dim):
    attn_output = MultiHead_SelfAttention(inputs, embed_dim, num_heads)
    attn_output = Dropout(0.1)(attn_output)
    out1 = LayerNormalization(epsilon=1e-6)(Add()([inputs, attn_output]))
    ffn_output = Dense(ff_dim, activation="relu")(out1)
    ffn_output = Dense(embed_dim)(ffn_output)
    ffn_output = Dropout(0.1)(ffn_output)
    return LayerNormalization(epsilon=1e-6)(Add()([out1, ffn_output]))

In [8]:
class Add_Embedding_Layer(tf.keras.layers.Layer):
    def __init__(self, num_patches=64, d_model=64, batch_size=16):
        super(Add_Embedding_Layer, self).__init__()
        self.batch_size = batch_size
        self.patch_emb = self.add_weight(shape=[1, 1, d_model], dtype=tf.float32)
        self.pos_emb = self.add_weight(shape=[1, num_patches+1, d_model], dtype=tf.float32)

    def call(self, input):
        patch_emb = K.repeat_elements(self.patch_emb, self.batch_size, axis=0)
        pos_emb = K.repeat_elements(self.pos_emb, self.batch_size, axis=0)
        return K.concatenate([input, patch_emb], axis=1) + pos_emb

In [10]:
epochs = 30
batch_size = 400

def make_ViT(img_size = 32, ch_size = 3, patch_size = 4,
             batch_size = 400, num_layers = 4, d_model = 64,
             num_heads = 4, mlp_dim = 128, num_classes = 10):

    num_patches = (img_size // patch_size) ** 2
    patch_dim = ch_size * patch_size ** 2

    inputs = Input(shape=(32, 32, 3))

    x = Rescaling(1./255)(inputs)
    x = tf.nn.space_to_depth(x, patch_size)
    x = K.reshape(x, (-1, num_patches, patch_dim))
    x = Dense(d_model)(x)

    x = Add_Embedding_Layer(num_patches, d_model, batch_size)(x)
    for _ in range(num_layers):
        x = TransformerBlock(x, d_model, num_heads, mlp_dim)

    x = Dense(mlp_dim, activation='relu')(x[:, 0])
    x = Dropout(0.1)(x)
    y = Dense(num_classes, activation='softmax')(x)
    return Model(inputs=inputs, outputs=y)

model = make_ViT()
model.compile(optimizer='Adam',
              loss='categorical_crossentropy',
              metrics=['accuracy'])
model.summary()

embedding__layer_1[0][0]    
__________________________________________________________________________________________________
tf_op_layer_RealDiv_4 (TensorFl [(400, 4, 65, 65)]   0           tf_op_layer_BatchMatMulV2_8[0][0]
__________________________________________________________________________________________________
tf_op_layer_Reshape_20 (TensorF [(400, 65, 4, 16)]   0           dense_28[0][0]                   
__________________________________________________________________________________________________
activation_4 (Activation)       (400, 4, 65, 65)     0           tf_op_layer_RealDiv_4[0][0]      
__________________________________________________________________________________________________
tf_op_layer_Transpose_18 (Tenso [(400, 4, 65, 16)]   0           tf_op_layer_Reshape_20[0][0]     
__________________________________________________________________________________________________
tf_op_layer_BatchMatMulV2_9 (Te [(400, 4, 65, 16)]   0           activation_4[0]

In [11]:
(x_train, y_train), (x_test, y_test) = cifar10.load_data()
y_train = to_categorical(y_train)
y_test = to_categorical(y_test)
history = model.fit(x_train, y_train, batch_size=batch_size, epochs=epochs, validation_data=(x_test, y_test), verbose=1)

Epoch 1/30

KeyboardInterrupt: 