In [1]:
# # IMDB 数据集上的知识蒸馏：DeBERTa-v3-base (Teacher) -> TinyBERT-6L (Student)
# ## 使用预分词方法进行蒸馏 (已修正 Data Collator)


In [2]:
# 确保安装了必要的库
!pip install evaluate


Collecting evaluate
  Downloading evaluate-0.4.3-py3-none-any.whl.metadata (9.2 kB)
Collecting fsspec>=2021.05.0 (from fsspec[http]>=2021.05.0->evaluate)
  Downloading fsspec-2024.12.0-py3-none-any.whl.metadata (11 kB)
Downloading evaluate-0.4.3-py3-none-any.whl (84 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m84.0/84.0 kB[0m [31m6.3 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading fsspec-2024.12.0-py3-none-any.whl (183 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m183.9/183.9 kB[0m [31m13.5 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: fsspec, evaluate
  Attempting uninstall: fsspec
    Found existing installation: fsspec 2025.3.2
    Uninstalling fsspec-2025.3.2:
      Successfully uninstalled fsspec-2025.3.2
[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
gcsfs 2024.10.0 requ

In [3]:
# 导入核心库
import logging
import os
import sys
from typing import Dict, Any, List # 用于类型提示

import evaluate
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from datasets import load_dataset, DatasetDict
from transformers import (
    AutoModelForSequenceClassification,
    AutoTokenizer,
    TrainingArguments,
    Trainer,
    DataCollatorWithPadding,  # 用于评估阶段
    PreTrainedTokenizerBase,
    BatchEncoding # 确保导入
)
# from transformers import default_data_collator # 不再需要


2025-04-26 07:29:36.978747: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1745652577.161365      19 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1745652577.213326      19 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


In [4]:
# 配置日志记录器
logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s [%(levelname)s] %(message)s",
    handlers=[
        logging.StreamHandler(sys.stdout) # 直接输出到标准输出
    ],
    force=True # 强制覆盖现有配置
)
logger = logging.getLogger(__name__)


In [5]:
# 检查并设置计算设备 (GPU 优先)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
logger.info(f"使用设备: {device}")


2025-04-26 07:29:52,974 [INFO] 使用设备: cuda


In [6]:
# 创建必要的输出目录 (使用预分词后缀区分)
os.makedirs("./result", exist_ok=True)
os.makedirs("./logs", exist_ok=True)
os.makedirs("./teacher_checkpoints", exist_ok=True) # 假设已有微调教师模型
os.makedirs("./distill_checkpoints_tinybert_pretokenized", exist_ok=True) # 蒸馏检查点
os.makedirs("./eval_output_tinybert_pretokenized", exist_ok=True) # 评估输出
os.makedirs("./distill_logs_tinybert_pretokenized", exist_ok=True) # TensorBoard 日志


In [7]:
# ==============================================================================
# ## 第零步：定义模型和分词器标识符
# ==============================================================================


In [8]:
# 定义教师和学生模型的 Hugging Face Hub ID
teacher_model_id: str = 'microsoft/deberta-v3-base'
student_model_id: str = "huawei-noah/TinyBERT_General_6L_768D"

# 定义教师模型微调后和最终学生模型蒸馏后的保存路径
teacher_model_finetuned_path: str = '/kaggle/input/deberta-v3-base-finetuned-imdb/deberta-v3-base-finetuned-imdb' # 假设这个路径下已有微调好的教师模型
final_student_model_path: str = "tinybert-student-distilled-imdb-pretokenized" # 预分词蒸馏后学生模型的保存路径

# 定义分词器的最大序列长度
MAX_LENGTH: int = 512


In [9]:
# ==============================================================================
# ## 第一步：加载数据集、分词器，并进行 *预分词*
# ==============================================================================


In [10]:
# 1. 加载 IMDB 数据集
logger.info("加载 IMDB 数据集...")
try:
    imdb_dataset: DatasetDict = load_dataset("imdb")
except Exception as e:
    logger.error(f"加载 IMDB 数据集失败: {e}")
    raise


2025-04-26 07:29:53,061 [INFO] 加载 IMDB 数据集...


README.md:   0%|          | 0.00/7.81k [00:00<?, ?B/s]

train-00000-of-00001.parquet:   0%|          | 0.00/21.0M [00:00<?, ?B/s]

test-00000-of-00001.parquet:   0%|          | 0.00/20.5M [00:00<?, ?B/s]

unsupervised-00000-of-00001.parquet:   0%|          | 0.00/42.0M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/25000 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/25000 [00:00<?, ? examples/s]

Generating unsupervised split:   0%|          | 0/50000 [00:00<?, ? examples/s]

In [11]:
# 2. 加载教师模型 (DeBERTa) 的分词器
logger.info(f"加载教师模型的分词器: {teacher_model_id}")
try:
    teacher_tokenizer: PreTrainedTokenizerBase = AutoTokenizer.from_pretrained(teacher_model_id, use_fast=True)
except Exception as e:
    logger.error(f"加载教师分词器 '{teacher_model_id}' 失败: {e}")
    raise


2025-04-26 07:30:03,857 [INFO] 加载教师模型的分词器: microsoft/deberta-v3-base


tokenizer_config.json:   0%|          | 0.00/52.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/579 [00:00<?, ?B/s]

spm.model:   0%|          | 0.00/2.46M [00:00<?, ?B/s]



In [12]:
# 3. 加载学生模型 (TinyBERT) 的分词器
logger.info(f"加载学生模型的分词器: {student_model_id}")
try:
    student_tokenizer: PreTrainedTokenizerBase = AutoTokenizer.from_pretrained(student_model_id, use_fast=True)
    # 建议设置学生分词器的最大长度与教师一致或根据需要调整
    student_tokenizer.model_max_length = MAX_LENGTH
except Exception as e:
    logger.error(f"加载学生分词器 '{student_model_id}' 失败: {e}")
    raise


2025-04-26 07:30:08,171 [INFO] 加载学生模型的分词器: huawei-noah/TinyBERT_General_6L_768D


config.json:   0%|          | 0.00/390 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

In [13]:
# 4. 定义预分词函数
logger.info("定义预分词函数...")
def tokenize_for_distillation(examples: Dict[str, List]) -> Dict[str, List]:
    """
    使用教师和学生分词器同时处理文本，并添加前缀区分。
    不进行填充，填充交给自定义 DataCollator 处理。
    """
    # 使用教师分词器处理
    teacher_encoding = teacher_tokenizer(
        examples["text"], truncation=True, max_length=MAX_LENGTH, padding=False
    )
    # 使用学生分词器处理
    student_encoding = student_tokenizer(
        examples["text"], truncation=True, max_length=MAX_LENGTH, padding=False
    )

    # 构建新的样本字典
    processed: Dict[str, List] = {}
    # 确保标签键名为 'labels'
    processed["labels"] = examples["label"]
    for k, v in teacher_encoding.items():
        processed[f"teacher_{k}"] = v
    for k, v in student_encoding.items():
        processed[f"student_{k}"] = v
    return processed


2025-04-26 07:30:10,814 [INFO] 定义预分词函数...


In [14]:
# 5. 对整个数据集应用预分词函数
logger.info("对 IMDB 数据集进行预分词 (可能需要一些时间)...")
# 使用 map 函数进行批处理，移除原始的 'text' 和 'label' 列
# num_proc 可以根据 CPU 核心数调整以加速
tokenized_datasets: DatasetDict = imdb_dataset.map(
    tokenize_for_distillation,
    batched=True,
    remove_columns=imdb_dataset["train"].column_names, # 移除所有原始列
    num_proc=os.cpu_count() // 2 if os.cpu_count() else 1 # 使用多进程加速
)
logger.info("数据集预分词完成。")

# 查看一个处理后的样本结构
print("\n预分词后的训练集样本示例:")
print(tokenized_datasets["train"][0])


2025-04-26 07:30:10,835 [INFO] 对 IMDB 数据集进行预分词 (可能需要一些时间)...


Map (num_proc=2):   0%|          | 0/25000 [00:00<?, ? examples/s]

Map (num_proc=2):   0%|          | 0/25000 [00:00<?, ? examples/s]

Map (num_proc=2):   0%|          | 0/50000 [00:00<?, ? examples/s]

2025-04-26 07:31:51,004 [INFO] 数据集预分词完成。

预分词后的训练集样本示例:
{'labels': 0, 'teacher_input_ids': [1, 273, 11388, 273, 3846, 51696, 65516, 271, 3021, 79919, 1975, 292, 312, 750, 1106, 401, 265, 305, 262, 9046, 272, 5199, 278, 335, 278, 284, 362, 1315, 267, 9785, 260, 273, 327, 1331, 272, 288, 362, 278, 284, 11258, 293, 543, 260, 430, 260, 9969, 337, 278, 632, 1367, 264, 1916, 291, 658, 261, 1928, 411, 266, 2066, 265, 3107, 1403, 307, 94752, 309, 273, 431, 330, 264, 398, 291, 270, 1113, 260, 4052, 8981, 840, 1504, 4052, 8981, 840, 1504, 635, 4278, 269, 10254, 441, 266, 856, 8442, 4522, 1234, 1784, 25127, 328, 1654, 264, 799, 758, 373, 295, 314, 432, 260, 344, 1070, 373, 1654, 264, 1087, 342, 1251, 268, 264, 570, 347, 1667, 265, 6186, 277, 339, 262, 1210, 69083, 708, 314, 991, 1198, 808, 405, 283, 262, 5681, 1752, 263, 1583, 808, 267, 262, 780, 1017, 260, 344, 457, 2331, 5837, 263, 5111, 72043, 265, 16339, 314, 308, 4713, 277, 3252, 261, 373, 303, 7454, 275, 342, 4522, 2274, 261, 16133, 261, 26

In [15]:
# *** 添加自定义 Data Collator 类定义 ***
import torch
from transformers import PreTrainedTokenizerBase, BatchEncoding
from typing import List, Dict, Any

class DistillationDataCollator:
    """
    自定义数据整理器，用于预分词蒸馏。
    分别使用教师和学生分词器填充对应的输入。
    """
    def __init__(self, teacher_tokenizer: PreTrainedTokenizerBase, student_tokenizer: PreTrainedTokenizerBase):
        self.teacher_tokenizer = teacher_tokenizer
        self.student_tokenizer = student_tokenizer

    def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, torch.Tensor]:
        # 1. 分离教师、学生输入和标签
        teacher_features = []
        student_features = []
        labels = []

        # 确定教师和学生模型期望的标准输入键名
        teacher_keys_to_extract = ["input_ids", "attention_mask"]
        if "token_type_ids" in self.teacher_tokenizer.model_input_names:
            teacher_keys_to_extract.append("token_type_ids")
        student_keys_to_extract = ["input_ids", "attention_mask"]
        if "token_type_ids" in self.student_tokenizer.model_input_names:
            student_keys_to_extract.append("token_type_ids")

        for feature in features:
            # 提取标签 (确保键名是 'labels')
            if "labels" in feature:
                labels.append(feature["labels"])
            else:
                # 如果没有标签，可能需要处理（例如，在推理时）
                # 对于训练，通常应该有标签
                logger.warning("在 collate 过程中未找到 'labels' 键。")
                pass # 或者根据需要添加 None 或默认值

            # 提取教师输入，并重命名键为标准名称以供 tokenizer.pad 使用
            t_feat = {}
            for standard_key in teacher_keys_to_extract:
                prefixed_key = f"teacher_{standard_key}"
                if prefixed_key in feature:
                    t_feat[standard_key] = feature[prefixed_key]
            if t_feat: # 只有在提取到键时才添加
                teacher_features.append(t_feat)

            # 提取学生输入，并重命名键为标准名称
            s_feat = {}
            for standard_key in student_keys_to_extract:
                prefixed_key = f"student_{standard_key}"
                if prefixed_key in feature:
                    s_feat[standard_key] = feature[prefixed_key]
            if s_feat:
                student_features.append(s_feat)

        # 检查是否成功提取了特征
        if not teacher_features or not student_features:
            # 如果 features 为空（例如最后一个不完整的批次被丢弃），则返回空字典
            if not features:
                return {}
            raise ValueError("未能从特征中提取教师或学生输入，请检查预处理步骤和键名。Features sample: " + str(features[0].keys()))

        # 2. 使用各自的分词器进行填充
        #    tokenizer.pad 会处理 input_ids, attention_mask, token_type_ids 的填充
        #    padding=True 表示填充到批次内的最大长度
        #    return_tensors="pt" 返回 PyTorch 张量
        try:
            padded_teacher_batch = self.teacher_tokenizer.pad(
                teacher_features,
                padding=True,
                return_tensors="pt",
                # max_length=MAX_LENGTH, # 可选：如果希望填充到固定最大长度而不是批次最大长度
                # pad_to_multiple_of=8 # 可选：为了硬件优化
            )
        except Exception as e:
            logger.error(f"教师分词器填充失败: {e}. Teacher features keys: {[f.keys() for f in teacher_features[:2]]}")
            raise e

        try:
            padded_student_batch = self.student_tokenizer.pad(
                student_features,
                padding=True,
                return_tensors="pt",
                # max_length=MAX_LENGTH,
                # pad_to_multiple_of=8
            )
        except Exception as e:
            logger.error(f"学生分词器填充失败: {e}. Student features keys: {[f.keys() for f in student_features[:2]]}")
            raise e

        # 3. 组合最终的批次字典，并恢复前缀
        batch = {}
        if labels: # 仅当存在标签时添加
            batch["labels"] = torch.tensor(labels, dtype=torch.long)

        for k, v in padded_teacher_batch.items():
            batch[f"teacher_{k}"] = v

        for k, v in padded_student_batch.items():
            batch[f"student_{k}"] = v

        return batch


In [16]:
# 6. 划分训练集、验证集、测试集 (从已分词的数据集划分)
train_dataset = tokenized_datasets["train"]
# 将预分词后的测试集对半拆分，得到验证集和新的测试集
logger.info("将预分词后的测试集拆分为验证集和测试集...")
if len(tokenized_datasets["test"]) < 2:
    raise ValueError("预分词后的测试集太小，无法拆分。")
val_test_split = tokenized_datasets["test"].train_test_split(test_size=0.5, seed=42, shuffle=True)
val_dataset = val_test_split['train']
# 注意：这个 test_dataset 是预分词过的，主要用于训练过程中的快速评估
# 最终评估需要使用原始文本+学生分词器处理的测试集
test_dataset_pretokenized = val_test_split['test']

logger.info(f"训练集大小: {len(train_dataset)}")
logger.info(f"验证集大小: {len(val_dataset)}")
logger.info(f"用于训练中评估的测试集大小: {len(test_dataset_pretokenized)}")


2025-04-26 07:31:51,059 [INFO] 将预分词后的测试集拆分为验证集和测试集...
2025-04-26 07:31:51,075 [INFO] 训练集大小: 25000
2025-04-26 07:31:51,075 [INFO] 验证集大小: 12500
2025-04-26 07:31:51,076 [INFO] 用于训练中评估的测试集大小: 12500


In [17]:
# 7. 数据整理器 (Data Collator)
#    *** 修改此处 ***
#    使用自定义的 DistillationDataCollator 处理预分词数据。
logger.info("使用自定义的 DistillationDataCollator 处理预分词数据。")
data_collator = DistillationDataCollator(
    teacher_tokenizer=teacher_tokenizer,
    student_tokenizer=student_tokenizer
)


2025-04-26 07:31:51,098 [INFO] 使用自定义的 DistillationDataCollator 处理预分词数据。


In [18]:
# 8. 准备评估指标计算函数
logger.info("加载 'accuracy' 评估指标...")
try:
    accuracy_metric = evaluate.load("accuracy")
except Exception as e:
    logger.error(f"加载 'accuracy' 指标失败: {e}")
    raise

def compute_metrics(eval_pred):
    """计算准确率指标"""
    predictions, labels = eval_pred
    # predictions 可能是 logits，需要 argmax
    if isinstance(predictions, tuple): # 有些模型输出是 (logits, ...)
        predictions = predictions[0]
    # 确保 predictions 是 numpy array
    if isinstance(predictions, torch.Tensor):
        predictions = predictions.detach().cpu().numpy()
    if isinstance(labels, torch.Tensor):
        labels = labels.detach().cpu().numpy()

    predictions = np.argmax(predictions, axis=1)
    return accuracy_metric.compute(predictions=predictions, references=labels)


2025-04-26 07:31:51,121 [INFO] 加载 'accuracy' 评估指标...


Downloading builder script:   0%|          | 0.00/4.20k [00:00<?, ?B/s]

In [19]:
# ==============================================================================
# ## 第二步：加载微调好的教师模型 (DeBERTa-v3-base on IMDB)
# ==============================================================================


In [20]:
logger.info(f"检查微调好的教师模型路径: {teacher_model_finetuned_path}")
if not os.path.exists(teacher_model_finetuned_path):
    logger.error(f"错误：未找到微调好的教师模型路径 '{teacher_model_finetuned_path}'。")
    logger.error("请确保教师模型已在 IMDB 上微调并保存，或修改路径。")
    # 可以在此处停止脚本或提供默认行为
    raise FileNotFoundError(f"教师模型未在路径 {teacher_model_finetuned_path} 找到")
else:
    logger.info(f"加载微调好的教师模型: {teacher_model_finetuned_path}")
    try:
        teacher_model_for_distill = AutoModelForSequenceClassification.from_pretrained(
            teacher_model_finetuned_path, num_labels=2 # IMDB 是二分类
        )
        logger.info("成功加载微调教师模型。")
    except Exception as e:
        logger.error(f"加载微调教师模型 '{teacher_model_finetuned_path}' 失败: {e}")
        raise


2025-04-26 07:31:52,568 [INFO] 检查微调好的教师模型路径: /kaggle/input/deberta-v3-base-finetuned-imdb/deberta-v3-base-finetuned-imdb
2025-04-26 07:31:52,573 [INFO] 加载微调好的教师模型: /kaggle/input/deberta-v3-base-finetuned-imdb/deberta-v3-base-finetuned-imdb
2025-04-26 07:31:53,284 [INFO] 成功加载微调教师模型。


In [21]:
# ==============================================================================
# ## 第三步：加载预训练的学生模型 (TinyBERT)
# ==============================================================================


In [22]:
logger.info(f"加载预训练学生模型 ({student_model_id}) 用于蒸馏...")
try:
    student_model = AutoModelForSequenceClassification.from_pretrained(
        student_model_id,
        num_labels=2, # 匹配 IMDB 任务
        ignore_mismatched_sizes=True # 忽略预训练头和当前任务头尺寸不匹配
    )
    logger.info("成功加载 TinyBERT 学生模型。")
except Exception as e:
    logger.error(f"加载学生模型 {student_model_id} 失败: {e}")
    raise


2025-04-26 07:31:53,328 [INFO] 加载预训练学生模型 (huawei-noah/TinyBERT_General_6L_768D) 用于蒸馏...


pytorch_model.bin:   0%|          | 0.00/287M [00:00<?, ?B/s]

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at huawei-noah/TinyBERT_General_6L_768D and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


2025-04-26 07:31:56,087 [INFO] 成功加载 TinyBERT 学生模型。


In [23]:
# ==============================================================================
# ## 第四步：设置并运行知识蒸馏训练 (使用预分词数据)
# ==============================================================================


In [24]:
# 1. 定义蒸馏训练参数类 (继承 TrainingArguments，增加 alpha 和 temperature)
class DistillationTrainingArguments(TrainingArguments):
    """
    自定义训练参数，添加蒸馏所需的 alpha 和 temperature。
    """
    def __init__(self, *args, alpha: float = 0.5, temperature: float = 2.0, **kwargs):
        super().__init__(*args, **kwargs)
        # 明确定义蒸馏超参数
        self.alpha = alpha
        self.temperature = temperature


In [25]:
# 2. 定义蒸馏训练器 (继承标准 Trainer，重写 compute_loss)
class PreTokenizedDistillationTrainer(Trainer):
    """
    处理预分词数据的蒸馏训练器。
    """
    def __init__(self, *args, teacher_model=None, **kwargs):
        super().__init__(*args, **kwargs)
        if teacher_model is None:
            raise ValueError("必须提供教师模型 (teacher_model)。")
        self.teacher = teacher_model
        # 将教师模型移动到与学生模型相同的设备，并设为评估模式
        # self.model 是学生模型，在 Trainer 初始化时已移动到 self.args.device
        self._move_model_to_device(self.teacher, self.args.device)
        self.teacher.eval()

    def compute_loss(self, model, inputs: Dict[str, torch.Tensor], return_outputs=False,**kwargs):
        """
        重写损失计算函数以处理预分词数据和蒸馏逻辑。
        'inputs' 字典包含由 DistillationDataCollator 处理好的张量。
        """
        # 提取标签，并确保在正确的设备上
        labels = inputs.get("labels")
        if labels is not None:
            labels = labels.to(self.args.device)
        # else: # 训练时通常有标签，评估时可能没有
        #     logger.debug("在 compute_loss 中未找到 'labels'。")

        # --- 学生模型前向传播 ---
        # 提取学生模型的输入 (以 "student_" 开头的键)
        student_input_keys = [k for k in inputs if k.startswith("student_")]
        # 构建输入字典，移除前缀，并移动到设备
        student_inputs = {k.replace("student_", ""): v.to(self.args.device) for k, v in inputs.items() if k in student_input_keys}

        # *** 修改此处：检查 model.training 而不是 self.is_training ***
        # 添加标签以计算学生模型的标准损失 (hard loss)
        # 只有在模型处于训练模式 (model.training is True) 且标签存在时才添加
        # 注意：Trainer 在计算 eval_loss 时也会调用 compute_loss，此时 model.training 为 False，
        # 但通常模型 forward 方法设计为即使在 eval 模式下，如果传入 labels 也能计算 loss。
        # 因此，更简单且常用的做法是：如果 labels 存在，就总是传入。
        # Trainer 会负责在训练和评估时都提供 labels（如果数据集里有的话）。
        if labels is not None:
            student_inputs["labels"] = labels
        # ***********************************************************

        # 调用学生模型
        outputs_student = model(**student_inputs)
        # 如果提供了标签，loss 会在 outputs_student 中计算好
        student_loss = outputs_student.loss if hasattr(outputs_student, "loss") else None
        logits_student = outputs_student.logits

        # ... (教师模型前向传播、蒸馏损失计算、组合损失计算保持不变) ...

        # --- 教师模型前向传播 ---
        teacher_input_keys = [k for k in inputs if k.startswith("teacher_")]
        teacher_inputs = {k.replace("teacher_", ""): v.to(self.args.device) for k, v in inputs.items() if k in teacher_input_keys}
        with torch.no_grad():
            outputs_teacher = self.teacher(**teacher_inputs)
        logits_teacher = outputs_teacher.logits.to(self.args.device)

        # --- 计算蒸馏损失 (KL 散度) ---
        temperature = self.args.temperature
        alpha = self.args.alpha
        loss_fct_kl = nn.KLDivLoss(reduction="batchmean", log_target=False)
        log_softmax_student = F.log_softmax(logits_student.float() / temperature, dim=-1)
        softmax_teacher = F.softmax(logits_teacher.float() / temperature, dim=-1)
        distillation_loss = loss_fct_kl(
            log_softmax_student,
            softmax_teacher
        ) * (temperature ** 2)

        # --- 组合损失 ---
        if student_loss is not None: # 只有在计算了硬标签损失时才组合
            loss = alpha * student_loss + (1.0 - alpha) * distillation_loss.to(student_loss.dtype)
        else:
            # 如果没有硬标签损失（例如在评估或无监督蒸馏时），只使用蒸馏损失
            # logger.debug("Student loss is None, using only distillation loss.")
            loss = distillation_loss # 这种情况在有监督训练评估中不应发生，因为eval也有loss

        return (loss, outputs_student) if return_outputs else loss


In [26]:
# 3. 配置蒸馏训练参数
distill_output_dir: str = "./distill_checkpoints_tinybert_pretokenized"
distill_logging_dir: str = './distill_logs_tinybert_pretokenized'

# 使用自定义的 DistillationTrainingArguments
distillation_args = DistillationTrainingArguments(
    output_dir=distill_output_dir,
    # max_steps=1,
    num_train_epochs=3,             # 训练轮数 (可调整)
    per_device_train_batch_size=16, # 根据 GPU 显存调整
    per_device_eval_batch_size=32,  # 评估批次大小
    gradient_accumulation_steps=1,  # 梯度累积步数
    learning_rate=5e-5,             # 学习率
    warmup_ratio=0.1,               # 学习率预热比例
    weight_decay=0.01,              # 权重衰减
    logging_dir=distill_logging_dir,# 日志保存目录
    logging_strategy="steps",       # 按步数记录日志
    logging_steps=100,              # 每 100 步记录一次日志
    eval_strategy="steps",          # 按步数进行评估
    eval_steps=500,                 # 每 500 步评估一次 (在 val_dataset 上)
    save_strategy="steps",          # 按步数保存模型
    save_steps=500,                 # 每 500 步保存一次检查点
    load_best_model_at_end=True,    # 训练结束后加载最佳模型 (基于 metric_for_best_model)
    metric_for_best_model="accuracy",# 使用 accuracy 作为选择最佳模型的指标
    greater_is_better=True,         # accuracy 越高越好
    save_total_limit=2,             # 最多保存 2 个检查点
    fp16=torch.cuda.is_available(), # 如果有 GPU，启用 FP16 混合精度训练
    report_to="tensorboard",        # 将日志报告给 TensorBoard
    # --- 蒸馏特定参数 ---
    alpha=0.5,                      # 硬标签损失 (student_loss) 的权重
    temperature=4.0,                # 蒸馏温度 (用于软化 logits)
    # --- 关键参数 ---
    remove_unused_columns=False     # **非常重要**: 必须保留所有列 (teacher_* 和 student_*)
    # 因为 DistillationDataCollator 和 compute_loss 需要它们
)


In [27]:
# 4. 创建预分词蒸馏训练器实例
distill_trainer = PreTokenizedDistillationTrainer(
    model=student_model,                  # 学生模型实例 (TinyBERT)
    teacher_model=teacher_model_for_distill, # 教师模型实例 (DeBERTa)
    args=distillation_args,               # 训练参数 (包含 alpha, temp)
    train_dataset=train_dataset,          # 预分词后的训练集
    eval_dataset=val_dataset,             # 预分词后的验证集
    data_collator=data_collator,          # *** 使用自定义的 DistillationDataCollator ***
    compute_metrics=compute_metrics,      # 评估指标计算函数
    tokenizer=student_tokenizer           # 主要分词器设为学生的 (用于保存等)
)


  super().__init__(*args, **kwargs)


model.safetensors:   0%|          | 0.00/287M [00:00<?, ?B/s]

In [28]:
# 5. 开始蒸馏训练
logger.info("开始在 IMDB 数据集上进行预分词知识蒸馏...")
try:
    train_result = distill_trainer.train()
    # 可以记录一些训练结果
    logger.info("知识蒸馏训练完成。")
    metrics = train_result.metrics
    distill_trainer.log_metrics("train", metrics)
    distill_trainer.save_metrics("train", metrics)
    distill_trainer.save_state() # 保存训练状态
except Exception as e:
    logger.error(f"蒸馏训练失败: {e}", exc_info=True) # 记录详细的回溯信息
    raise
finally:
    # 确保教师模型从 GPU 释放（如果需要）
    if hasattr(distill_trainer, 'teacher') and hasattr(distill_trainer.teacher, 'to'):
        try:
            distill_trainer.teacher.to('cpu')
            logger.info("教师模型已移至 CPU。")
        except Exception as e_cpu:
            logger.warning(f"将教师模型移至 CPU 时出错: {e_cpu}")



2025-04-26 07:32:03,273 [INFO] 开始在 IMDB 数据集上进行预分词知识蒸馏...


You're using a DebertaV2TokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
You're using a BertTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
Could not estimate the number of tokens of the input, floating-point operations will not be computed


Step,Training Loss,Validation Loss,Accuracy
500,0.5174,0.437679,0.86848
1000,0.4632,0.41132,0.8892
1500,0.4221,0.410105,0.89904
2000,0.2795,0.308357,0.9032
2500,0.2554,0.333867,0.89632
3000,0.2533,0.327107,0.90384
3500,0.1778,0.278062,0.91408
4000,0.155,0.267726,0.91672
4500,0.141,0.265842,0.91944


2025-04-26 09:39:39,157 [INFO] 知识蒸馏训练完成。
***** train metrics *****
  epoch                    =        3.0
  total_flos               =        0GF
  train_loss               =     0.3143
  train_runtime            = 2:07:35.46
  train_samples_per_second =      9.797
  train_steps_per_second   =      0.613
2025-04-26 09:39:39,848 [INFO] 教师模型已移至 CPU。


In [29]:
# 6. 保存最终训练好的 (最佳) 学生模型和 *其对应的分词器*
logger.info(f"保存最终蒸馏学生模型到: {final_student_model_path}")
distill_trainer.save_model(final_student_model_path)
# **关键**: 保存与最终学生模型匹配的分词器 (即 TinyBERT 的分词器)
logger.info(f"同时保存学生分词器到: {final_student_model_path}")
if student_tokenizer:
    student_tokenizer.save_pretrained(final_student_model_path)


2025-04-26 09:39:39,873 [INFO] 保存最终蒸馏学生模型到: tinybert-student-distilled-imdb-pretokenized
2025-04-26 09:39:40,541 [INFO] 同时保存学生分词器到: tinybert-student-distilled-imdb-pretokenized


In [30]:
# 7. 清理 GPU 显存 (可选，但在 notebook 中或连续运行时有用)
logger.info("清理模型和训练器...")
# 确保变量存在再删除
if 'teacher_model_for_distill' in locals() or 'teacher_model_for_distill' in globals():
    del teacher_model_for_distill
if 'student_model' in locals() or 'student_model' in globals():
    del student_model
if 'distill_trainer' in locals() or 'distill_trainer' in globals():
    del distill_trainer

if torch.cuda.is_available():
    torch.cuda.empty_cache()
    logger.info("GPU 显存已清理。")


2025-04-26 09:39:40,583 [INFO] 清理模型和训练器...
2025-04-26 09:39:40,766 [INFO] GPU 显存已清理。


In [31]:
# ==============================================================================
# ## 第五步：评估蒸馏后的学生模型 (TinyBERT) on IMDB Test Set
# ==============================================================================
# **重要**: 最终评估需要使用 *原始* 测试集文本，并 *只用学生分词器* 处理。


In [32]:
# 1. 加载最终蒸馏好的学生模型
logger.info(f"加载最终蒸馏学生模型进行评估: {final_student_model_path}")
try:
    # 确保模型存在再加载
    if not os.path.exists(final_student_model_path):
        raise FileNotFoundError(f"最终学生模型路径不存在: {final_student_model_path}")
    final_student_model = AutoModelForSequenceClassification.from_pretrained(
        final_student_model_path, num_labels=2
    ).to(device) # 加载到评估设备
    final_student_model.eval() # 设置为评估模式
except Exception as e:
    logger.error(f"加载最终蒸馏学生模型 '{final_student_model_path}' 失败: {e}")
    raise


2025-04-26 09:39:40,814 [INFO] 加载最终蒸馏学生模型进行评估: tinybert-student-distilled-imdb-pretokenized


In [33]:
# 2. 加载用于评估的学生分词器 (与最终模型一起保存的 TinyBERT 分词器)
logger.info(f"加载学生分词器进行评估: {final_student_model_path}")
try:
    # 确保分词器存在再加载
    if not os.path.exists(final_student_model_path):
        raise FileNotFoundError(f"学生分词器路径不存在: {final_student_model_path}")
    tokenizer_for_eval: PreTrainedTokenizerBase = AutoTokenizer.from_pretrained(final_student_model_path)
except Exception as e:
    logger.error(f"加载学生分词器 '{final_student_model_path}' 失败: {e}")
    raise


2025-04-26 09:39:40,971 [INFO] 加载学生分词器进行评估: tinybert-student-distilled-imdb-pretokenized


In [34]:
# 3. 准备用于最终评估的测试集
#    需要加载 *原始* 的测试集文本，然后用 *学生分词器* 处理
logger.info("加载原始 IMDB 测试集用于最终评估...")
# 重新加载原始数据集或确保 imdb_dataset 仍然可用
try:
    imdb_dataset_original_test = load_dataset("imdb", split="test")
except Exception as e:
    logger.error(f"重新加载原始 IMDB 测试集失败: {e}")
    raise

# 执行与训练/验证集相同的 50/50 拆分以获得一致的最终测试集
logger.info("从原始测试集中划分出最终评估集 (与训练时验证集对应)...")
if len(imdb_dataset_original_test) < 2:
    raise ValueError("原始测试集太小，无法拆分。")
original_val_test_split = imdb_dataset_original_test.train_test_split(test_size=0.5, seed=42, shuffle=True)
original_test_dataset_for_eval = original_val_test_split['test'] # 这是包含 'text' 和 'label' 的原始测试集

logger.info(f"最终评估测试集大小: {len(original_test_dataset_for_eval)}")

logger.info("使用学生分词器对原始测试集进行分词以供评估...")
def tokenize_for_student_eval(examples: Dict[str, List]) -> Any:
    """仅使用学生分词器处理文本，用于最终评估。"""
    return tokenizer_for_eval(
        examples["text"],
        truncation=True,
        max_length=MAX_LENGTH,
        padding=False # 填充交给 DataCollator
    )

# 对原始测试集应用分词函数
tokenized_test_for_eval = original_test_dataset_for_eval.map(
    tokenize_for_student_eval,
    batched=True,
    remove_columns=["text"], # 移除原始文本，保留 'label' (会被 Trainer 处理为 'labels')
    num_proc=os.cpu_count() // 2 if os.cpu_count() else 1
)
logger.info("最终评估测试集已使用学生分词器处理完毕。")
print("\n评估用测试集样本示例 (已用学生分词器处理):")
# 检查数据集是否为空
if len(tokenized_test_for_eval) > 0:
    print(tokenized_test_for_eval[0])
else:
    logger.warning("评估用的测试集为空。")



2025-04-26 09:39:41,021 [INFO] 加载原始 IMDB 测试集用于最终评估...
2025-04-26 09:39:47,058 [INFO] 从原始测试集中划分出最终评估集 (与训练时验证集对应)...
2025-04-26 09:39:47,069 [INFO] 最终评估测试集大小: 12500
2025-04-26 09:39:47,069 [INFO] 使用学生分词器对原始测试集进行分词以供评估...


Map (num_proc=2):   0%|          | 0/12500 [00:00<?, ? examples/s]

2025-04-26 09:39:54,052 [INFO] 最终评估测试集已使用学生分词器处理完毕。

评估用测试集样本示例 (已用学生分词器处理):
{'label': 1, 'input_ids': [101, 1026, 7987, 1013, 1028, 1026, 7987, 1013, 1028, 2043, 1045, 4895, 13203, 5051, 10985, 2135, 12524, 1037, 4595, 4631, 1010, 1045, 2245, 1045, 2001, 1999, 2005, 2019, 14036, 2332, 26511, 2466, 1998, 1997, 2607, 9393, 1052, 7959, 13355, 2121, 2001, 1999, 2009, 1010, 2061, 2054, 2071, 2175, 3308, 1029, 1026, 7987, 1013, 1028, 1026, 7987, 1013, 1028, 2200, 2855, 1010, 2174, 1010, 1045, 3651, 2008, 2023, 2466, 2001, 2055, 1037, 4595, 2060, 2477, 4661, 2074, 4631, 1012, 1045, 2318, 6933, 1998, 2481, 1005, 1056, 2644, 2127, 2146, 2044, 1996, 3185, 3092, 1012, 4067, 2017, 4869, 1010, 6874, 1998, 23786, 1010, 2005, 5026, 2149, 2107, 1037, 6919, 2135, 11259, 1998, 29353, 3185, 999, 4067, 2017, 3459, 1010, 2005, 2108, 2920, 1998, 17274, 1996, 3494, 2007, 2107, 5995, 1998, 7132, 2791, 999, 1026, 7987, 1013, 1028, 1026, 7987, 1013, 1028, 1045, 3858, 1996, 4854, 2905, 1025, 1996, 19050, 2905, 

In [35]:
# 4. 创建用于评估的数据整理器 (使用学生分词器进行填充)
logger.info("创建用于评估的数据整理器...")
data_collator_for_eval = DataCollatorWithPadding(tokenizer=tokenizer_for_eval)


2025-04-26 09:39:54,081 [INFO] 创建用于评估的数据整理器...


In [36]:
# 5. 创建一个新的标准 Trainer 实例用于评估
logger.info("创建评估用 Trainer...")
eval_output_dir: str = './eval_output_tinybert_pretokenized'
eval_args = TrainingArguments(
    output_dir=eval_output_dir, # 评估输出目录
    per_device_eval_batch_size=64,      # 评估批次大小 (可根据显存调整)
    do_train=False,                     # 不进行训练
    do_eval=True,                       # 进行评估
    report_to="none",                   # 不需要报告给外部服务
    remove_unused_columns=False         # 确保 'label' 列被保留传递给 compute_metrics
)

eval_trainer = Trainer(
    model=final_student_model,          # 加载好的最终学生模型
    args=eval_args,                     # 评估参数
    eval_dataset=tokenized_test_for_eval, # **关键**: 使用学生分词器处理过的测试集
    data_collator=data_collator_for_eval, # 使用包含学生分词器的整理器
    tokenizer=tokenizer_for_eval,       # 传入学生分词器
    compute_metrics=compute_metrics,    # 指标计算函数
)


2025-04-26 09:39:54,107 [INFO] 创建评估用 Trainer...


  eval_trainer = Trainer(


In [37]:
# 6. 在测试集上执行评估
logger.info("在最终 IMDB 测试集上评估蒸馏后的学生模型 (TinyBERT)...")
try:
    # 明确传入评估数据集进行评估
    if len(tokenized_test_for_eval) > 0:
        evaluation_results = eval_trainer.evaluate(eval_dataset=tokenized_test_for_eval)
    else:
        logger.warning("评估测试集为空，跳过评估。")
        evaluation_results = {} # 返回空结果
except Exception as e:
    logger.error(f"评估失败: {e}", exc_info=True)
    raise


2025-04-26 09:39:54,166 [INFO] 在最终 IMDB 测试集上评估蒸馏后的学生模型 (TinyBERT)...


In [38]:
# 7. 打印并保存评估结果
if evaluation_results:
    logger.info("最终学生模型 (预分词蒸馏) 在 IMDB 测试集上的评估结果:")
    # 格式化输出结果
    for key, value in evaluation_results.items():
        # 过滤掉不必要的日志信息，只显示评估核心指标
        if key.startswith("eval_"):
            logger.info(f"  {key}: {value:.4f}") # 打印到小数点后 4 位

    # 保存评估结果到文件
    eval_trainer.log_metrics("eval", evaluation_results)
    eval_trainer.save_metrics("eval", evaluation_results)
else:
    logger.info("没有评估结果可打印或保存。")

print("\n脚本执行完毕。")

2025-04-26 09:41:35,516 [INFO] 最终学生模型 (预分词蒸馏) 在 IMDB 测试集上的评估结果:
2025-04-26 09:41:35,517 [INFO]   eval_loss: 0.2376
2025-04-26 09:41:35,518 [INFO]   eval_model_preparation_time: 0.0019
2025-04-26 09:41:35,518 [INFO]   eval_accuracy: 0.9190
2025-04-26 09:41:35,519 [INFO]   eval_runtime: 101.3177
2025-04-26 09:41:35,520 [INFO]   eval_samples_per_second: 123.3740
2025-04-26 09:41:35,520 [INFO]   eval_steps_per_second: 1.9350
***** eval metrics *****
  eval_accuracy               =      0.919
  eval_loss                   =     0.2376
  eval_model_preparation_time =     0.0019
  eval_runtime                = 0:01:41.31
  eval_samples_per_second     =    123.374
  eval_steps_per_second       =      1.935

脚本执行完毕。
