In [None]:
import tensorflow as tf
import tensorflow_hub as hub
import numpy as np
import cv2
import matplotlib.pyplot as plt
import os

# 设置图像路径
image_path = 'test4.png'  # 替换为你的图片路径
model_local_dir = './midas_model'  # 本地模型保存路径

# 下载并缓存 MiDaS 模型（首次执行需要联网）
def load_midas_model():
    model_url = "https://tfhub.dev/intel/midas/v2_1_small/1"
    if not os.path.exists(model_local_dir):
        print("Downloading MiDaS model...")
    midas_model = hub.load(model_url, tags=["serve"])
    return midas_model.signatures['serving_default']

# 图像预处理
def preprocess_image(image_path, target_size=(256, 256)):
    original_img = cv2.imread(image_path)
    if original_img is None:
        raise FileNotFoundError(f"无法读取图像文件: {image_path}")
    original_img = cv2.cvtColor(original_img, cv2.COLOR_BGR2RGB)
    original_size = original_img.shape[:2]

    resized_img = cv2.resize(original_img, target_size)
    normalized_img = resized_img / 255.0
    input_tensor = tf.convert_to_tensor([normalized_img], dtype=tf.float32)
    return input_tensor, original_size, original_img

# 深度图后处理
def postprocess_depth(depth_tensor, original_size):
    depth_map = depth_tensor['default'][0].numpy()
    depth_map = cv2.resize(depth_map, (original_size[1], original_size[0]))
    depth_min = depth_map.min()
    depth_max = depth_map.max()
    depth_vis = 255 * (depth_map - depth_min) / (depth_max - depth_min + 1e-6)
    return depth_vis.astype(np.uint8)

# 主执行流程
def run_depth_estimation(image_path):
    midas = load_midas_model()
    input_tensor, original_size, original_img = preprocess_image(image_path)
    depth_result = midas(input_tensor)
    depth_map = postprocess_depth(depth_result, original_size)

    # 显示图像和深度图
    plt.figure(figsize=(12, 5))
    plt.subplot(1, 2, 1)
    plt.title("Original Image")
    plt.imshow(original_img)
    plt.axis("off")

    plt.subplot(1, 2, 2)
    plt.title("Estimated Depth")
    plt.imshow(depth_map, cmap='inferno')
    plt.axis("off")
    plt.tight_layout()
    plt.show()

# 执行
run_depth_estimation(image_path)
