In [17]:
import os
import re
import numpy as np
from summa import keywords
from nltk.stem import PorterStemmer
import nltk

nltk.download('punkt')  # 确保下载必要资源

# ========== 配置参数 ==========
DOMAIN_TERMS = {
    'laser', 'quantum', 'modulus', 'polymer', 'nanoparticle', 'spectroscopy',
    'synthesis', 'alloy', 'composite', 'asphalt', 'cullet', 'stiffness', 
    'bitumen', 'asphaltic', 'glass', 'particle'
}

EXTENDED_STOPWORDS = {
    'however', 'furthermore', 'conclusion', 'experiment', 'methodology',
    'result', 'study', 'research', 'data', 'analysis', 'table', 'figure', 
    'sample', 'test', 'show', 'based', 'using'
}

# ========== 核心函数 ==========
def load_data(data_path):
    """增强型数据加载（带词干化和格式验证）"""
    texts = []
    keywords_list = []
    stemmer = PorterStemmer()
    
    for filename in sorted(os.listdir(data_path)):
        if not filename.endswith(".txt"):
            continue
            
        # 读取文本文件
        text_path = os.path.join(data_path, filename)
        ann_path = os.path.join(data_path, filename.replace(".txt", ".ann"))
        
        try:
            with open(text_path, "r", encoding="utf-8", errors="ignore") as f:
                texts.append(f.read().strip())
        except Exception as e:
            print(f"Error reading {text_path}: {str(e)}")
            continue
        
        # 处理标注文件
        doc_keywords = []
        if os.path.exists(ann_path):
            with open(ann_path, "r", encoding="utf-8", errors="ignore") as f:
                for line in f:
                    parts = line.strip().split("\t")
                    if len(parts) >= 3:
                        # 解析标注类型（兼容Term/Material/Process）
                        annot_parts = parts[1].split()
                        if len(annot_parts) == 0:
                            continue
                            
                        annot_type = annot_parts[0]
                        if annot_type in ["Term", "Material", "Process", "T"]:
                            keyword = parts[2].lower().strip()
                            # 过滤无效关键词
                            if 2 < len(keyword) < 50 and not keyword.isnumeric():
                                doc_keywords.append(keyword)
        
        # 词干化处理并去重
        stemmed_kws = [stemmer.stem(kw) for kw in doc_keywords]
        keywords_list.append(list(set(stemmed_kws)))
    
    return texts, keywords_list

def clean_text(text):
    """工业级文本清洗管道"""
    # 移除参考文献引用
    text = re.sub(r'\[\d+(-\d+)?\]', '', text)
    # 处理特殊符号
    text = re.sub(r'&[a-z]+;', lambda m: {'&alpha;':'alpha','&beta;':'beta'}.get(m.group(), ''), text)
    # 统一连字符格式
    text = re.sub(r'(\w+)-(\w+)', r'\1_\2', text)
    # 移除表格内容
    text = re.sub(r'\|.*?\|', ' ', text)
    # 处理数字字母组合
    text = re.sub(r'\b(\d+)([a-zA-Z]+)\b', r'\1 \2', text)
    # 基础清洗
    text = re.sub(r'[^\w\s_]', '', text)
    return re.sub(r'\s+', ' ', text).lower().strip()

def extract_keywords(text, ratio=0.3):
    """领域自适应关键词提取"""
    try:
        # 预处理
        cleaned = clean_text(text)
        if len(cleaned.split()) < 10:  # 过滤过短文本
            return []
        
        # 提取基础关键词
        kw = keywords.keywords(
            cleaned,
            ratio=ratio,
            words=True,
            split=True,
            scores=False,
            language="english"
        )
        
        # 后处理
        processed_kws = []
        stemmer = PorterStemmer()
        for k in kw:
            k_clean = k.strip().lower()
            # 过滤条件
            if (len(k_clean) < 3 or 
                k_clean in EXTENDED_STOPWORDS or
                not any(c.isalpha() for c in k_clean)):
                continue
                
            # 领域词优先
            stemmed = stemmer.stem(k_clean)
            if stemmed in DOMAIN_TERMS or k_clean in DOMAIN_TERMS:
                processed_kws.insert(0, k_clean)
            else:
                processed_kws.append(k_clean)
        
        # 合并同类词
        stemmed_kws = [stemmer.stem(kw) for kw in processed_kws]
        return list(set(stemmed_kws))
        
    except Exception as e:
        print(f"提取错误: {str(e)}")
        return []

def evaluate(y_true, y_pred):
    """词干化评估"""
    stemmer = PorterStemmer()
    
    tp = pred = true = 0
    for true_kws, pred_kws in zip(y_true, y_pred):
        # 词干化处理
        true_set = set(stemmer.stem(kw) for kw in true_kws)
        pred_set = set(stemmer.stem(kw) for kw in pred_kws)
        
        if not true_set:  # 跳过无标注文档
            continue
            
        common = true_set & pred_set
        tp += len(common)
        pred += len(pred_set)
        true += len(true_set)
    
    precision = tp / pred if pred > 0 else 0
    recall = tp / true if true > 0 else 0
    f1 = 2*precision*recall/(precision+recall) if (precision+recall) >0 else 0
    
    return precision, recall, f1

# ========== 主程序 ==========
def main():
    # 配置路径
    train_path = r"E:\HKULearning\2025 spring\STAT8021\group work\scienceie.github.io-master\resources\scienceie2017_train\train2"
    test_path = r"E:\HKULearning\2025 spring\STAT8021\group work\scienceie.github.io-master\resources\semeval_articles_test"

    # 加载数据
    print("Loading training data...")
    X_train, y_train = load_data(train_path)
    print(f"Loaded {len(X_train)} training samples | Avg keywords: {np.mean([len(x) for x in y_train]):.1f}")

    print("\nLoading test data...")
    X_test, y_test = load_data(test_path)
    print(f"Loaded {len(X_test)} test samples | Avg keywords: {np.mean([len(x) for x in y_test]):.1f}")

    # 提取关键词
    print("\nExtracting keywords...")
    y_pred = []
    for i, text in enumerate(X_test):
        kws = extract_keywords(text)
        y_pred.append(kws)
        if i < 3:  # 打印前3个样本结果
            print(f"Doc {i+1} Pred: {kws[:5]}...")

    # 评估
    precision, recall, f1 = evaluate(y_test, y_pred)

    # 输出结果
    print("\n评估结果:")
    print(f"Precision: {precision:.2%}")
    print(f"Recall:    {recall:.2%}")
    print(f"F1 Score:  {f1:.2%}")

    # 诊断信息
    valid_docs = sum(1 for x in y_test if x)
    print(f"\n有效标注文档: {valid_docs}/{len(y_test)}")
    print(f"预测关键词示例: {y_pred[0][:5] if y_pred else 'None'}")

if __name__ == "__main__":
    main()

[nltk_data] Downloading package punkt to
[nltk_data]     C:\Users\sphy9\AppData\Roaming\nltk_data...
[nltk_data]   Package punkt is already up-to-date!


Loading training data...
Loaded 350 training samples | Avg keywords: 12.8

Loading test data...
Loaded 100 test samples | Avg keywords: 14.1

Extracting keywords...
Doc 1 Pred: ['surfac']...
Doc 2 Pred: ['surfac']...
Doc 3 Pred: ['alloy']...

评估结果:
Precision: 17.44%
Recall:    1.07%
F1 Score:  2.01%

有效标注文档: 100/100
预测关键词示例: ['surfac']
