### Mixed precision is the use of both 16-bit and 32-bit floating-point types in a model during training to make it run faster and use less memory.

In [1]:
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras import mixed_precision

Among NVIDIA GPUs, those with compute capability 7.0 or higher will see the greatest performance benefit from mixed precision because they have special hardware units, called Tensor Cores, to accelerate float16 matrix multiplications and convolutions. Older GPUs offer no math performance benefit for using mixed precision, however memory and bandwidth savings can enable some speedups. You can look up the compute capability for your GPU at NVIDIA's CUDA GPU web page. Examples of GPUs that will benefit most from mixed precision include RTX GPUs, the V100, and the A100.

Among Intel CPUs, starting with the 4th Gen Intel Xeon Processors (code name Sapphire Rapids), will see the greatest performance benefit from mixed precision as they can accelerate bfloat16 computations using AMX instructions (requires Tensorflow 2.12 or later).

In [2]:
#The command only exists if the NVIDIA drivers are installed, so the following will raise an error otherwise.
!nvidia-smi -L


GPU 0: NVIDIA GeForce GTX 1650 (UUID: GPU-4cbe939e-e009-c760-7713-5d55bdf1dc95)


# Setting the dtype policy

To use mixed precision in Keras, you need to create a tf.keras.mixed_precision.Policy, typically referred to as a dtype policy. Dtype policies specify the dtypes layers will run in. In this guide, you will construct a policy from the string 'mixed_float16' and set it as the global policy. This will cause subsequently created layers to use mixed precision with a mix of float16 and float32.

In [3]:
policy = mixed_precision.Policy('mixed_float16')
mixed_precision.set_global_policy(policy)

INFO:tensorflow:Mixed precision compatibility check (mixed_float16): OK
Your GPU will likely run quickly with dtype policy mixed_float16 as it has compute capability of at least 7.0. Your GPU: NVIDIA GeForce GTX 1650, compute capability 7.5


In [4]:
mixed_precision.set_global_policy('mixed_float16')

The policy specifies two important aspects of a layer: the dtype the layer's computations are done in, and the dtype of a layer's variables. Above, you created a mixed_float16 policy (i.e., a mixed_precision.Policy created by passing the string 'mixed_float16' to its constructor). With this policy, layers use float16 computations and float32 variables. Computations are done in float16 for performance, but variables must be kept in float32 for numeric stability. You can directly query these properties of the policy

In [5]:
print('Compute dtypes: %s' % policy.compute_dtype)
print('Variable dtype: %s' % policy.variable_dtype)

Compute dtypes: float16
Variable dtype: float32


As mentioned before, the mixed_float16 policy will most significantly improve performance on NVIDIA GPUs with compute capability of at least 7.0. The policy will run on other GPUs and CPUs but may not improve performance. For TPUs and CPUs, the mixed_bfloat16 policy should be used instead.

# Building the model

In [6]:
inputs = keras.Input(shape=(784,), name='digits')
if tf.config.list_physical_devices('GPU'):
    print('The model will run with 4090 units on a GPU')
    num_units = 4090
else:
    # Use fewer units on CPU so the model finishes in a reasonable amount of time
    print('The model will run with 64 units on a CPU')
    num_units = 64
dense1 = layers.Dense(num_units, activation='relu', name='dense_1')
x = dense1(inputs)
dense2 = layers.Dense(num_units, activation='relu', name='dense_2')
x=dense2(x)

The model will run with 4090 units on a GPU


Each layer has a policy and uses the global policy by default. Each of the Dense layers therefore have the mixed_float16 policy because you set the global policy to mixed_float16 previously. This will cause the dense layers to do float16 computations and have float32 variables. They cast their inputs to float16 in order to do float16 computations, which causes their outputs to be float16 as a result. Their variables are float32 and will be cast to float16 when the layers are called to avoid errors from dtype mismatches.

In [7]:
print(dense1.dtype_policy)
print('x.dtype: %s' % x.dtype.name)
# 'kernel' is dense1's variable
print('dense1.kernal.dtype: %s' % dense1.kernel.dtype.name)


<Policy "mixed_float16">
x.dtype: float16
dense1.kernal.dtype: float32


Next, create the output predictions. Normally, you can create the output predictions as follows, but this is not always numerically stable with float16

In [8]:
# INCORRECT: softmax and model output will be float16, when it should be float32
outputs = layers.Dense(10, activation='softmax', name='predictions')(x)
print('Outputs dtype: %s' % outputs.dtype.name)

Outputs dtype: float16


In [9]:
# CORRECT: softmax and model output are float32
x = layers.Dense(10, name='dense_logits')(x)
outputs = layers.Activation('softmax', dtype='float32', name='predictions')(x)
print('Output dtype: %s' % outputs.dtype.name)

Output dtype: float32


In [10]:
outputs = layers.Activation('linear', dtype='float32')(outputs)

In [11]:
model = keras.Model(inputs=inputs, outputs=outputs)
model.compile(loss='sparse_categorical_crossentropy',
              optimizer=keras.optimizers.RMSprop(),
              metrics=['accuracy'])

(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()
x_train = x_train.reshape(60000, 784).astype('float32')/255
x_test = x_test.reshape(10000, 784).astype('float')/255

Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz


In [12]:
initial_weights = model.get_weights()

In [13]:
history = model.fit(x_train, y_train,
                    batch_size=8192,
                    epochs=5,
                    validation_split=0.2)
test_scores = model.evaluate(x_test, y_test, verbose=2)
print('Test loss:', test_scores[0])
print('Test accuracy:', test_scores[1])

Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
313/313 - 5s - loss: 0.2996 - accuracy: 0.9040 - 5s/epoch - 17ms/step
Test loss: 0.2995896339416504
Test accuracy: 0.9039999842643738
