In [10]:
import os
import sys
import json
import random

import torch
import cv2
import pdfplumber

import numpy as np
from pathlib import Path


colors = np.random.randint(125, 255, (80, 3))


import torch
import torchvision.models as models
from torchvision import transforms
from PIL import Image


# 图片路径
image_path = 'bus.jpg'  # 替换为你自己的图片路径


# 加载预训练的 ResNet50 模型
# 创建 ResNet50 模型实例（不包括预训练权重）
model = models.resnet50(pretrained=False)

# 加载本地保存的模型参数
model.load_state_dict(torch.load('Resnet/pretrain_models/resnet50.pth', map_location=torch.device('cpu')))
model.eval()  # 设置模型为评估模式


# 图像预处理函数
preprocess = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])


# 图像分类函数
def predict_image_class(image_path, model, preprocess):
    # 加载图像并进行预处理
    image = Image.open(image_path)
    image = preprocess(image).unsqueeze(0)  # 增加一个维度，适应模型输入格式

    # 将模型设置为评估模式
    model.eval()

    # 使用模型进行推理
    with torch.no_grad():
        outputs = model(image)

    # 获取预测结果
    _, predicted = torch.max(outputs, 1)
    
    
    # 使用softmax获取每个类别的概率分布
    softmax = torch.nn.Softmax(dim=1)
    probabilities = softmax(outputs)

    # 获取预测类别的置信度或概率
    confidence = torch.max(probabilities).item()

    # 获取所有类别的置信度或概率
    all_probabilities = probabilities.squeeze().tolist()
    
    return predicted.item(), confidence, all_probabilities

# 预测图像类别
predicted_class, confidence, all_probabilities = predict(image_path, model, preprocess)
print("Predicted class index:", predicted_class)
print("Confidence of predicted class:", confidence)


with open("data/imagenet-simple-labels.json") as f:
    class_labels = json.load(f)

predicted_label = class_labels[predicted_class]
print("Predicted class label:", predicted_label)




Predicted class index: 654
Confidence of predicted class: 0.6298723816871643
Predicted class label: minibus


In [42]:
import os
import sys
import json
import shutil

import random
import torch
import cv2
import pdfplumber

import numpy as np
from pathlib import Path


colors = np.random.randint(125, 255, (80, 3))


import torch
import torchvision.models as models
from torchvision import transforms
from PIL import Image


# 图片路径
image_path = 'bus.jpg'  # 替换为你自己的图片路径


# 加载预训练的 ResNet50 模型
# 创建 ResNet50 模型实例（不包括预训练权重）
model = models.resnet50(pretrained=False)

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# 加载本地保存的模型参数
# model.load_state_dict(torch.load('Resnet/pretrain_models/resnet50.pth', map_location=torch.device('cpu')))
model.load_state_dict(torch.load('Resnet/pretrain_models/resnet50.pth', map_location=device))
model.eval()  # 设置模型为评估模式


# 图像预处理函数
preprocess = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])


def pre_process(images_path, pre_process_path):
    res_dict = dict()
    file_list = os.listdir(images_path)
#     img_count = len(file_list)
    
    for idx, filename in enumerate(file_list):
        file_extension = filename.split(".")[-1]
        file_name = filename.split(".")[0]
        res_dict[filename] = {"file_type": file_extension.upper()}

        detection_type_list = ["pdf", "png", "jpg", "jpeg"]
        img_extension_list = ["png", "jpg", "jpeg"]
        if file_extension not in detection_type_list:
            continue
            
        if file_extension in img_extension_list:
            shutil.copy(os.path.join(images_path, filename), os.path.join(pre_process_path, f"{file_name}_page_0.png"))
        else:
            
            with pdfplumber.open(os.path.join(images_path, filename)) as pdf:
                for page_num, page in enumerate(pdf.pages):
                    img = page.to_image()
                    pdf_image = page.to_image().original
    
                    # 将原始图像数据转换为 OpenCV 图像对象
                    open_cv_image = cv2.cvtColor(np.array(pdf_image), cv2.COLOR_RGB2BGR)
                    cv2.imwrite(os.path.join(pre_process_path, f"{file_name}_page_{page_num}.png"), open_cv_image)
                
    return res_dict


