In [1]:
import pandas as pd
from sklearn.model_selection import train_test_split
from datasets import Dataset, DatasetDict
data = pd.read_csv("./data/sft/NSL-KDD-100000-sft.csv")
# 划分训练集和验证集
train_texts, eval_texts, train_labels, eval_labels = train_test_split(
    data["flow"].tolist(), data["class"].tolist(), test_size=0.2, random_state=42
)
# 转换为 Hugging Face Dataset 格式
dataset = DatasetDict({
    "train": Dataset.from_dict({"text": train_texts, "label": train_labels}),
    "eval": Dataset.from_dict({"text": eval_texts, "label": eval_labels})
})

In [2]:
# 遍历数据，找到最长文本的长度（基于逗号分词）
text_lengths = [len(text.split(",")) for text in data["flow"].tolist()]
max_length = max(text_lengths)
print(f"实际设定的 max_length: {max_length}")

实际设定的 max_length: 41


In [3]:
from transformers import BertTokenizer
# 加载 BERT 分词器
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
# 逗号分词 + 重新转换为 BERT 需要的 `input_ids`
def custom_tokenize_function(examples):
    # 逗号分词
    tokenized_texts = [text.split(",") for text in examples["text"]]
    # 将每个短语转换为 BERT input_ids
    encodings = tokenizer(
        tokenized_texts,  # 逗号分词后的文本
        padding="max_length",
        max_length=max_length,
        truncation=True,
        is_split_into_words=True  # 关键参数：告诉 tokenizer 文本已经被手动分词
    )
    return encodings

In [4]:
tokenized_datasets = dataset.map(custom_tokenize_function, batched=True)
train_dataset = tokenized_datasets["train"].shuffle(seed=42)
eval_dataset = tokenized_datasets["eval"].shuffle(seed=42)

Map:   0%|          | 0/80000 [00:00<?, ? examples/s]

Map:   0%|          | 0/20000 [00:00<?, ? examples/s]

In [5]:
from transformers import AutoModelForSequenceClassification
model = AutoModelForSequenceClassification.from_pretrained("bert-base-uncased", num_labels=2, hidden_dropout_prob=0.2)

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [6]:
from sklearn.metrics import accuracy_score, precision_recall_fscore_support
import numpy as np
def compute_metrics(eval_pred):
    logits, labels = eval_pred
    predictions = np.argmax(logits, axis=-1)
    acc = accuracy_score(labels, predictions)
    precision, recall, f1, _ = precision_recall_fscore_support(labels, predictions, average="binary")
    return {
        "eval_loss": float(np.mean(logits)),  # 确保 `eval_loss` 存在
        "accuracy": acc,
        "precision": precision,
        "recall": recall,
        "f1": f1,
    }

In [None]:
from transformers import TrainingArguments, Trainer, EarlyStoppingCallback
training_args = TrainingArguments(
    output_dir="./trainer/test_trainer-sft-diy",
    eval_strategy="epoch",
    save_strategy="epoch",  # 每个 epoch 保存一次模型
    save_steps=None,  # 取消按steps保存
    learning_rate=2e-5,  # 学习率  2e-5
    per_device_train_batch_size=64,  # 适当增加 batch_size，默认 8
    per_device_eval_batch_size=64,
    num_train_epochs=5,  # 降低训练轮数，避免过拟合
    weight_decay=0.02,  # 加入 L2 正则化
    load_best_model_at_end=True,  # 训练结束后加载最佳模型
    logging_strategy="epoch",  # 确保每个 epoch 打印 loss
    # logging_steps=500,  # 或者按步数打印
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    compute_metrics=compute_metrics,
    callbacks=[EarlyStoppingCallback(early_stopping_patience=3)]  # 设置早停机制
)

In [9]:
trainer.train()

Epoch,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
1,0.0446,0.012943,0.9968,0.995583,0.998398,0.996989
2,0.0124,0.01273,0.9972,0.995586,0.999152,0.997366
3,0.0087,0.004771,0.99865,0.998775,0.998681,0.998728
4,0.0058,0.00615,0.99825,0.997647,0.999058,0.998352
5,0.0041,0.003904,0.9989,0.998775,0.999152,0.998964


TrainOutput(global_step=6250, training_loss=0.015123786392211914, metrics={'train_runtime': 1200.3454, 'train_samples_per_second': 333.237, 'train_steps_per_second': 5.207, 'total_flos': 8427775992000000.0, 'train_loss': 0.015123786392211914, 'epoch': 5.0})