# tf.argmax's output dtype is int64

In [1]:
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
from copy import deepcopy

# this line should be commented out for regular python run 
%matplotlib inline  
# this line should be commented out for regular python run 


""" Hyperparameter """
data_size_train = 60000
data_size_test = 10000
batch_size = 100
lr = 1e-2
epoch = 1000


""" Data Loading """
def load_mnist_flat():
    mnist = tf.keras.datasets.mnist
    (x_train, y_train), (x_test, y_test) = mnist.load_data()

    x_train, x_test = x_train / 255.0, x_test / 255.0
    x_train, x_test = x_train.reshape((-1, 784)), x_test.reshape((-1, 784))
    x_train, x_test = x_train.astype(np.float32), x_test.astype(np.float32)

    y_train_cls = deepcopy(y_train).astype(np.int32)
    y_test_cls = deepcopy(y_test).astype(np.int32)

    y_train = np.eye(10)[y_train].astype(np.float32)
    y_test = np.eye(10)[y_test].astype(np.float32)

    data = (x_train, x_test, y_train, y_test, y_train_cls, y_test_cls)

    return data


# x_train.shape     :  (60000, 784)
# x_test.shape      :  (10000, 784)
# y_train.shape     :  (60000, 10)
# y_test.shape      :  (10000, 10)
# y_train_cls.shape :  (60000,)
# y_test_cls.shape  :  (10000,)
data = load_mnist_flat()
x_train, x_test, y_train, y_test, y_train_cls, y_test_cls = data


""" Graph Construction """
tf.random.set_random_seed(337)

# placeholders
x = tf.placeholder(tf.float32, shape=(batch_size, 784), name='x')
y = tf.placeholder(tf.float32, shape=(batch_size, 10), name='y')
y_cls = tf.placeholder(tf.int32, shape=(batch_size,), name='y_cls')

# weights
W = tf.get_variable("W", shape=(784, 10), \
        initializer=tf.contrib.layers.variance_scaling_initializer(mode="FAN_AVG"))
b = tf.get_variable("b", shape=(1, 10), \
        initializer=tf.constant_initializer(0.0))

# logits, y_pred, and y_pred_cls 
logits = (x @ W) + b
y_pred = tf.nn.softmax(logits, name='y_pred') # probabilities
print(logits.dtype) # float32
print(y_pred.dtype) # float32

# y_pred_cls = tf.cast(tf.argmax(logits, axis=1), tf.int32) 
y_pred_cls = tf.argmax(logits, axis=1, output_type=tf.int32) 
print(y_pred_cls.dtype) # int32

# cross_entropy cost function
cross_entropy = tf.nn.softmax_cross_entropy_with_logits_v2(logits=logits,
                                                           labels=y)
cost = tf.reduce_mean(cross_entropy)

# optimizer
train_op = tf.train.AdamOptimizer(learning_rate=lr).minimize(cost)

# test accuracy
# you have to put test sets to compute test_accuracy
correct_bool = tf.equal(y_cls, y_pred_cls)
test_accuracy = tf.reduce_mean(tf.cast(correct_bool, tf.float32))

<dtype: 'float32'>
<dtype: 'float32'>
<dtype: 'int32'>
