In [None]:
import re 
import numpy as np
import pandas as pd
from datasets import load_dataset
from datasets import Dataset, DatasetDict
from transformers import BertTokenizer, BertForSequenceClassification, Trainer, TrainingArguments, AutoTokenizer

In [None]:
ds = load_dataset("parquet", data_files = "./dataset/wudao/*.parquet", split="train",)
size_need_data = int(len(ds) * 0.05)
ds

In [None]:
def clean_text(text):
    if not text or not isinstance(text, str):
        return False
    
    # 1. 去掉HTML标签（增强版）
    text = re.sub(r'<[^>]+>', '', text)  # 更稳健的HTML标签匹配
    
    # 2. 去掉空文本（增强检查）
    if not text.strip():
        return False
    
    # 3. 去掉各种类型的电话号码
    text = re.sub(r'\b\d{3}[-\.\s]??\d{3}[-\.\s]??\d{4}\b', '', text)  # 带分隔符的电话
    text = re.sub(r'\b\d{11}\b', '', text)  # 11位手机号
    text = re.sub(r'\b\d{4}[-\.\s]??\d{3}[-\.\s]??\d{4}\b', '', text)  # 带区号的电话
    
    # 4. 扩展广告语和推广内容识别
    ad_patterns = [
        '关注公众号', '扫码.*获取', '添加微信', '点击下方链接',
        '详情请访问', '领取优惠券', '限时折扣', '立即购买', '了解更多',
        '欢迎转载', '版权声明', '免责声明', '文章来源', '发布于',
        'tel', '电话', '联系电话', '热线', '企鹅', 'qq', 'q号'
    ]
    for pattern in ad_patterns:
        text = re.sub(pattern, '', text)
    
    # 5. 处理URL和电子邮件
    text = re.sub(r'http[s]?://(?:[a-zA-Z]|[0-9]|[$-_@.&+]|[!*\\(\\),]|(?:%[0-9a-fA-F][0-9a-fA-F]))+', '', text)
    text = re.sub(r'\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Z|a-z]{2,}\b', '', text)
    
    # 6. 扩展表情符号和特殊字符处理
    text = re.sub(r'[\U00010000-\U0010FFFF]', '', text)  # 高位Unicode表情
    text = re.sub(r'[\u2000-\u2FFF]', '', text)  # 扩展标点、符号
    text = re.sub(r'[\u3000-\u303F]', '', text)  # 中文标点符号（可选）
    text = re.sub(r'[\[\]{}<>#*★◇§♡♥♪♬▶▼▪◆●¡⭐]', '', text)  # 更多特殊符号
    
    # 7. 处理重复标点和无意义字符序列
    text = re.sub(r'[!?。，]{2,}', '.', text)  # 多个重复标点替换为单个
    text = re.sub(r'[\.]{2,}', '.', text)  # 多个句点替换为单个
    text = re.sub(r'[\s]+', ' ', text)  # 多个空白字符替换为单个空格
    text = re.sub(r'\s+', ' ', text).strip() # 替换重复的空格
    
    # 8. 处理无意义短文本
    # 如果清洗后文本太短，可能是无意义内容
    cleaned_text = text.strip()
    if len(cleaned_text) < 64:  # 可根据实际调整阈值
        return False
    
    return cleaned_text

# 测试
raw_text = "关注公众号ABC获取资料😄，手机号13812345678，网页链接<a href='x'>链接</a> ★中国共产党万宁市委员会统一战线工作部\n中国共产党万宁市委员会统一战线工作部是中国共产党万宁市委员会工作部门。"
cleaned_text = clean_text(raw_text)
print(cleaned_text)


In [None]:
for i in range(2):
    ds_clean = ds.filter(lambda x: clean_text(x["text"]))
print(ds_clean)

In [None]:
model_path = "./model/classification/BERT_classifier_for_wudao/checkpoint-6654"
BERT_model = BertForSequenceClassification.from_pretrained(model_path)
tokenizer = AutoTokenizer.from_pretrained(model_path) 

In [None]:
ds_test = ds_clean #.select(range(10000))
test = ds_test.add_column("labels", [0] * len(ds_test))
def tokenize(batch):
    return tokenizer(batch["text"], padding=False, truncation=True, max_length=512)

tokened_dataset = test.map(tokenize, batched=True)
tokened_dataset

In [None]:
test_train_args = TrainingArguments(
    output_dir="./temp_eval",
    per_device_eval_batch_size=8,
    do_train=False,
    do_eval=True,
)

trainer = Trainer(
    model = BERT_model,
    args=test_train_args,
    tokenizer = tokenizer,
)

In [None]:
predictions = trainer.predict(tokened_dataset)
logits = predictions.predictions  # 模型输出的 logits
import numpy as np
pred_labels = np.argmax(logits, axis=1)  # 转换为标签
print(pred_labels)
