# 随机森林分类

In [9]:
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import (
    accuracy_score,
    precision_score,
    recall_score,
    classification_report,
)
import numpy as np

In [10]:
# 加载数据
input_file = "dataset.csv"
data = pd.read_csv(input_file)

# 处理数据
# 将stream_index转换为数值类型
data["stream_index"] = (
    pd.to_numeric(data["stream_index"], errors="coerce").fillna(0).astype(int)
)

# 对SNI进行编码，转换为数值类型
data["sni"] = data["sni"].fillna("")  # 填充空值
data["sni_encoded"] = data["sni"].astype("category").cat.codes

In [11]:
# 特征值和目标值
features = ["stream_index", "sni_encoded"] + [f"len{i+1}" for i in range(30)]
X = data[features]
y = data["classes"]

# 分割数据集，70%训练集，30%测试集
X_train, X_test, y_train, y_test = train_test_split(
    X, y, test_size=0.3, random_state=42
)

# 初始化随机森林分类器
clf = RandomForestClassifier(n_estimators=200, random_state=42)

# 训练模型
clf.fit(X_train, y_train)

In [12]:
# 预测
y_pred = clf.predict(X_test)

# 评估模型
accuracy = accuracy_score(y_test, y_pred)
precision = precision_score(y_test, y_pred, average="weighted")
recall = recall_score(y_test, y_pred, average="weighted")

print("Accuracy:", accuracy)
print("Precision:", precision)
print("Recall:", recall)
print(classification_report(y_test, y_pred))

Accuracy: 0.8057553956834532
Precision: 0.8143199725933539
Recall: 0.8057553956834532
                  precision    recall  f1-score   support

bbs.elecfans.com       0.00      0.00      0.00         2
  bbs.kanxue.com       0.70      0.70      0.70        10
   blog.csdn.net       1.00      0.98      0.99        48
         dxy.com       1.00      1.00      1.00         1
     my.4399.com       0.52      0.85      0.65        26
    www.7k7k.com       0.71      0.38      0.50        26
   www.haodf.com       0.00      0.00      0.00         1
     www.news.cn       1.00      1.00      1.00        22
   www.qimao.com       1.00      1.00      1.00         3

        accuracy                           0.81       139
       macro avg       0.66      0.66      0.65       139
    weighted avg       0.81      0.81      0.79       139



  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
