In [None]:
import logging
import os
import evaluate
import numpy as np
import torch
from datasets import load_dataset
from transformers import (
    AutoModelForSequenceClassification,
    AutoTokenizer,
    TrainingArguments,
    Trainer,
    DataCollatorWithPadding
)

# --- 配置基础设置 ---
logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s [%(levelname)s] %(message)s",
    handlers=[
        logging.FileHandler("./logs/training.log"),  # 保存到文件
        logging.StreamHandler()                      # 输出到控制台
    ]
)
logger = logging.getLogger(__name__)

# 检查 GPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# 创建必要的目录 (为 SST-2 创建特定目录)
os.makedirs("./result_sst2", exist_ok=True)
os.makedirs("./logs_sst2", exist_ok=True)
os.makedirs("./teacher_checkpoints_sst2", exist_ok=True) # 教师训练检查点目录
os.makedirs("./teacher_logs_sst2", exist_ok=True)      # 教师训练日志目录


In [None]:
# --- 数据加载与预处理 (SST-2) ---
logger.info("Loading SST-2 dataset from GLUE...")
# SST-2 是 GLUE 基准的一部分
sst2_dataset = load_dataset("glue", "sst2")
# 查看数据集结构，确认列名 (通常是 'sentence', 'label', 'idx')
print("SST-2 Dataset structure:")
print(sst2_dataset)

# 定义教师模型 ID
teacher_model_id = 'microsoft/deberta-v3-base'

logger.info(f"Loading tokenizer for {teacher_model_id}...")
tokenizer = AutoTokenizer.from_pretrained(teacher_model_id)


In [None]:
# 定义分词函数 (注意：SST-2 的文本列通常是 'sentence')
def tokenize_function_sst2(examples):
    # 对 "sentence" 列进行分词
    return tokenizer(examples["sentence"], padding="max_length", truncation=True, max_length=128) # SST-2 句子较短，128 通常足够，但 512 也安全

logger.info("Tokenizing SST-2 dataset...")
tokenized_datasets_sst2 = sst2_dataset.map(tokenize_function_sst2, batched=True)

# 数据整理器
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)

# 评估指标 (仍然是准确率)
accuracy_metric = evaluate.load("accuracy")
def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    predictions = np.argmax(predictions, axis=1)
    return accuracy_metric.compute(predictions=predictions, references=labels)

# 划分数据集 (SST-2 通常有 train, validation, test)
# 我们使用 train 进行训练，validation 进行评估和选择最佳模型
tokenized_train_sst2 = tokenized_datasets_sst2["train"]
tokenized_val_sst2 = tokenized_datasets_sst2["validation"]
# test 集通常没有标签，不用于此处的微调过程

logger.info(f"SST-2 Train dataset size: {len(tokenized_train_sst2)}")
logger.info(f"SST-2 Validation dataset size: {len(tokenized_val_sst2)}")



In [None]:
# --- 教师模型微调与保存 (SST-2) ---

# 定义教师模型最终保存路径 (SST-2 特定)
teacher_model_finetuned_path_sst2 = 'deberta-v3-base-finetuned-sst2'

logger.info("Starting teacher model fine-tuning on SST-2...")

# 1. 加载预训练的 DeBERTa-v3-base 模型用于序列分类
teacher_model_for_finetune_sst2 = AutoModelForSequenceClassification.from_pretrained(
    teacher_model_id,
    num_labels=2  # SST-2 是二分类任务
).to(device)

# 2. 配置训练参数
teacher_training_args_sst2 = TrainingArguments(
    output_dir='./teacher_checkpoints_sst2',
    num_train_epochs=3,
    warmup_ratio=0.1,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=32,
    learning_rate=2e-5,
    weight_decay=0.01,
    logging_dir='./teacher_logs_sst2',
    logging_steps=50,
    eval_strategy="epoch",
    save_strategy="epoch",
    load_best_model_at_end=True,
    metric_for_best_model="accuracy",
    fp16=torch.cuda.is_available(),
    report_to="tensorboard",
)

# 3. 初始化Trainer
teacher_trainer_sst2 = Trainer(
    model=teacher_model_for_finetune_sst2,
    args=teacher_training_args_sst2,
    train_dataset=tokenized_train_sst2,
    eval_dataset=tokenized_val_sst2,
    data_collator=data_collator,
    tokenizer=tokenizer,
    compute_metrics=compute_metrics,
)

# 4. 开始训练
teacher_trainer_sst2.train()
logger.info("Fine-tuning completed. Saving final model...")

# 5. 保存模型前清理旧文件
import shutil
if os.path.exists(teacher_model_finetuned_path_sst2):
    shutil.rmtree(teacher_model_finetuned_path_sst2)

# 保存模型和分词器
teacher_trainer_sst2.save_model(teacher_model_finetuned_path_sst2)
tokenizer.save_pretrained(teacher_model_finetuned_path_sst2)

print(f"Teacher model training completed. Model saved at: {teacher_model_finetuned_path_sst2}")