In [1]:
import pandas as pd
import re
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.decomposition import TruncatedSVD
from sklearn.discriminant_analysis import QuadraticDiscriminantAnalysis, LinearDiscriminantAnalysis
from sklearn.metrics import accuracy_score, classification_report
from pathlib import Path
from sklearn.preprocessing import OneHotEncoder
import numpy as np
from scipy.sparse import hstack

DATA = Path('data')

## Data Clean
- 将所有文本转为小写。
- 移除URL、@用户、#话题标签、标点符号和数字，只保留文本内容。

In [21]:
# --- 1. 加载数据 ---
print("正在加载数据...")
train_df = pd.read_csv(DATA / 'train.csv')
val_df = pd.read_csv(DATA / 'val.csv')
train_df.head()

train_df = train_df.sample(frac=1, random_state=18260817).reset_index(drop=True)
val_df = val_df.sample(frac=1, random_state=18260817).reset_index(drop=True)

正在加载数据...


In [22]:
# 填充可能存在的NaN值
train_df['text'] = train_df['text'].fillna('')
val_df['text'] = val_df['text'].fillna('')
train_df['event'] = train_df['event'].fillna('unknown')
val_df['event'] = val_df['event'].fillna('unknown')

In [23]:
# --- 2. 文本预处理函数 ---
def preprocess_text(text):
    """清洗文本数据"""
    text = text.lower()  # 转为小写
    text = re.sub(r'http\S+|www\S+|https\S+', '', text, flags=re.MULTILINE) # 移除URL
    text = re.sub(r'\@\w+|\#', '', text) # 移除@和#
    text = re.sub(r'[^a-z\s]', '', text) # 只保留英文字母和空格
    return text

In [24]:
print("正在预处理文本...")
train_df['clean_text'] = train_df['text'].apply(preprocess_text)
val_df['clean_text'] = val_df['text'].apply(preprocess_text)
y_train = train_df['label']
y_val = val_df['label']
train_df.head()

正在预处理文本...


Unnamed: 0,id,text,label,event,clean_text
0,544286686142685184,CORRECTION: We reported earlier Sydney air spa...,1,5,correction we reported earlier sydney air spac...
1,499410965075886080,Missouri Mayor: Looters came from out of town....,0,1,missouri mayor looters came from out of town ...
2,536837309078200320,Swiss museum accepts art from Gurlitt http://t...,1,0,swiss museum accepts art from gurlitt via
3,544282877278814208,Flag in Sydney cafe where hostages are being h...,1,5,flag in sydney cafe where hostages are being h...
4,500280477920796672,"Because, of course, when someone commits a rob...",1,1,because of course when someone commits a robbe...


In [None]:
# --- 2. Feature Engineering (in parallel) ---

# 2.1 Text Path (TF-IDF -> SVD)
print("正在处理文本特征...")
tfidf_vectorizer = TfidfVectorizer(stop_words='english')
X_train_tfidf = tfidf_vectorizer.fit_transform(train_df['clean_text'])
X_val_tfidf = tfidf_vectorizer.transform(val_df['clean_text'])

n_components = 150
svd = TruncatedSVD(n_components=n_components, random_state=42)
X_train_svd = svd.fit_transform(X_train_tfidf)
X_val_svd = svd.transform(X_val_tfidf)


正在处理文本特征...


In [26]:
# # 2.2 Event Path (One-Hot Encoding)
# print("正在处理 'event' 特征...")
# onehot_encoder = OneHotEncoder(handle_unknown='ignore')
# # Note: The input to OneHotEncoder should be a 2D array, hence train_df[['event']]
# X_train_event = onehot_encoder.fit_transform(train_df[['event']])
# X_val_event = onehot_encoder.transform(val_df[['event']])

In [27]:
# --- 3. Combine Features ---
# print("正在合并文本特征和 'event' 特征...")
# # np.hstack requires dense arrays for stacking.
# X_train_combined = np.hstack((X_train_svd, X_train_event.toarray()))
# X_val_combined = np.hstack((X_val_svd, X_val_event.toarray()))



In [28]:
# print(f"最终特征维度: {X_train_combined.shape[1]}") # This will be n_components + number_of_event_categories


In [39]:
# --- 4. Model Training and Evaluation (QDA) ---

X_train_combined = X_train_svd
X_val_combined = X_val_svd
print("正在训练和评估 QDA 模型...")
qda_model = QuadraticDiscriminantAnalysis(reg_param=1e-5)
# Note: QDA can sometimes fail if a class has too few samples for a given feature combination,
# leading to a singular covariance matrix. If that happens, LDA is a more robust alternative.
try:
    qda_model.fit(X_train_combined, y_train)
    y_pred_qda = qda_model.predict(X_val_combined)
    accuracy_qda = accuracy_score(y_val, y_pred_qda)

    print(f"\n包含 'event' 的 QDA 模型在验证集上的准确率: {accuracy_qda:.4f}")
    print("\nQDA 分类报告:")
    print(classification_report(y_val, y_pred_qda))
except Exception as e:
    print(f"\n无法训练QDA模型，可能因为特征共线性问题: {e}")
    print("请考虑使用LDA作为替代方案。")

正在训练和评估 QDA 模型...

包含 'event' 的 QDA 模型在验证集上的准确率: 0.8321

QDA 分类报告:
              precision    recall  f1-score   support

           0       0.84      0.86      0.85       226
           1       0.82      0.80      0.81       179

    accuracy                           0.83       405
   macro avg       0.83      0.83      0.83       405
weighted avg       0.83      0.83      0.83       405



