In [1]:
import tensorflow as tf
import tensorflow_datasets as tfds
import os 
from tf_CMT.model import CMT_Model

In [2]:
os.environ["XLA_FLAGS"]="--xla_gpu_cuda_data_dir=/usr/local/cuda"
os.environ["CUDA_HOME"]="/usr/local/cuda"

## Load mnist dataset

In [3]:
(ds_train, ds_test), ds_info = tfds.load(
    'mnist',
    split=['train', 'test'],
    shuffle_files=True,
    as_supervised=True,
    with_info=True,
)

## Normalization

In [4]:
def normalize_img(image, label):
    """Normalizes images: `uint8` -> `float32`."""
    return tf.cast(image, tf.float32) / 255., label

ds_train = ds_train.map(normalize_img, num_parallel_calls=tf.data.experimental.AUTOTUNE)
ds_train = ds_train.cache()
ds_train = ds_train.shuffle(ds_info.splits['train'].num_examples)
ds_train = ds_train.batch(256)
ds_train = ds_train.prefetch(tf.data.experimental.AUTOTUNE)

In [5]:
ds_test = ds_test.map(normalize_img, num_parallel_calls=tf.data.experimental.AUTOTUNE)
ds_test = ds_test.cache()
ds_test = ds_test.batch(256)
ds_test = ds_test.prefetch(tf.data.experimental.AUTOTUNE)

### Define CMT model

In [6]:
model = CMT_Model(Block_num     = [3,12], # Number of CMT_Blocks in each stage
                  K             = 2,      # HyperParam to reduce the calculation of self-attention to O(N^2/k^2)
                  n_heads       = 4,      # Number of heads
                  head_dim      = 256,    # The latent dimension of self-attention
                  filters       = 256,    # Number of filters of CNNs
                  num_classes   = 10,     # Number of output classes
                  usePosBias    = True,   # Use learnable positional bias 
                  output_logits = True    # Output logits or not
                 )

### Train model

In [7]:
model.compile(
    optimizer=tf.keras.optimizers.Adam(0.001),
    loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    metrics=[tf.keras.metrics.SparseCategoricalAccuracy()],
)

In [8]:
early_stop = tf.keras.callbacks.EarlyStopping(monitor='val_loss', patience=3)

In [9]:
model.fit(
    ds_train,
    epochs = 20,
    validation_data=ds_test,
    callbacks = [early_stop]
)

Epoch 1/20
Epoch 2/20
Epoch 3/20
Epoch 4/20
Epoch 5/20
Epoch 6/20
Epoch 7/20
Epoch 8/20


<tensorflow.python.keras.callbacks.History at 0x7f02801ebd10>