# 测试 ImageNet 分类精度

In [1]:
import tensorflow as tf
try:
    tf1 = tf.compat.v1
except (ImportError, AttributeError):
    tf1 = tf
%cd /media/pc/data/lxw/ai/tasks/models/research/slim
from nets import resnet_v2
import tf_slim as slim
import numpy as np
from tvm_book.metric.classification import Accuracy, TopKAccuracy
from tvm_book.data.classification import ImageFolderDataset
from tvm_book.data.imagenet.classification import ImageNet1kAttr

# @tf.function
def preprocessing(
    image,
    use_grayscale=False,
    central_fraction=0.875,
    central_crop=True,
    height=224,
    width=224,
    mean: tuple[float, ...] = (0.485, 0.456, 0.406),
    std: tuple[float, ...] = (0.229, 0.224, 0.225)
):
    image = tf.convert_to_tensor(image)
    if image.dtype != tf.float32:
        image = tf.image.convert_image_dtype(image, dtype=tf.float32)
    if use_grayscale:
        image = tf.image.rgb_to_grayscale(image)
    if central_crop and central_fraction:
        image = tf.image.central_crop(image, central_fraction=central_fraction)
    if height and width:
        image = tf.expand_dims(image, 0)
        image = tf.image.resize(image, [height, width],
                                method='bilinear',
                                preserve_aspect_ratio=False,
                                antialias=False)
        image = tf.squeeze(image, [0])
    image = tf.subtract(image, tf.constant(mean, dtype=tf.float32))
    image = tf.divide(image, tf.constant(std, dtype=tf.float32))
    return image

class ResnetV2_50(tf.keras.Model):
    def __init__(self, trainable=False, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.trainable = trainable

    @tf.function(input_signature=[tf.TensorSpec([1, 3, 299, 299], 
                                                 tf.float32, name="data")])
    @tf1.keras.utils.track_tf1_style_variables
    def call(self, x):
        # x = tf.convert_to_tensor(x, tf.float32) # 确保输入是 tensor
        x = tf.transpose(x, perm=(0, 2, 3, 1)) # NCHW -> NHWC
        with slim.arg_scope(resnet_v2.resnet_arg_scope()):
            logits, end_points = resnet_v2.resnet_v2_50(
                x, 
                num_classes=1001,
                global_pool=True,
                is_training=self.trainable,
                scope="resnet_v2_50"
            )
        del end_points
        return tf.nn.softmax(logits)

2023-06-21 16:03:37.608487: I tensorflow/tsl/cuda/cudart_stub.cc:28] Could not find cuda drivers on your machine, GPU will not be used.
2023-06-21 16:03:37.815194: I tensorflow/tsl/cuda/cudart_stub.cc:28] Could not find cuda drivers on your machine, GPU will not be used.
2023-06-21 16:03:37.817549: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


/media/pc/data/lxw/ai/tasks/models/research/slim



In [2]:
model = ResnetV2_50()
model(tf.ones(shape=(1, 3, 299, 299), dtype=tf.float32))
ckpt = tf.train.Checkpoint(model=model)
ckpt_path = "/media/pc/data/board/arria10/lxw/tests/npu_user_demos/models/resnet50_v2_tf/weight/resnet_v2_50.ckpt"
ckpt.restore(ckpt_path) # 更新模型参数
root = "/media/pc/data/lxw/home/data/datasets/ILSVRC/val"
valset = ImageFolderDataset(root)

2023-06-21 16:03:52.183666: W tensorflow/core/common_runtime/gpu/gpu_device.cc:1956] Cannot dlopen some GPU libraries. Please make sure the missing libraries mentioned above are installed properly if you would like to use GPU. Follow the guide at https://www.tensorflow.org/install/gpu for how to download and setup the required libraries for your platform.
Skipping registering GPU devices...
  self.updates, tf.compat.v1.GraphKeys.UPDATE_OPS


Instructions for updating:
Restoring a name-based tf.train.Saver checkpoint using the object-based restore API. This mode uses global names to match variables, and so is somewhat fragile. It also adds new restore ops to the graph each time it is called when graph building. Prefer re-encoding training checkpoints in the object-based format: run save() on the object-based saver (the same one this message is coming from) and use that checkpoint in the future.


In [3]:
from tqdm import tqdm

metric = Accuracy()
top5_metric = TopKAccuracy(top_k=5)
imagenet1k_attr = ImageNet1kAttr()
for k, (image, label_id) in tqdm(enumerate(valset)):
    processed_image = preprocessing(
        image,
        use_grayscale=False,
        central_fraction=0.875,
        central_crop=True,
        height=299,
        width=299,
        mean=(0.485, 0.456, 0.406),
        std=(1, 1, 1)
    )
    np_processed_images = np.expand_dims(processed_image.numpy(), axis=0)
    np_processed_images = np_processed_images.transpose(0, 3, 1, 2)
    outputs = model(np_processed_images)
    outputs = outputs.numpy()
    metric.update(labels=np.array([label_id+1]), preds=outputs)
    top5_metric.update(labels=np.array([label_id+1]), preds=outputs)
    if k%1000==0:
        print(f"{k+1}: {metric} {top5_metric}")

3it [00:00,  7.35it/s]

1: Accuracy: {'Accuracy': 0.0} TopKAccuracy: {'top_5_accuracy': 1.0}


1003it [01:04, 15.83it/s]

1001: Accuracy: {'Accuracy': 0.8861138861138861} TopKAccuracy: {'top_5_accuracy': 0.967032967032967}


2003it [02:08, 15.98it/s]

2001: Accuracy: {'Accuracy': 0.8275862068965517} TopKAccuracy: {'top_5_accuracy': 0.9545227386306847}


3003it [03:16, 14.74it/s]

3001: Accuracy: {'Accuracy': 0.7944018660446518} TopKAccuracy: {'top_5_accuracy': 0.9456847717427525}


4003it [04:24, 14.77it/s]

4001: Accuracy: {'Accuracy': 0.773056735816046} TopKAccuracy: {'top_5_accuracy': 0.9380154961259685}


5003it [05:40, 14.19it/s]

5001: Accuracy: {'Accuracy': 0.7966406718656269} TopKAccuracy: {'top_5_accuracy': 0.9444111177764447}


6002it [06:54, 14.20it/s]

6001: Accuracy: {'Accuracy': 0.7998666888851858} TopKAccuracy: {'top_5_accuracy': 0.9418430261623063}


7002it [08:07, 13.61it/s]

7001: Accuracy: {'Accuracy': 0.8070275674903585} TopKAccuracy: {'top_5_accuracy': 0.9448650192829596}


8003it [09:18, 14.59it/s]

8001: Accuracy: {'Accuracy': 0.8115235595550556} TopKAccuracy: {'top_5_accuracy': 0.9472565929258843}


9003it [10:35, 10.82it/s]

9001: Accuracy: {'Accuracy': 0.8029107876902566} TopKAccuracy: {'top_5_accuracy': 0.9463392956338185}


9705it [11:28, 12.89it/s]

In [None]:
print(f"{metric} {top5_metric}")