In [None]:
!pip install apache-tvm

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/


In [None]:
import tvm 
import numpy as np
import tensorflow as tf

In [None]:
print("TVM version:", tvm.__version__)
print("NumPy version:", np.__version__)
print("TensorFlow version:", tf.__version__)

TVM version: 0.9.0
NumPy version: 1.21.6
TensorFlow version: 2.8.2


In [None]:
fashion_mnist = tf.keras.datasets.fashion_mnist
(train_images, train_labels), (test_images, test_labels) = fashion_mnist.load_data()

In [None]:
train_images = train_images / 255.0
test_images = test_images / 255.0

In [None]:
class_names = ['T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat', 'Sandal',
               'Shirt', 'Sneaker', 'Bag', 'Ankle boot']

In [None]:
model = tf.keras.Sequential([
    tf.keras.layers.Flatten(input_shape=(28, 28), name='flatten_1'),
    tf.keras.layers.Dense(128, activation='relu'),
    tf.keras.layers.Dense(10)
])

In [None]:
model.compile(optimizer='adam', 
              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
              metrics=['accuracy'])

In [None]:
model.fit(train_images, train_labels, epochs=10)

Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10


<keras.callbacks.History at 0x7f39c58b5590>

In [None]:
test_loss, test_acc = model.evaluate(test_images,  test_labels, verbose=2)
print('\nTest accuracy:', test_acc)

313/313 - 1s - loss: 0.3410 - accuracy: 0.8840 - 520ms/epoch - 2ms/step

Test accuracy: 0.8840000033378601


In [None]:
probability_model = tf.keras.Sequential([model, tf.keras.layers.Softmax()])
predictions = probability_model.predict(test_images)
print("Keras predict: ", class_names[np.argmax(predictions[0])])
print("Answer: ", class_names[test_labels[0]])

Keras predict:  Ankle boot
Answer:  Ankle boot


In [None]:
data = np.array(train_images)[np.newaxis, :].astype("float32")
data = data.transpose([1, 0, 2, 3])
shape_dict = {"flatten_1_input": data.shape}

In [None]:
mod, params = tvm.relay.frontend.from_keras(model, shape_dict, layout="NCHW")

In [None]:
print(shape_dict)

{'flatten_1_input': (60000, 1, 28, 28)}


In [None]:
print(mod)

def @main(%flatten_1_input: Tensor[(60000, 1, 28, 28), float32], %v_param_1: Tensor[(128, 784), float32], %v_param_2: Tensor[(128), float32], %v_param_3: Tensor[(10, 128), float32], %v_param_4: Tensor[(10), float32]) {
  %0 = transpose(%flatten_1_input, axes=[0, 2, 3, 1]);
  %1 = nn.batch_flatten(%0);
  %2 = nn.dense(%1, %v_param_1, units=128);
  %3 = nn.bias_add(%2, %v_param_2);
  %4 = nn.relu(%3);
  %5 = nn.dense(%4, %v_param_3, units=10);
  nn.bias_add(%5, %v_param_4)
}



In [None]:
print(params)

{'_param_1': <tvm.nd.NDArray shape=(128, 784), cpu(0)>
array([[ 0.22257568,  0.23190874,  0.3886681 , ..., -0.01436261,
         0.26669192,  0.19182   ],
       [-0.10346685,  0.04861556,  0.13840634, ...,  0.2748953 ,
        -0.29270855,  0.17692687],
       [ 0.06714239, -0.19274831, -0.0167403 , ...,  0.09110993,
        -0.19981442,  0.07897342],
       ...,
       [ 0.1773131 ,  0.15603375,  0.11255036, ..., -0.14668164,
        -0.21108034,  0.0671009 ],
       [-0.05093393,  0.29168066,  0.2898056 , ..., -0.01702415,
         0.09920012,  0.15783215],
       [ 0.07435638, -0.08680963,  0.07287902, ..., -0.02268694,
         0.08838622,  0.09159937]], dtype=float32), '_param_2': <tvm.nd.NDArray shape=(128,), cpu(0)>
array([ 0.35672253,  0.14644524,  0.39203906, -0.29724568,  0.22118716,
        0.42256585,  0.45139995,  0.47372693,  0.42327607,  0.08090623,
        0.3114824 ,  0.23824978,  0.161685  , -0.01218236,  0.20451668,
        0.07872387,  0.39456075, -0.01198395,  0.0

In [None]:
with tvm.transform.PassContext(opt_level=3):
    tvm_model = tvm.relay.build_module.create_executor("graph", mod, tvm.cpu(0), "llvm", params).evaluate()

  "target_host parameter is going to be deprecated. "


In [None]:
tvm_out = tvm_model(tvm.nd.array(data.astype("float32")))
top1_tvm = np.argmax(tvm_out.numpy()[0])

In [None]:
print("TVM predict:", class_names[top1_tvm])
print("Answer: ", class_names[test_labels[0]])

TVM predict: Ankle boot
Answer:  Ankle boot