# 图像分类函数
def predict_image_class(task_id, images_path, pre_process_path, model, preprocess, res_dict):
    # 加载图像并进行预处理
    pre_process_files = os.listdir(pre_process_path)
    pre_process_files_count = len(pre_process_files)

    for idx, img_name in enumerate(pre_process_files):
        print(f"Task_id:{task_id} {idx}/{pre_process_files_count} img_name: {img_name}")
        
        file_extension = img_name.split(".")[-1]
        
        detection_type_list = ["pdf", "png", "jpg", "jpeg"]
        img_extension_list = ["png", "jpg", "jpeg"]
        if file_extension not in detection_type_list:
            continue
            
        file_name = img_name.split("_page")[0]
        
        matching_keys = [key for key in res_dict.keys() if file_name in key]
        if not matching_keys:
            continue
            
        related_pdf_name = matching_keys[0]
        print(f"模糊匹配到的键：{matching_keys}")

        print("aa", img_name)
        image = Image.open(os.path.join(pre_process_path, img_name))
        image = preprocess(image).unsqueeze(0)  # 增加一个维度，适应模型输入格式
        
        
        # 将模型设置为评估模式
        model.eval()
    
        # 使用模型进行推理
        with torch.no_grad():
            outputs = model(image)
    
        # 获取预测结果
        _, predicted = torch.max(outputs, 1)
        
        
        # 使用softmax获取每个类别的概率分布
        softmax = torch.nn.Softmax(dim=1)
        probabilities = softmax(outputs)
    
        # 获取预测类别的置信度或概率
        confidence = torch.max(probabilities).item()
    
        # 获取所有类别的置信度或概率
        all_probabilities = probabilities.squeeze().tolist()
        
        
        with open("data/imagenet-simple-labels.json") as f:
            class_labels = json.load(f)
        
        predicted_class = predicted.item()
        predicted_label = class_labels[predicted_class]
        
        print("Predicted class index:", predicted_class)
        print("Confidence of predicted class:", confidence)
        print("Predicted class label:", predicted_label)
        
        detection_res = {
            "class_id": int(predicted_class),
            "label": predicted_label,
            "score": confidence,
        }
        
        if not res_dict[related_pdf_name].get("result"):
                res_dict[related_pdf_name]["result"] = list()

        res_dict[related_pdf_name]["result"].append(detection_res)
        
    return res_dict

def handler(detect_floder, task_id, node):
    images_path = os.path.join("temp_storage", detect_floder)
    
    pre_process_path = os.path.join(images_path, "pre_process")
    if not os.path.exists(pre_process_path):
        os.mkdir(pre_process_path)

    # 预测图像类别

    res_dict = pre_process(images_path, pre_process_path)
    res = predict_image_class(task_id, images_path, pre_process_path, model, preprocess, res_dict)
    print(res)
    
    # Todo: 释放资源， 结果增加到 MongoDB


if __name__ == "__main__":
    detect_floder = "detect_demo1"
    task_id = "123"
    node = "worker1"
    handler(detect_floder, task_id, node)
    


Task_id:123 0/13 img_name: table1_page_2.png
模糊匹配到的键：['table1.pdf']
aa table1_page_2.png
Predicted class index: 916
Confidence of predicted class: 0.8667305111885071
Predicted class label: website
Task_id:123 1/13 img_name: test_page_0.png
模糊匹配到的键：['test.jpeg']
aa test_page_0.png
Predicted class index: 922
Confidence of predicted class: 0.18185554444789886
Predicted class label: menu
Task_id:123 2/13 img_name: table1_page_1.png
模糊匹配到的键：['table1.pdf']
aa table1_page_1.png
Predicted class index: 918
Confidence of predicted class: 0.48874181509017944
Predicted class label: crossword
Task_id:123 3/13 img_name: table2_page_0.png
模糊匹配到的键：['table2.pdf']
aa table2_page_0.png
Predicted class index: 918
Confidence of predicted class: 0.9081958532333374
Predicted class label: crossword
Task_id:123 4/13 img_name: 不动产登记申请表_page_1.png
模糊匹配到的键：['不动产登记申请表.pdf']
aa 不动产登记申请表_page_1.png
Predicted class index: 789
Confidence of predicted class: 0.6891041398048401
Predicted class label: shoji
Task_id:123 5

