# 8. Декоратор tf.function для ускорения выполнения функций

In [1]:
import os

os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'

import tensorflow as tf
from tensorflow.keras.datasets import mnist
from tensorflow.keras.utils import to_categorical
import numpy as np

2024-10-17 22:49:33.468225: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:485] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-10-17 22:49:33.479411: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:8454] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-10-17 22:49:33.482573: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1452] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


In [2]:
(x_train, y_train), (x_test, y_test) = mnist.load_data()

x_train = x_train / 255
x_test = x_test / 255

x_train = tf.reshape(tf.cast(x_train, tf.float32), [-1, 28*28])
x_test = tf.reshape(tf.cast(x_test, tf.float32), [-1, 28*28])

y_train = to_categorical(y_train, 10)

In [3]:
class DenseNN(tf.Module):
    def __init__(self, outputs, activate='relu'):
        super().__init__()
        self.outputs = outputs
        self.activate = activate
        self.fl_init = False
        
    def __call__(self, x):
        if not self.fl_init:
            self.w = tf.random.truncated_normal((x.shape[-1], self.outputs), stddev=0.1, name='w')  
            self.b = tf.zeros([self.outputs], dtype=tf.float32, name='b')
            
            self.w = tf.Variable(self.w)
            self.b = tf.Variable(self.b)             
            self.fl_init = True
            
        y = x @ self.w + self.b
        
        if self.activate == 'relu':
            return tf.nn.relu(y)
        elif self.activate == 'softmax':
            return tf.nn.softmax(y)
        
        return y
    
class SequentialModule(tf.Module):
    def __init__(self):
        super().__init__()
        self.layer_1 = DenseNN(128)
        self.layer_2 = DenseNN(10, activate='softmax')
        
    def __call__(self, x):
        return self.layer_2(self.layer_1(x))

In [4]:
model = SequentialModule()

cross_entropy = lambda y_true, y_pred: tf.reduce_mean(tf.losses.categorical_crossentropy(y_true, y_pred))

opt = tf.keras.optimizers.Adam(learning_rate=0.01)

batch_size = 32
epochs = 10
total = x_train.shape[0]

train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train))
train_dataset = train_dataset.shuffle(buffer_size=1024).batch(batch_size)

In [5]:
@tf.function #декоратор для выполнения работы нейронной сети на уровни графов
def train_batch(x_batch, y_batch):
    with tf.GradientTape() as tape:
        f_loss = cross_entropy(y_batch, model(x_batch))
            
        #loss += f_loss
    grads = tape.gradient(f_loss, model.trainable_variables)
    opt.apply_gradients(zip(grads, model.trainable_variables))
    return f_loss

for n in range(epochs):
    loss = 0
    for x_batch, y_batch in train_dataset:
        loss += train_batch(x_batch, y_batch)
        
    print(loss.numpy())

427.97458
298.4348
259.79315
242.99739
219.41093
221.73846
212.92538
191.5984
202.26596
171.78041


In [7]:
y = model(x_test)
y2 = tf.argmax(y, axis=1).numpy()
acc = len(y_test[y_test == y2]) / y_test.shape[0] * 100
print(acc)

95.82000000000001


Пример декоратора для перевода вычислений в графы

In [10]:
import time

In [15]:
def function_tf(x, y):
    s = tf.zeros_like(x, dtype=tf.float32)
    s = s + tf.matmul(x, y)
    for n in range(10):
        s = s + tf.matmul(s, y) * x
    
    return s 

def test_function(fn):
    def wrapper(*args, **kwargs):
        start = time.time()
        for n in range(10):
            fn(*args, **kwargs)
        dt = time.time() - start
        print(f"Время обработки: {dt} сек")
        
    return wrapper

In [12]:
size = 1000
x = tf.ones((size, size), dtype=tf.float32)
y = tf.ones_like(x, dtype=tf.float32)

In [13]:
function_tf_graph = tf.function(function_tf)

In [16]:
test_function(function_tf)(x, y)
test_function(function_tf_graph)(x, y)

Время обработки: 0.47199010848999023 сек
Время обработки: 0.3979959487915039 сек


Модификация кода для ускорения

In [24]:
def function_for(s, x, y):
    for n in range(10):
        s = s + tf.matmul(s, y) * x
    
    return s 

def function_tf(x, y):
    print("вызов функции print")
    s = tf.zeros_like(x, dtype=tf.float32)
    s = s + tf.matmul(x, y)
        
    return function_for(s, x, y)

def test_function(fn):
    def wrapper(*args, **kwargs):
        start = time.time()
        for n in range(10):
            fn(*args, **kwargs)
        dt = time.time() - start
        print(f"Время обработки: {dt} сек")
        
    return wrapper

In [22]:
size = 1000
x = tf.ones((size, size), dtype=tf.float32)
y = tf.ones_like(x, dtype=tf.float32)

function_tf_graph = tf.function(function_tf)

In [25]:
test_function(function_tf)(x, y)
test_function(function_tf_graph)(x, y)

вызов функции print
вызов функции print
вызов функции print
вызов функции print
вызов функции print
вызов функции print
вызов функции print
вызов функции print
вызов функции print
вызов функции print
Время обработки: 0.49132871627807617 сек
Время обработки: 0.39866089820861816 сек
