In [4]:
import argparse
import json
import cv2
import os
import numpy as np
import torch
import torchvision.transforms as TTR
from torch.utils.data import DataLoader, Dataset

In [5]:
dms46 = [
    1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 15, 16, 17, 18, 19, 20, 21, 23,
    24, 26, 27, 29, 30, 32, 33, 34, 35, 36, 37, 38, 39, 41, 43, 44, 46, 47, 48, 49,
    50, 51, 52, 53, 56,
]
t = json.load(open(os.path.expanduser('./taxonomy.json'), 'rb'))
srgb_colormap = [
    t['srgb_colormap'][i] for i in range(len(t['srgb_colormap'])) if i in dms46
]
srgb_colormap.append([0, 0, 0])  # 无法识别的材料用黑色表示
srgb_colormap = np.array(srgb_colormap, dtype=np.uint8)

In [8]:
def apply_color(label_mask):
    label_mask[label_mask == 255] = len(srgb_colormap) - 1
    vis = np.take(srgb_colormap, label_mask, axis=0)
    return vis[..., ::-1]


def predict_image(model, image_path):
    img = cv2.imread(image_path, cv2.IMREAD_COLOR)
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    image = np.copy(img)

    image = torch.from_numpy(image.transpose((2, 0, 1))).float()
    image = TTR.Normalize([0.485 * 255, 0.456 * 255, 0.406 * 255], [0.229 * 255, 0.224 * 255, 0.225 * 255])(image)
    image = image.unsqueeze(0)

    if torch.cuda.is_available():
        image = image.cuda()

    with torch.no_grad():
        prediction = model(image)[0].data.cpu()[0].numpy()

    return prediction


def main():
    model_path = "./DMS46_v1.pt"  # 更改为实际预训练模型文件路径
    image_path = "./input_image.jpg"  # 更改为实际输入图片文件路径
    output_image_path = "./output_image.png"  # 输出彩色图片的路径
    output_csv_path = "./output_results.csv"  # 输出CSV结果文件的路径


In [9]:

    model = torch.jit.load(model_path)
    if torch.cuda.is_available():
        model = model.cuda()

    prediction = predict_image(model, image_path)
    predicted_colored = apply_color(prediction)

    cv2.imwrite(output_image_path, predicted_colored[..., ::-1])


NameError: name 'model_path' is not defined

In [15]:
 #从 taxonomy.json 文件中读取颜色映射
with open('./taxonomy.json', 'rb') as f:
    taxonomy_data = json.load(f)
    all_srgb_colormap = taxonomy_data['srgb_colormap']

# 根据 dms46 创建 srgb_colormap
dms46 = [
    1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 15, 16, 17, 18, 19, 20, 21, 23,
    24, 26, 27, 29, 30, 32, 33, 34, 35, 36, 37, 38, 39, 41, 43, 44, 46, 47, 48, 49,
    50, 51, 52, 53, 56,
]
srgb_colormap = [
    all_srgb_colormap[i] for i in range(len(all_srgb_colormap)) if i in dms46
]
srgb_colormap.append([0, 0, 0])  # 添加黑色以表示无法检测的材料
srgb_colormap = np.array(srgb_colormap, dtype=np.uint8)
srgb_colormap

array([[188, 188, 137],
       [  0, 188,   0],
       [188, 188,   0],
       [  0,   0, 188],
       [188,   0, 188],
       [  0, 188, 188],
       [241, 241, 241],
       [  0, 137, 137],
       [225,   0,   0],
       [137, 188,   0],
       [225, 188,   0],
       [137,   0, 188],
       [137, 188, 188],
       [225, 188, 188],
       [  0, 137,   0],
       [188, 137,   0],
       [137, 225, 188],
       [188, 137, 188],
       [  0, 137, 188],
       [188, 225,   0],
       [188, 225, 188],
       [137, 137,   0],
       [137, 225,   0],
       [225, 225,   0],
       [225, 137, 188],
       [  0, 225,   0],
       [  0,   0, 137],
       [188,   0,   0],
       [  0, 188, 137],
       [188,   0, 137],
       [  0,   0, 225],
       [225, 188, 137],
       [  0, 188, 225],
       [188, 188, 225],
       [225,   0, 137],
       [225, 225, 188],
       [137,   0, 225],
       [137, 188, 225],
       [225, 188, 225],
       [  0, 137, 225],
       [188, 137, 137],
       [188, 188

In [17]:
print(prediction[0].shape)

(1, 667, 1000)


In [18]:
# 设置预训练模型文件的路径
model_path = "./DMS46_v1.pt"

# 加载预训练模型
model = torch.jit.load(model_path)
model.eval()

# 用于对标签进行可视化的颜色映射
srgb_colormap = [
    # 在这里添加对应于 dms46 的 sRGB 颜色值
]

# 输入图像路径
input_image_path = "./input_image.jpg"

# 读取输入图像
input_image = cv2.imread(input_image_path)
input_image = cv2.cvtColor(input_image, cv2.COLOR_BGR2RGB)

# 预处理图像
mean = [0.485, 0.456, 0.406]
std = [0.229, 0.224, 0.225]
value_scale = 255
mean = [item * value_scale for item in mean]
std = [item * value_scale for item in std]

input_tensor = torch.from_numpy(input_image.transpose((2, 0, 1))).float()
input_tensor = TTR.Normalize(mean, std)(input_tensor)
input_tensor = input_tensor.unsqueeze(0)

# 进行预测
with torch.no_grad():
    prediction = model(input_tensor)[0].data.cpu().numpy()

# 应用颜色映射
# 将预测矩阵转换为二维数组
prediction = np.argmax(prediction, axis=0)

# 定义一个将类别索引转换为颜色的函数
def index_to_color(index):
    return srgb_colormap[index]

# 使用 np.vectorize 将函数应用到整个预测矩阵上
index_to_color_vec = np.vectorize(index_to_color, signature='()->(n)')
predicted_colored = index_to_color_vec(prediction)


# predicted_colored = np.take(srgb_colormap, prediction[0], axis=0)

# 保存预测结果
output_image_path = "./predicted_image.png"
cv2.imwrite(output_image_path, predicted_colored[..., ::-1])

# 在这里添加将分析结果保存到 CSV 文件的代码


IndexError: cannot do a non-empty take from an empty axes.