In [2]:
# 加载模型
import inference
from utils import doc_topics, paragraph_topics

# 输入文档主题分类模型所在路径
doc_classifier = inference.NewsClassifierForDoc("./output/doc_all/checkpoint-best/")
# 输入段落主题分类模型所在路径
para_classifier = inference.NewsClassifierForPara("./output/para_all/checkpoint-best/")

2022-08-28 14:37:59.758902: I tensorflow/stream_executor/platform/default/dso_loader.cc:49] Successfully opened dynamic library libcudart.so.11.0


2022-08-28 14:38:03,576 INFO     Create model.........
2022-08-28 14:38:06,111 INFO     Create model.........


In [3]:
from tqdm import tqdm
import warnings
from utils import preprocess_pdf

# 分类严格程度，建议在 0.5-0.9 之间，数字越大代表分类越严格
threshold = 0.7
# 跳过长度过短的段落
min_paragraph_length = 20

# 对单一pdf文件预测主题
def predict(pdf_path):
    print(pdf_path)
    preprocessed_text = preprocess_pdf(pdf_path)
    doc_topic, doc_score = doc_classifier.classify(" ".join(preprocessed_text[:3]))
    if doc_score < threshold:
        warnings.warn("活动类型不明确: {}".format(pdf_path))
    print("活动类型为：" + doc_topic)

    para_topics = []
    for para in preprocessed_text[2:]:
        topic, score = para_classifier.classify(para)
        if score > threshold and len(para) > min_paragraph_length:
            para_topics.append(topic)
    para_topics = set(para_topics)
    print("活动主题为："+ " ".join(para_topics))


# 对所有pdf文件预测主题
def predict_all(pdf_paths):
    # 初始化统计结果
    stats = dict()
    for doc_topic in doc_topics:
        stats[doc_topic] = dict()
        for para_topic in paragraph_topics:
            stats[doc_topic][para_topic] = 0

    # 遍历pdf文件进行预测
    for pdf_path in tqdm(pdf_paths):
        preprocessed_text = preprocess_pdf(pdf_path)
        # 默认使用文档的前三段预测文档主题
        doc_topic, doc_score = doc_classifier.classify(" ".join(preprocessed_text[:3]))
        # 如果使用前三段预测文档主题失败，则遍历文档的所有段落
        if doc_score < threshold:
            for para in preprocessed_text:
                doc_topic, doc_score = doc_classifier.classify(para)
                if doc_score >= threshold:
                    break
            # 如果文档主题不明确，跳过该文档
            if doc_score < threshold:
                warnings.warn("活动类型不明确: {}".format(pdf_path))
                continue
        
        # 默认从文档的第二段开始预测段落主题
        para_topics = []
        for para in preprocessed_text[2:]:
            topic, score = para_classifier.classify(para)
            # 如果段落主题分类不明确或段落文本长度过短，跳过该段落
            if score >= threshold and len(para) >= min_paragraph_length:
                para_topics.append(topic)
        para_topics = set(para_topics) # 去重
        
        # 统计结果
        for para_topic in para_topics:
            stats[doc_topic][para_topic] += 1
        
    return stats

In [4]:
predict("./new_data/会见/01/进一步加强科创合作交流_记者__孟群舒.pdf") # 输入单个pdf文件所在的路径

./new_data/会见/01/进一步加强科创合作交流_记者__孟群舒.pdf
2022-08-28 14:38:09,934 INFO     loading: 'to-unicode-Adobe-GB1'
活动类型为：会见
主题包括：经济


In [5]:
import glob
stats = predict_all(glob.glob("./new_data/*/*/*.pdf")) # 输入pdf文件所在的路径

100%|██████████| 210/210 [07:20<00:00,  2.10s/it]


In [6]:
import pandas as pd
from utils import doc_topics, paragraph_topics

# 将统计结果输出到指定路径的excel文件
def save_excel(stats, path):
    df = pd.DataFrame(columns=[" "] + ["{}.{}".format(i+1, t) for i, t in enumerate(doc_topics)], index=list(range(1, len(paragraph_topics)+1)))
    df[" "] = paragraph_topics
    for col_idx, doc_topic in enumerate(doc_topics):
        for row_idx in range(1, len(paragraph_topics)+1):
            df["{}.{}".format(col_idx+1, doc_topic)][row_idx] = stats[doc_topic][paragraph_topics[row_idx-1]]
    df.to_excel(path)

In [7]:
save_excel(stats, "./stats.xlsx") # 指定excel文件输出路径