In [12]:
import os
import math
import re
import json
import xml.etree.ElementTree as ET
from gensim.summarization import keywords

def preprocess(text):
    text = re.sub(r'[^a-zA-Z\s]', ' ', text)  # 更严格的字符过滤
    return ' '.join(text.lower().split()).strip()  # 合并多余空格

def load_data(data_path, ref_path=None, data_type='train'):
    documents = []
    
    if data_type == 'train':
        # 训练集加载逻辑保持不变
        for file in os.listdir(data_path):
            if file.endswith('.txt'):
                base_name = os.path.splitext(file)[0]
                txt_path = os.path.join(data_path, file)
                key_path = os.path.join(data_path, base_name + '.key')
                
                with open(txt_path, 'r', encoding='utf-8') as f:
                    content = preprocess(f.read())
                
                true_keywords = []
                if os.path.exists(key_path):
                    with open(key_path, 'r', encoding='utf-8') as f:
                        true_keywords = [line.strip().lower() for line in f if line.strip()]
                
                documents.append({
                    'doc_id': base_name,
                    'content': content,
                    'true_keywords': true_keywords
                })
    
    elif data_type == 'test' and ref_path:
        # 加载JSON格式的关键词引用
        with open(ref_path, 'r', encoding='utf-8') as f:
            ref_data = json.load(f)
        
        for file in os.listdir(data_path):
            if file.endswith('.xml'):
                doc_id = os.path.splitext(file)[0]
                xml_path = os.path.join(data_path, file)
                
                try:
                    # 解析XML内容
                    tree = ET.parse(xml_path)
                    root = tree.getroot()
                    
                    # 提取文本内容
                    content_parts = []
                    for sentence in root.findall('.//sentence'):
                        for token in sentence.findall('.//token'):
                            word_node = token.find('word')
                            if word_node is not None and word_node.text:
                                word = word_node.text.strip()
                                if word not in ['--', '']:
                                    content_parts.append(word)
                    content = preprocess(' '.join(content_parts))
                    
                    # 从JSON加载关键词（处理多候选情况）
                    true_keywords = []
                    if doc_id in ref_data:
                        # 合并所有候选关键词并去重
                        candidates = ref_data[doc_id]
                        true_keywords = list({kw.lower().strip() 
                                            for group in candidates 
                                            for kw in group})
                    
                    documents.append({
                        'doc_id': doc_id,
                        'content': content,
                        'true_keywords': true_keywords
                    })
                    
                except Exception as e:
                    print(f"Error processing {file}: {str(e)}")
                    continue
    
    print(f"\nLoaded {len(documents)} {data_type} documents")
    if documents:
        sample = documents[0]
        print(f"Sample Content Preview:\n{sample['content'][:200]}...")
        print(f"Sample Keywords: {sample['true_keywords'][:5]}")
    return documents

def textrank_extract(text, topk=5):
    try:
        if not text or len(text.split()) < 10:
            return []
        # 增加分句处理提升效果
        extracted = keywords(text, 
                            words=topk*2, 
                            ratio=0.2, 
                            split=True, 
                            scores=False,
                            pos_filter=('NN', 'NNS', 'JJ'))  # 过滤名词和形容词
        return [kw.lower().strip() for kw in extracted][:topk]
    except Exception as e:
        print(f"Extraction Error: {str(e)}")
        return []

def evaluate(true_list, pred_list):
    # 添加词干处理
    from nltk.stem import PorterStemmer
    stemmer = PorterStemmer()
    
    true_set = {stemmer.stem(kw) for kw in true_list}
    pred_set = {stemmer.stem(kw) for kw in pred_list}
    
    tp = len(true_set & pred_set)
    
    precision = tp / len(pred_set) if len(pred_set) > 0 else 0
    recall = tp / len(true_set) if len(true_set) > 0 else 0
    f1 = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0
    cos_sim = tp / (math.sqrt(len(pred_set)) * math.sqrt(len(true_set))) if len(pred_set)*len(true_set) > 0 else 0
    
    return {
        'precision': precision,
        'recall': recall,
        'f1': f1,
        'cos_sim': cos_sim
    }

def main():
    base_path = r"E:\HKULearning\2025 spring\STAT8021\group work\Krapivin\krapivin-2009-pre-master"
    
    # 路径配置
    train_path = os.path.join(base_path, "src", "all_docs_abstacts_refined")
    test_path = os.path.join(base_path, "test")
    ref_path = os.path.join(base_path, "references", "test.author.stem.json")  # 使用词干化版本
    
    # 加载数据
    train_docs = load_data(train_path, data_type='train')
    test_docs = load_data(test_path, ref_path=ref_path, data_type='test')
    
    # 评估逻辑
    topk = 5
    metrics = ['precision', 'recall', 'f1', 'cos_sim']
    
    def calculate_avg(docs):
        results = {m: [] for m in metrics}
        for doc in docs:
            pred = textrank_extract(doc['content'], topk)
            res = evaluate(doc['true_keywords'], pred)
            for m in metrics:
                results[m].append(res[m])
        return {f'avg_{m}': sum(results[m])/len(docs) for m in metrics}
    
    train_results = calculate_avg(train_docs)
    test_results = calculate_avg(test_docs)
    
    print(f"\n{'Metric':<15} {'Training':<10} {'Test':<10}")
    for m in metrics:
        print(f"{m.capitalize():<15} {train_results[f'avg_{m}']:.4f}      {test_results[f'avg_{m}']:.4f}")

if __name__ == "__main__":
    main()


Loaded 2304 train documents
Sample Content Preview:
t enhancing product recommender systems on sparse binary data a commercial recommender systems use various data mining techniques to make appropriate recommendations to users during online real time s...
Sample Keywords: ['collaborative filtering', 'customer relationship management', 'e-commerce', 'recommender systems', 'dependency networks']

Loaded 2304 test documents
Sample Content Preview:
t enhancing product recommender systems on sparse binary data a commercial recommender systems use various data mining techniques to make appropriate recommendations to users during online real time s...
Sample Keywords: ['recommend system', 'depend network', 'custom relationship manag', 'collabor filter', 'e-commerc']

Metric          Training   Test      
Precision       0.0786      0.0415
Recall          0.0446      0.0261
F1              0.0531      0.0297
Cos_sim         0.0569      0.0315