In [37]:
a = {'户口本.pdf': {'file_type': 'PDF', 'result': [{'class_id': 918, 'label': 'crossword', 'score': 0.5679798126220703}]}, '02公示无异议证明.docx': {'file_type': 'DOCX'}, '03 宗地图.pdf': {'file_type': 'PDF', 'result': [{'class_id': 918, 'label': 'crossword', 'score': 0.9857348799705505}]}, '02-身份证.pdf': {'file_type': 'PDF', 'result': [{'class_id': 549, 'label': 'envelope', 'score': 0.6350112557411194}]}, '不动产登记申请表.pdf': {'file_type': 'PDF', 'result': [{'class_id': 789, 'label': 'shoji', 'score': 0.6891041398048401}, {'class_id': 918, 'label': 'crossword', 'score': 0.8979585766792297}, {'class_id': 916, 'label': 'website', 'score': 0.7475031614303589}]}, 'table2.pdf': {'file_type': 'PDF', 'result': [{'class_id': 918, 'label': 'crossword', 'score': 0.9081958532333374}, {'class_id': 916, 'label': 'website', 'score': 0.725080132484436}]}, 'pre_process': {'file_type': 'PRE_PROCESS'}, 'test.jpeg': {'file_type': 'JPEG'}, 'table1.pdf': {'file_type': 'PDF', 'result': [{'class_id': 916, 'label': 'website', 'score': 0.8667305111885071}, {'class_id': 918, 'label': 'crossword', 'score': 0.48874181509017944}, {'class_id': 916, 'label': 'website', 'score': 0.43206724524497986}]}}


"test" in a.keys()

search_key = "test"
matching_keys = [key for key in a.keys() if search_key in key]
if matching_keys:
    print(f"模糊匹配到的键：{matching_keys}")
else:
    print("未找到匹配的键")

模糊匹配到的键：['test.jpeg']


Task_id:123 0/12 img_name: table1_page_2.png
aa table1_page_2.png
Predicted class index: 916
Confidence of predicted class: 0.8667305111885071
Predicted class label: website
Task_id:123 1/12 img_name: test_page_0.png
Task_id:123 2/12 img_name: table1_page_1.png
aa table1_page_1.png
Predicted class index: 918
Confidence of predicted class: 0.48874181509017944
Predicted class label: crossword
Task_id:123 3/12 img_name: table2_page_0.png
aa table2_page_0.png
Predicted class index: 918
Confidence of predicted class: 0.9081958532333374
Predicted class label: crossword
Task_id:123 4/12 img_name: 不动产登记申请表_page_1.png
aa 不动产登记申请表_page_1.png
Predicted class index: 789
Confidence of predicted class: 0.6891041398048401
Predicted class label: shoji
Task_id:123 5/12 img_name: table2_page_1.png
aa table2_page_1.png
Predicted class index: 916
Confidence of predicted class: 0.725080132484436
Predicted class label: website
Task_id:123 6/12 img_name: 02-身份证_page_0.png
aa 02-身份证_page_0.png
Predicted class index: 549
Confidence of predicted class: 0.6350112557411194
Predicted class label: envelope
Task_id:123 7/12 img_name: 户口本_page_0.png
aa 户口本_page_0.png
Predicted class index: 918
Confidence of predicted class: 0.5679798126220703
Predicted class label: crossword
Task_id:123 8/12 img_name: 不动产登记申请表_page_0.png
aa 不动产登记申请表_page_0.png
Predicted class index: 918
Confidence of predicted class: 0.8979585766792297
Predicted class label: crossword
Task_id:123 9/12 img_name: 03 宗地图_page_0.png
aa 03 宗地图_page_0.png
Predicted class index: 918
Confidence of predicted class: 0.9857348799705505
Predicted class label: crossword
Task_id:123 10/12 img_name: 不动产登记申请表_page_2.png
aa 不动产登记申请表_page_2.png
Predicted class index: 916
Confidence of predicted class: 0.7475031614303589
Predicted class label: website
Task_id:123 11/12 img_name: table1_page_0.png
aa table1_page_0.png
Predicted class index: 916
Confidence of predicted class: 0.43206724524497986
Predicted class label: website