In [1]:
import tensorflow as tf

import ssl
ssl._create_default_https_context = ssl._create_unverified_context

def preprocess(image, y_true = None):
    image = tf.image.resize([image], (224, 224), tf.image.ResizeMethod.BICUBIC)[0]
    result = image
    if y_true is not None:
        result = image, y_true
    return result

def pipe(data, batch_size = 16, shuffle = False):
    dataset = tf.data.Dataset.from_tensor_slices(data)
    dataset = dataset.map(preprocess, num_parallel_calls = tf.data.experimental.AUTOTUNE)
    if shuffle:
        dataset = dataset.shuffle(buffer_size = batch_size * 10)
    dataset = dataset.batch(batch_size)
    #dataset = dataset.prefetch((batch_size * 2) + 1)
    dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE)
    return dataset

(tr_x, tr_y), (te_x, te_y) = tf.keras.datasets.cifar10.load_data()

tr_x = tr_x * 1/255
te_x = te_x * 1/255

batch_size = 16

tr_data = pipe((tr_x[:1000], tr_y[:1000]), batch_size = batch_size, shuffle = True)
te_data = pipe((te_x[:1000], te_y[:1000]), batch_size = batch_size, shuffle = False)

Downloading data from https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz


In [2]:
import swin_transformer

"""
> Manual Backbone Initialize (Example: swin_transformer_tiny)
x = tf.keras.layers.Input(shape = (224, 224, 3))
out = swin_transformer.swin_transformer(x, n_class = 10, include_top = True, patch_size = 4, n_feature = 96, n_blocks = [2, 2, 6, 2], n_heads = [3, 6, 12, 24], window_size = 7, ratio = 4., scale = None, use_bias = True, patch_normalize = True, dropout_rate = 0., attention_dropout_rate = 0., droppath_rate = 0.1, normalize = tf.keras.layers.LayerNormalization, activation = tf.keras.activations.gelu)
model = tf.keras.Model(x, out)
"""
model = swin_transformer.swin_transformer_tiny(input_shape = (224, 224, 3), include_top = False, weights = "imagenet")

feature = tf.keras.layers.Reshape([7, 7, -1])(model.output)
flatten = tf.keras.layers.GlobalAveragePooling2D()(feature)
drop_out = tf.keras.layers.Dropout(0.5)(flatten)
dense = tf.keras.layers.Dense(2048, activation = "relu")(drop_out)
prediction = tf.keras.layers.Dense(10, activation = "softmax", name = "prediction")(dense)
model = tf.keras.Model(model.input, prediction)

Downloading: "https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_tiny_patch4_window7_224.pth" to /root/.cache/torch/hub/checkpoints/swin_tiny_patch4_window7_224.pth


  0%|          | 0.00/109M [00:00<?, ?B/s]

In [3]:
loss = tf.keras.losses.sparse_categorical_crossentropy
opt = tf.keras.optimizers.Adam(1e-4)
metric = [tf.keras.metrics.sparse_categorical_accuracy]
model.compile(loss = loss, optimizer = opt, metrics = metric)

In [4]:
model.fit(tr_data, validation_data = te_data, epochs = 1)



<keras.callbacks.History at 0x7f7e79217e90>

In [5]:
model.save_weights("model.h5")

In [6]:
model = swin_transformer.swin_transformer_tiny(input_shape = (224, 224, 3), include_top = False, weights = None)

feature = tf.keras.layers.Reshape([7, 7, -1])(model.output)
flatten = tf.keras.layers.GlobalAveragePooling2D()(feature)
drop_out = tf.keras.layers.Dropout(0.5)(flatten)
dense = tf.keras.layers.Dense(2048, activation = "relu")(drop_out)
prediction = tf.keras.layers.Dense(10, activation = "softmax", name = "prediction")(dense)
model = tf.keras.Model(model.input, prediction)

model.load_weights("model.h5")

In [7]:
loss = tf.keras.losses.sparse_categorical_crossentropy
metric = [tf.keras.metrics.sparse_categorical_accuracy]
model.compile(loss = loss, metrics = metric)
model.evaluate(te_data)



[0.6011554598808289, 0.7979999780654907]