# TensorFlow2 推理

参考：[migrating_checkpoints](https://www.tensorflow.org/guide/migrate/migrating_checkpoints)

下面以模型 [resnet_v2_50](http://download.tensorflow.org/models/resnet_v2_50_2017_04_14.tar.gz) 为例展示。

需要克隆项目 [models](https://github.com/tensorflow/models)，然后执行如下操作。

In [1]:
import tensorflow as tf
try:
    tf1 = tf.compat.v1
except (ImportError, AttributeError):
    tf1 = tf
tf.get_logger().setLevel('ERROR')

2023-06-15 19:39:52.245278: I tensorflow/tsl/cuda/cudart_stub.cc:28] Could not find cuda drivers on your machine, GPU will not be used.
2023-06-15 19:39:52.295369: I tensorflow/tsl/cuda/cudart_stub.cc:28] Could not find cuda drivers on your machine, GPU will not be used.
2023-06-15 19:39:52.297971: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


切换到 `models/research/slim` 目录下：

In [2]:
%cd /media/pc/data/lxw/ai/tasks/models/research/slim

/media/pc/data/lxw/ai/tasks/models/research/slim


将 TF1 升级为 TF2：

In [4]:
from nets import resnet_v2
import tf_slim as slim


class ResnetV2_50(tf.keras.layers.Layer):
    def __init__(self, image_size, num_classes=1001, trainable=True, 
                 name="resnet_v2_50", dtype=None, dynamic=False, **kwargs):
        super().__init__(trainable, name, dtype, dynamic, **kwargs)
        self.image_size = image_size
        self.num_classes = num_classes
        # self.preprocessing = get_preprocessing(self.name)

    @tf1.keras.utils.track_tf1_style_variables
    def call(self, inputs, training=False):
        is_training = training or False 
        with slim.arg_scope(resnet_v2.resnet_arg_scope()):
            logits, end_points = resnet_v2.resnet_v2_50(
                inputs, 
                num_classes=self.num_classes,
                global_pool=True,
                is_training=is_training,
                scope=self.name
            )
        return tf.nn.softmax(logits), end_points

预处理：

In [5]:
from PIL import Image
import numpy as np
from nets import resnet_v2
from preprocessing.preprocessing_factory import get_preprocessing
import tf_slim as slim

preprocessing = get_preprocessing("resnet_v2_50")
image_size = 224
path = '/media/pc/data/board/arria10/lxw/data/test/cat.png' # 将要预测的图片路径
preprocessing = get_preprocessing("resnet_v2_50")


@tf.function
def preprocess_image(image, output_height, output_width):
    # image = tf.constant(image)
    processed_image = preprocessing(image, output_height, output_width)
    return processed_image/256
with Image.open(path) as im:
    if im.mode != "RGB":
        im.convert("RGB")
    image = np.asarray(im)
np_processed_image = preprocess_image(image, image_size, image_size)
np_processed_images = np.expand_dims(np_processed_image, axis=0)

2023-06-15 19:40:25.314629: W tensorflow/core/common_runtime/gpu/gpu_device.cc:1956] Cannot dlopen some GPU libraries. Please make sure the missing libraries mentioned above are installed properly if you would like to use GPU. Follow the guide at https://www.tensorflow.org/install/gpu for how to download and setup the required libraries for your platform.
Skipping registering GPU devices...


前向推理：

In [6]:
model = ResnetV2_50(224)
model(tf.ones(shape=(1, 224, 224, 3), dtype=tf.float32))
ckpt = tf.train.Checkpoint(model=model)
ckpt.restore("/tmp/checkpoints/resnet_v2_50.ckpt") # 更新模型参数
output, endpoints = model(np_processed_images)
output = output.numpy()



打印标签信息：

In [7]:
from github import Github

g = Github(user_agent="xinetzone")
repo = g.get_repo("tensorflow/models")
label_content = repo.get_contents("research/slim/datasets/imagenet_lsvrc_2015_synsets.txt")
imagenet_labels = label_content.decoded_content.decode().split()
assert len(imagenet_labels) == 1000
metadata = repo.get_contents("research/slim/datasets/imagenet_metadata.txt")
imagenet_metadata = metadata.decoded_content.decode().splitlines()
synset_to_human = {}
for metadata in imagenet_metadata:
    name, value = metadata.split("\t")
    synset_to_human[name] = value
name2id = {name: k+1 for k, name in enumerate(imagenet_labels)}

topk = 5
sorted_inds = output[0].argsort()[::-1]
for prob, sorted_ind in zip(output[0][:topk], sorted_inds[:topk]):
    label = synset_to_human[imagenet_labels[sorted_ind-1]]
    print(f"{sorted_ind-1}: {label.ljust(20)}\t{prob}")

282: tiger cat           	6.913422794241342e-07
285: Egyptian cat        	2.1831299079622113e-07
281: tabby, tabby cat    	1.9797682853095466e-06
278: kit fox, Vulpes macrotis	3.303630080608855e-08
277: red fox, Vulpes vulpes	1.9624657454642147e-07
